sparank.modules.SpotRankTransformer
- class sparank.modules.SpotRankTransformer(tower_specs, num_cell_types, num_contexts=0, dim=128, depth=2, heads=4, mlp_dim=256, dropout=0.1, fusion_type='gate', context_dim=16, proj_dim=64, use_cl=True, use_mrp=True)[source]
Bases:
ModuleUnified 1-to-N modality Transformer for spatial deconvolution.
- Parameters:
tower_specs (List[Dict]) –
Each dict describes one modality:
{ "name": str, # e.g. "rna", "adt", "atac" "vocab_size": int, # including PAD, UNK, MASK "top_k": int, # token sequence length "mask_id": int, # MASK token id for MRP "depth": int, # Transformer depth (optional) "heads": int, # attention heads (optional) }
num_cell_types (int) – Deconvolution output dimension (number of classes).
num_contexts (int, default 0) – Number of contexts.
0disables context embedding.dim (int, default 128) – Shared token-embedding dimension.
depth (int, default 2) – Default Transformer depth (can be overridden per-tower).
heads (int, default 4) – Default attention heads (can be overridden per-tower).
mlp_dim (int, default 256) – Feed-forward hidden dimension.
dropout (float, default 0.1) – Dropout rate.
fusion_type (str, default "gate") – Fusion mechanism:
"concat"|"gate"|"sigmoid_gate"|"attention". Ignored when operating unimodally (M = 1).context_dim (int, optional) – Context-embedding dimension (defaults to
dim // 8).proj_dim (int, default 64) – Contrastive projection output dimension.
use_cl (bool, default True) – Whether to build the projection head for Contrastive Learning.
use_mrp (bool, default True) – Whether to build the reconstruction heads for Masked Region Prediction.
- __init__(tower_specs, num_cell_types, num_contexts=0, dim=128, depth=2, heads=4, mlp_dim=256, dropout=0.1, fusion_type='gate', context_dim=16, proj_dim=64, use_cl=True, use_mrp=True)[source]
Methods
__init__(tower_specs, num_cell_types[, ...])forward(x[, context_ids])Standard forward pass for spatial deconvolution.
forward_cl(x_a, x_b[, context_ids])Contrastive forward pass on paired augmented views.
forward_mrp(masked_x, mask_pos)Per-tower masked reconstruction predictions (MRP).
forward_with_gates(x[, context_ids])Forward pass that extracts fusion gate/attention scores for interpretation.
- forward(x, context_ids=None)[source]
Standard forward pass for spatial deconvolution.
- Parameters:
x (torch.Tensor) – Input token indices of shape
(B, total_seq_len).context_ids (torch.Tensor, optional) – Context IDs of shape
(B,).
- Returns:
logits: Predicted class logits of shape
(B, num_cell_types).embedding: The fused latent representation fed into the MLP head, shape
(B, head_in).
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- forward_cl(x_a, x_b, context_ids=None)[source]
Contrastive forward pass on paired augmented views.
- Parameters:
x_a (torch.Tensor) – Tokens for the first augmented view, shape
(B, total_seq_len).x_b (torch.Tensor) – Tokens for the second augmented view, shape
(B, total_seq_len).context_ids (torch.Tensor, optional) – Context IDs of shape
(B,).
- Returns:
proj_a: Projected embeddings for view A, shape
(B, proj_dim).proj_b: Projected embeddings for view B, shape
(B, proj_dim).logits_a: Predicted logits for view A, shape
(B, num_cell_types).
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- Raises:
RuntimeError – If called when the model was initialized with
use_cl=False.
- forward_mrp(masked_x, mask_pos)[source]
Per-tower masked reconstruction predictions (MRP).
- Parameters:
masked_x (torch.Tensor) – Token indices with corrupted/masked tokens, shape
(B, total_seq_len).mask_pos (torch.Tensor) – Boolean mask indicating which tokens were altered, shape
(B, total_seq_len).
- Returns:
Dictionary mapping modality names to tuples of: - pred_logits: Reconstructed vocabulary logits over masked tokens. - seg_mask_bool: The localized boolean mask for that modality segment.
- Return type:
Dict[str, Tuple[torch.Tensor, torch.Tensor]]
- Raises:
RuntimeError – If called when the model was initialized with
use_mrp=False.
- forward_with_gates(x, context_ids=None)[source]
Forward pass that extracts fusion gate/attention scores for interpretation.
- Parameters:
x (torch.Tensor) – Input token indices of shape
(B, total_seq_len).context_ids (torch.Tensor, optional) – Context IDs of shape
(B,).
- Returns:
logits: Predicted class logits, shape
(B, num_cell_types).emb: The latent vector fed into the MLP head, shape
(B, head_in).- gates: A dictionary detailing fusion behaviour:
type: “gate”, “sigmoid_gate”, “attention”, “concat”, or “identity”.weights: Tensor describing the gating weights (shape depends on type).modalities: List of modality names in structural order.
- Return type:
Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]