Source code for sparank.modules.model
"""Unified SpotRankTransformer — handles 1 to N modalities.
Architecture overview
---------------------
- Each modality owns a ``_Tower`` (token embedding, positional encoding,
CLS token, Transformer encoder layers, layer norm).
- All tower CLS vectors are fused via a ``FusionLayer`` (identity when
M = 1).
- An optional context embedding is concatenated to the fused vector
before the MLP classification head and projection head.
- Each tower has its own MRP (masked reconstruction pre-training) head
so vocabularies remain independent.
When ``len(tower_specs) == 1`` and ``num_contexts > 0``, this is exactly
the "unimodal + context-conditioning" model. When
``len(tower_specs) >= 2``, it is the multi-tower model. No code path
differs.
"""
from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Any
import torch
import torch.nn as nn
from sparank.modules.layers import FusionLayer
# -----------------------------------------------------------------------
# Single Transformer tower
# -----------------------------------------------------------------------
class _Tower(nn.Module):
"""One self-contained Transformer encoder tower."""
def __init__(
self,
vocab_size: int,
max_seq: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
dropout: float
):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, dim, padding_idx=0)
self.pos_emb = nn.Embedding(max_seq + 1, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.drop = nn.Dropout(dropout)
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=dim, nhead=heads, dim_feedforward=mlp_dim,
dropout=dropout, batch_first=True,
)
for _ in range(depth)
])
self.norm = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process a token sequence through the Transformer encoder.
Parameters
----------
x : torch.Tensor
Input token indices of shape ``(B, K)``.
Returns
-------
torch.Tensor
Encoded sequence of shape ``(B, K+1, dim)``, with the CLS token
prepended at position 0.
"""
b, n = x.shape
device = x.device
pad_mask = torch.cat([
torch.zeros((b, 1), dtype=torch.bool, device=device),
x == 0,
], dim=1)
emb = self.token_emb(x)
cls = self.cls_token.expand(b, -1, -1)
seq = torch.cat([cls, emb], dim=1)
pos = torch.arange(n + 1, device=device).unsqueeze(0)
seq = self.drop(seq + self.pos_emb(pos))
for layer in self.layers:
seq = layer(seq, src_key_padding_mask=pad_mask)
return self.norm(seq)
# -----------------------------------------------------------------------
# Unified model
# -----------------------------------------------------------------------
[docs]
class SpotRankTransformer(nn.Module):
"""Unified 1-to-N modality Transformer for spatial deconvolution.
Parameters
----------
tower_specs : List[Dict]
Each dict describes one modality::
{
"name": str, # e.g. "rna", "adt", "atac"
"vocab_size": int, # including PAD, UNK, MASK
"top_k": int, # token sequence length
"mask_id": int, # MASK token id for MRP
"depth": int, # Transformer depth (optional)
"heads": int, # attention heads (optional)
}
num_cell_types : int
Deconvolution output dimension (number of classes).
num_contexts : int, default 0
Number of contexts. ``0`` disables context embedding.
dim : int, default 128
Shared token-embedding dimension.
depth : int, default 2
Default Transformer depth (can be overridden per-tower).
heads : int, default 4
Default attention heads (can be overridden per-tower).
mlp_dim : int, default 256
Feed-forward hidden dimension.
dropout : float, default 0.1
Dropout rate.
fusion_type : str, default "gate"
Fusion mechanism: ``"concat"`` | ``"gate"`` | ``"sigmoid_gate"`` | ``"attention"``.
Ignored when operating unimodally (M = 1).
context_dim : int, optional
Context-embedding dimension (defaults to ``dim // 8``).
proj_dim : int, default 64
Contrastive projection output dimension.
use_cl : bool, default True
Whether to build the projection head for Contrastive Learning.
use_mrp : bool, default True
Whether to build the reconstruction heads for Masked Region Prediction.
"""
[docs]
def __init__(
self,
tower_specs: List[Dict],
num_cell_types: int,
num_contexts: int = 0,
dim: int = 128,
depth: int = 2,
heads: int = 4,
mlp_dim: int = 256,
dropout: float = 0.1,
fusion_type: str = "gate",
context_dim: Optional[int] = 16,
proj_dim: int = 64,
use_cl: bool = True,
use_mrp: bool = True,
):
super().__init__()
self.dim = dim
self.specs = tower_specs
self.n_modalities = len(tower_specs)
self.use_context = num_contexts > 0
self._context_dim = context_dim if context_dim is not None else dim // 8
# ── Towers + Segments ─────────────────────────────────────────────
self.towers = nn.ModuleDict()
self._seg_bounds: List[Tuple[int, int]] = []
offset = 0
for spec in tower_specs:
name = spec["name"]
d = spec.get("depth", depth)
h = spec.get("heads", heads)
self.towers[name] = _Tower(
vocab_size=spec["vocab_size"], max_seq=spec["top_k"],
dim=dim, depth=d, heads=h,
mlp_dim=mlp_dim, dropout=dropout,
)
k = spec["top_k"]
self._seg_bounds.append((offset, offset + k))
offset += k
# ── Fusion ────────────────────────────────────────────────────────
self.fusion = FusionLayer(
dim=dim, n_modalities=self.n_modalities, fusion_type=fusion_type,
)
fused_dim = self.fusion.out_dim
# ── Optional context embedding ────────────────────────────────────
if self.use_context:
self.context_emb = nn.Embedding(num_contexts, self._context_dim)
head_in = fused_dim + self._context_dim
else:
head_in = fused_dim
# ── Downstream heads ──────────────────────────────────────────────
self.mlp_head = nn.Sequential(
nn.Linear(head_in, mlp_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(mlp_dim, num_cell_types),
)
if use_cl:
self.proj_head = nn.Sequential(
nn.Linear(head_in, head_in), nn.ReLU(),
nn.Linear(head_in, proj_dim),
)
else:
self.proj_head = None
if use_mrp:
self.mrp_heads = nn.ModuleDict({
spec["name"]: nn.Linear(dim, spec["vocab_size"])
for spec in tower_specs
})
else:
self.mrp_heads = None
self._init_weights()
def _init_weights(self):
"""Initialise network parameters using Xavier Uniform."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
# ── Helpers ───────────────────────────────────────────────────────────
def _split(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Split concatenated tokens into per-modality segments."""
return {
spec["name"]: x[:, s:e]
for spec, (s, e) in zip(self.specs, self._seg_bounds)
}
def _get_embedding(
self,
x: torch.Tensor,
context_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Process all towers, apply fusion, and optionally concat context."""
segments = self._split(x)
cls_list = [
self.towers[spec["name"]](segments[spec["name"]])[:, 0, :]
for spec in self.specs
]
fused = self.fusion(cls_list)
if self.use_context and context_ids is not None:
fused = torch.cat([fused, self.context_emb(context_ids)], dim=-1)
return fused
# ── Forward variants ──────────────────────────────────────────────────
[docs]
def forward(
self,
x: torch.Tensor,
context_ids: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Standard forward pass for spatial deconvolution.
Parameters
----------
x : torch.Tensor
Input token indices of shape ``(B, total_seq_len)``.
context_ids : torch.Tensor, optional
Context IDs of shape ``(B,)``.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
- **logits**: Predicted class logits of shape ``(B, num_cell_types)``.
- **embedding**: The fused latent representation fed into the MLP head,
shape ``(B, head_in)``.
"""
emb = self._get_embedding(x, context_ids)
return self.mlp_head(emb), emb
[docs]
def forward_cl(
self,
x_a: torch.Tensor,
x_b: torch.Tensor,
context_ids: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Contrastive forward pass on paired augmented views.
Parameters
----------
x_a : torch.Tensor
Tokens for the first augmented view, shape ``(B, total_seq_len)``.
x_b : torch.Tensor
Tokens for the second augmented view, shape ``(B, total_seq_len)``.
context_ids : torch.Tensor, optional
Context IDs of shape ``(B,)``.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- **proj_a**: Projected embeddings for view A, shape ``(B, proj_dim)``.
- **proj_b**: Projected embeddings for view B, shape ``(B, proj_dim)``.
- **logits_a**: Predicted logits for view A, shape ``(B, num_cell_types)``.
Raises
------
RuntimeError
If called when the model was initialized with ``use_cl=False``.
"""
if self.proj_head is None:
raise RuntimeError("Contrastive learning is disabled (use_cl=False). Cannot call forward_cl.")
ea = self._get_embedding(x_a, context_ids)
eb = self._get_embedding(x_b, context_ids)
return self.proj_head(ea), self.proj_head(eb), self.mlp_head(ea)
[docs]
def forward_mrp(
self,
masked_x: torch.Tensor,
mask_pos: torch.Tensor,
) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
"""Per-tower masked reconstruction predictions (MRP).
Parameters
----------
masked_x : torch.Tensor
Token indices with corrupted/masked tokens, shape ``(B, total_seq_len)``.
mask_pos : torch.Tensor
Boolean mask indicating which tokens were altered, shape ``(B, total_seq_len)``.
Returns
-------
Dict[str, Tuple[torch.Tensor, torch.Tensor]]
Dictionary mapping modality names to tuples of:
- **pred_logits**: Reconstructed vocabulary logits over masked tokens.
- **seg_mask_bool**: The localized boolean mask for that modality segment.
Raises
------
RuntimeError
If called when the model was initialized with ``use_mrp=False``.
"""
if self.mrp_heads is None:
raise RuntimeError("Masked Region Prediction is disabled (use_mrp=False). Cannot call forward_mrp.")
segments = self._split(masked_x)
results: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
for spec, (s, e) in zip(self.specs, self._seg_bounds):
name = spec["name"]
seg_mask = mask_pos[:, s:e]
tok = self.towers[name](segments[name])[:, 1:, :] # skip CLS
pred = self.mrp_heads[name](tok[seg_mask])
results[name] = (pred, seg_mask)
return results
[docs]
def forward_with_gates(
self,
x: torch.Tensor,
context_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""Forward pass that extracts fusion gate/attention scores for interpretation.
Parameters
----------
x : torch.Tensor
Input token indices of shape ``(B, total_seq_len)``.
context_ids : torch.Tensor, optional
Context IDs of shape ``(B,)``.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]
- **logits**: Predicted class logits, shape ``(B, num_cell_types)``.
- **emb**: The latent vector fed into the MLP head, shape ``(B, head_in)``.
- **gates**: A dictionary detailing fusion behaviour:
- ``type``: "gate", "sigmoid_gate", "attention", "concat", or "identity".
- ``weights``: Tensor describing the gating weights (shape depends on type).
- ``modalities``: List of modality names in structural order.
"""
segments = self._split(x)
cls_list = [
self.towers[spec["name"]](segments[spec["name"]])[:, 0, :]
for spec in self.specs
]
fused, gates = self.fusion(cls_list, return_gates=True)
gates["modalities"] = [spec["name"] for spec in self.specs]
if self.use_context and context_ids is not None:
fused = torch.cat([fused, self.context_emb(context_ids)], dim=-1)
logits = self.mlp_head(fused)
return logits, fused, gates