Source code for sparank.modules.model

"""Unified SpotRankTransformer — handles 1 to N modalities.

Architecture overview
---------------------
- Each modality owns a ``_Tower`` (token embedding, positional encoding,
  CLS token, Transformer encoder layers, layer norm).
- All tower CLS vectors are fused via a ``FusionLayer`` (identity when
  M = 1).
- An optional context embedding is concatenated to the fused vector
  before the MLP classification head and projection head.
- Each tower has its own MRP (masked reconstruction pre-training) head
  so vocabularies remain independent.

When ``len(tower_specs) == 1`` and ``num_contexts > 0``, this is exactly
the "unimodal + context-conditioning" model. When
``len(tower_specs) >= 2``, it is the multi-tower model. No code path
differs.
"""

from __future__ import annotations

from typing import Dict, List, Optional, Tuple, Any

import torch
import torch.nn as nn

from sparank.modules.layers import FusionLayer


# -----------------------------------------------------------------------
#  Single Transformer tower
# -----------------------------------------------------------------------

class _Tower(nn.Module):
    """One self-contained Transformer encoder tower."""

    def __init__(
        self, 
        vocab_size: int, 
        max_seq: int,
        dim: int, 
        depth: int, 
        heads: int,
        mlp_dim: int, 
        dropout: float
    ):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, dim, padding_idx=0)
        self.pos_emb = nn.Embedding(max_seq + 1, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.drop = nn.Dropout(dropout)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=dim, nhead=heads, dim_feedforward=mlp_dim,
                dropout=dropout, batch_first=True,
            )
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Process a token sequence through the Transformer encoder.
        
        Parameters
        ----------
        x : torch.Tensor
            Input token indices of shape ``(B, K)``.
            
        Returns
        -------
        torch.Tensor
            Encoded sequence of shape ``(B, K+1, dim)``, with the CLS token 
            prepended at position 0.
        """
        b, n = x.shape
        device = x.device
        pad_mask = torch.cat([
            torch.zeros((b, 1), dtype=torch.bool, device=device),
            x == 0,
        ], dim=1)

        emb = self.token_emb(x)
        cls = self.cls_token.expand(b, -1, -1)
        seq = torch.cat([cls, emb], dim=1)
        pos = torch.arange(n + 1, device=device).unsqueeze(0)
        seq = self.drop(seq + self.pos_emb(pos))

        for layer in self.layers:
            seq = layer(seq, src_key_padding_mask=pad_mask)
        return self.norm(seq)


# -----------------------------------------------------------------------
#  Unified model
# -----------------------------------------------------------------------

[docs] class SpotRankTransformer(nn.Module): """Unified 1-to-N modality Transformer for spatial deconvolution. Parameters ---------- tower_specs : List[Dict] Each dict describes one modality:: { "name": str, # e.g. "rna", "adt", "atac" "vocab_size": int, # including PAD, UNK, MASK "top_k": int, # token sequence length "mask_id": int, # MASK token id for MRP "depth": int, # Transformer depth (optional) "heads": int, # attention heads (optional) } num_cell_types : int Deconvolution output dimension (number of classes). num_contexts : int, default 0 Number of contexts. ``0`` disables context embedding. dim : int, default 128 Shared token-embedding dimension. depth : int, default 2 Default Transformer depth (can be overridden per-tower). heads : int, default 4 Default attention heads (can be overridden per-tower). mlp_dim : int, default 256 Feed-forward hidden dimension. dropout : float, default 0.1 Dropout rate. fusion_type : str, default "gate" Fusion mechanism: ``"concat"`` | ``"gate"`` | ``"sigmoid_gate"`` | ``"attention"``. Ignored when operating unimodally (M = 1). context_dim : int, optional Context-embedding dimension (defaults to ``dim // 8``). proj_dim : int, default 64 Contrastive projection output dimension. use_cl : bool, default True Whether to build the projection head for Contrastive Learning. use_mrp : bool, default True Whether to build the reconstruction heads for Masked Region Prediction. """
[docs] def __init__( self, tower_specs: List[Dict], num_cell_types: int, num_contexts: int = 0, dim: int = 128, depth: int = 2, heads: int = 4, mlp_dim: int = 256, dropout: float = 0.1, fusion_type: str = "gate", context_dim: Optional[int] = 16, proj_dim: int = 64, use_cl: bool = True, use_mrp: bool = True, ): super().__init__() self.dim = dim self.specs = tower_specs self.n_modalities = len(tower_specs) self.use_context = num_contexts > 0 self._context_dim = context_dim if context_dim is not None else dim // 8 # ── Towers + Segments ───────────────────────────────────────────── self.towers = nn.ModuleDict() self._seg_bounds: List[Tuple[int, int]] = [] offset = 0 for spec in tower_specs: name = spec["name"] d = spec.get("depth", depth) h = spec.get("heads", heads) self.towers[name] = _Tower( vocab_size=spec["vocab_size"], max_seq=spec["top_k"], dim=dim, depth=d, heads=h, mlp_dim=mlp_dim, dropout=dropout, ) k = spec["top_k"] self._seg_bounds.append((offset, offset + k)) offset += k # ── Fusion ──────────────────────────────────────────────────────── self.fusion = FusionLayer( dim=dim, n_modalities=self.n_modalities, fusion_type=fusion_type, ) fused_dim = self.fusion.out_dim # ── Optional context embedding ──────────────────────────────────── if self.use_context: self.context_emb = nn.Embedding(num_contexts, self._context_dim) head_in = fused_dim + self._context_dim else: head_in = fused_dim # ── Downstream heads ────────────────────────────────────────────── self.mlp_head = nn.Sequential( nn.Linear(head_in, mlp_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_dim, num_cell_types), ) if use_cl: self.proj_head = nn.Sequential( nn.Linear(head_in, head_in), nn.ReLU(), nn.Linear(head_in, proj_dim), ) else: self.proj_head = None if use_mrp: self.mrp_heads = nn.ModuleDict({ spec["name"]: nn.Linear(dim, spec["vocab_size"]) for spec in tower_specs }) else: self.mrp_heads = None self._init_weights()
def _init_weights(self): """Initialise network parameters using Xavier Uniform.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) # ── Helpers ─────────────────────────────────────────────────────────── def _split(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """Split concatenated tokens into per-modality segments.""" return { spec["name"]: x[:, s:e] for spec, (s, e) in zip(self.specs, self._seg_bounds) } def _get_embedding( self, x: torch.Tensor, context_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Process all towers, apply fusion, and optionally concat context.""" segments = self._split(x) cls_list = [ self.towers[spec["name"]](segments[spec["name"]])[:, 0, :] for spec in self.specs ] fused = self.fusion(cls_list) if self.use_context and context_ids is not None: fused = torch.cat([fused, self.context_emb(context_ids)], dim=-1) return fused # ── Forward variants ──────────────────────────────────────────────────
[docs] def forward( self, x: torch.Tensor, context_ids: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Standard forward pass for spatial deconvolution. Parameters ---------- x : torch.Tensor Input token indices of shape ``(B, total_seq_len)``. context_ids : torch.Tensor, optional Context IDs of shape ``(B,)``. Returns ------- Tuple[torch.Tensor, torch.Tensor] - **logits**: Predicted class logits of shape ``(B, num_cell_types)``. - **embedding**: The fused latent representation fed into the MLP head, shape ``(B, head_in)``. """ emb = self._get_embedding(x, context_ids) return self.mlp_head(emb), emb
[docs] def forward_cl( self, x_a: torch.Tensor, x_b: torch.Tensor, context_ids: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Contrastive forward pass on paired augmented views. Parameters ---------- x_a : torch.Tensor Tokens for the first augmented view, shape ``(B, total_seq_len)``. x_b : torch.Tensor Tokens for the second augmented view, shape ``(B, total_seq_len)``. context_ids : torch.Tensor, optional Context IDs of shape ``(B,)``. Returns ------- Tuple[torch.Tensor, torch.Tensor, torch.Tensor] - **proj_a**: Projected embeddings for view A, shape ``(B, proj_dim)``. - **proj_b**: Projected embeddings for view B, shape ``(B, proj_dim)``. - **logits_a**: Predicted logits for view A, shape ``(B, num_cell_types)``. Raises ------ RuntimeError If called when the model was initialized with ``use_cl=False``. """ if self.proj_head is None: raise RuntimeError("Contrastive learning is disabled (use_cl=False). Cannot call forward_cl.") ea = self._get_embedding(x_a, context_ids) eb = self._get_embedding(x_b, context_ids) return self.proj_head(ea), self.proj_head(eb), self.mlp_head(ea)
[docs] def forward_mrp( self, masked_x: torch.Tensor, mask_pos: torch.Tensor, ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]: """Per-tower masked reconstruction predictions (MRP). Parameters ---------- masked_x : torch.Tensor Token indices with corrupted/masked tokens, shape ``(B, total_seq_len)``. mask_pos : torch.Tensor Boolean mask indicating which tokens were altered, shape ``(B, total_seq_len)``. Returns ------- Dict[str, Tuple[torch.Tensor, torch.Tensor]] Dictionary mapping modality names to tuples of: - **pred_logits**: Reconstructed vocabulary logits over masked tokens. - **seg_mask_bool**: The localized boolean mask for that modality segment. Raises ------ RuntimeError If called when the model was initialized with ``use_mrp=False``. """ if self.mrp_heads is None: raise RuntimeError("Masked Region Prediction is disabled (use_mrp=False). Cannot call forward_mrp.") segments = self._split(masked_x) results: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} for spec, (s, e) in zip(self.specs, self._seg_bounds): name = spec["name"] seg_mask = mask_pos[:, s:e] tok = self.towers[name](segments[name])[:, 1:, :] # skip CLS pred = self.mrp_heads[name](tok[seg_mask]) results[name] = (pred, seg_mask) return results
[docs] def forward_with_gates( self, x: torch.Tensor, context_ids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: """Forward pass that extracts fusion gate/attention scores for interpretation. Parameters ---------- x : torch.Tensor Input token indices of shape ``(B, total_seq_len)``. context_ids : torch.Tensor, optional Context IDs of shape ``(B,)``. Returns ------- Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]] - **logits**: Predicted class logits, shape ``(B, num_cell_types)``. - **emb**: The latent vector fed into the MLP head, shape ``(B, head_in)``. - **gates**: A dictionary detailing fusion behaviour: - ``type``: "gate", "sigmoid_gate", "attention", "concat", or "identity". - ``weights``: Tensor describing the gating weights (shape depends on type). - ``modalities``: List of modality names in structural order. """ segments = self._split(x) cls_list = [ self.towers[spec["name"]](segments[spec["name"]])[:, 0, :] for spec in self.specs ] fused, gates = self.fusion(cls_list, return_gates=True) gates["modalities"] = [spec["name"] for spec in self.specs] if self.use_context and context_ids is not None: fused = torch.cat([fused, self.context_emb(context_ids)], dim=-1) logits = self.mlp_head(fused) return logits, fused, gates