sparank.modules.SpotRankTransformer

class sparank.modules.SpotRankTransformer(tower_specs, num_cell_types, num_contexts=0, dim=128, depth=2, heads=4, mlp_dim=256, dropout=0.1, fusion_type='gate', context_dim=16, proj_dim=64, use_cl=True, use_mrp=True)[source]

Bases: Module

Unified 1-to-N modality Transformer for spatial deconvolution.

Parameters:
  • tower_specs (List[Dict]) –

    Each dict describes one modality:

    {
        "name":       str,   # e.g. "rna", "adt", "atac"
        "vocab_size": int,   # including PAD, UNK, MASK
        "top_k":      int,   # token sequence length
        "mask_id":    int,   # MASK token id for MRP
        "depth":      int,   # Transformer depth (optional)
        "heads":      int,   # attention heads (optional)
    }
    

  • num_cell_types (int) – Deconvolution output dimension (number of classes).

  • num_contexts (int, default 0) – Number of contexts. 0 disables context embedding.

  • dim (int, default 128) – Shared token-embedding dimension.

  • depth (int, default 2) – Default Transformer depth (can be overridden per-tower).

  • heads (int, default 4) – Default attention heads (can be overridden per-tower).

  • mlp_dim (int, default 256) – Feed-forward hidden dimension.

  • dropout (float, default 0.1) – Dropout rate.

  • fusion_type (str, default "gate") – Fusion mechanism: "concat" | "gate" | "sigmoid_gate" | "attention". Ignored when operating unimodally (M = 1).

  • context_dim (int, optional) – Context-embedding dimension (defaults to dim // 8).

  • proj_dim (int, default 64) – Contrastive projection output dimension.

  • use_cl (bool, default True) – Whether to build the projection head for Contrastive Learning.

  • use_mrp (bool, default True) – Whether to build the reconstruction heads for Masked Region Prediction.

__init__(tower_specs, num_cell_types, num_contexts=0, dim=128, depth=2, heads=4, mlp_dim=256, dropout=0.1, fusion_type='gate', context_dim=16, proj_dim=64, use_cl=True, use_mrp=True)[source]
Parameters:

Methods

__init__(tower_specs, num_cell_types[, ...])

forward(x[, context_ids])

Standard forward pass for spatial deconvolution.

forward_cl(x_a, x_b[, context_ids])

Contrastive forward pass on paired augmented views.

forward_mrp(masked_x, mask_pos)

Per-tower masked reconstruction predictions (MRP).

forward_with_gates(x[, context_ids])

Forward pass that extracts fusion gate/attention scores for interpretation.

forward(x, context_ids=None)[source]

Standard forward pass for spatial deconvolution.

Parameters:
  • x (torch.Tensor) – Input token indices of shape (B, total_seq_len).

  • context_ids (torch.Tensor, optional) – Context IDs of shape (B,).

Returns:

  • logits: Predicted class logits of shape (B, num_cell_types).

  • embedding: The fused latent representation fed into the MLP head, shape (B, head_in).

Return type:

Tuple[torch.Tensor, torch.Tensor]

forward_cl(x_a, x_b, context_ids=None)[source]

Contrastive forward pass on paired augmented views.

Parameters:
  • x_a (torch.Tensor) – Tokens for the first augmented view, shape (B, total_seq_len).

  • x_b (torch.Tensor) – Tokens for the second augmented view, shape (B, total_seq_len).

  • context_ids (torch.Tensor, optional) – Context IDs of shape (B,).

Returns:

  • proj_a: Projected embeddings for view A, shape (B, proj_dim).

  • proj_b: Projected embeddings for view B, shape (B, proj_dim).

  • logits_a: Predicted logits for view A, shape (B, num_cell_types).

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Raises:

RuntimeError – If called when the model was initialized with use_cl=False.

forward_mrp(masked_x, mask_pos)[source]

Per-tower masked reconstruction predictions (MRP).

Parameters:
  • masked_x (torch.Tensor) – Token indices with corrupted/masked tokens, shape (B, total_seq_len).

  • mask_pos (torch.Tensor) – Boolean mask indicating which tokens were altered, shape (B, total_seq_len).

Returns:

Dictionary mapping modality names to tuples of: - pred_logits: Reconstructed vocabulary logits over masked tokens. - seg_mask_bool: The localized boolean mask for that modality segment.

Return type:

Dict[str, Tuple[torch.Tensor, torch.Tensor]]

Raises:

RuntimeError – If called when the model was initialized with use_mrp=False.

forward_with_gates(x, context_ids=None)[source]

Forward pass that extracts fusion gate/attention scores for interpretation.

Parameters:
  • x (torch.Tensor) – Input token indices of shape (B, total_seq_len).

  • context_ids (torch.Tensor, optional) – Context IDs of shape (B,).

Returns:

  • logits: Predicted class logits, shape (B, num_cell_types).

  • emb: The latent vector fed into the MLP head, shape (B, head_in).

  • gates: A dictionary detailing fusion behaviour:
    • type: “gate”, “sigmoid_gate”, “attention”, “concat”, or “identity”.

    • weights: Tensor describing the gating weights (shape depends on type).

    • modalities: List of modality names in structural order.

Return type:

Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]