Source code for sparank.framework

"""High-level framework encompassing data registration, preparation, and training."""

from __future__ import annotations

import json
import os
from collections import OrderedDict
from os.path import join
from typing import Dict, List, Optional, Sequence, Tuple, Union, Any

import numpy as np
import pandas as pd
import scanpy as sc

import torch
from torch.utils.data import DataLoader

from sparank.config import ExpConfig, ModalityConfig
from sparank.data import (
    MemmapDataset,
    InferenceDataset,
    build_vocab,
    find_sc_markers,
    normalize_rna,
    simulate,
    tokenize_batch
)
from sparank.modules import SpotRankTransformer
from sparank.training import Trainer


[docs] class SpaRank: """High-level pipeline: register modalities -> prepare -> fit -> predict. This class serves as the main entry point for the SpaRank framework. It manages multiple AnnData objects (one per modality), handles pre-processing and pseudo-spot simulation, instantiates the multimodal model, and runs training and inference. Parameters ---------- cfg : ExpConfig Configuration object containing hyperparameters for modalities, model architecture, simulation, and training. save_dir : str, optional Directory to save models and metadata. If None, falls back to ``cfg.save_dir`` or ``"./sparank_output"``. device : Union[str, torch.device], optional Compute device. Automatically determined if left as None. """ # ────────────────────────────────────────────────────────────────────── # Init # ──────────────────────────────────────────────────────────────────────
[docs] def __init__( self, cfg: ExpConfig, save_dir: Optional[str] = None, device: Optional[Union[str, torch.device]] = None, ): self.cfg = cfg self.save_dir = save_dir or getattr(cfg, "save_dir", "./sparank_output") self.device = ( torch.device(device) if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) # ordered – insertion order == modality order self._raw_adatas: OrderedDict[str, sc.AnnData] = OrderedDict() # populated during prepare() self.mod_adatas: OrderedDict[str, sc.AnnData] = OrderedDict() self.modality_features: OrderedDict[str, List[str]] = OrderedDict() self.cell_types: List[str] = [] self.context_cats: List[str] = [] self.vocab: dict = {} self.top_ks: Dict[str, int] = {} self.model: Optional[SpotRankTransformer] = None self.trainer: Optional[Trainer] = None self._loader: Optional[DataLoader] = None # internal simulation paths self._real_n: int = 0 self._inp_path: str = "" self._lbl_path: str = "" self._context_path: Optional[str] = None
# ────────────────────────────────────────────────────────────────────── # 1. Register raw AnnData for each modality # ──────────────────────────────────────────────────────────────────────
[docs] def register_modality(self, name: str, adata: sc.AnnData) -> "SpaRank": """Register a single-modality AnnData. Call once per modality. Parameters ---------- name : str Name of the modality (must match a modality name defined in `cfg`). adata : sc.AnnData The AnnData object containing the single-cell reference for this modality. Returns ------- SpaRank Returns self to allow method chaining. Raises ------ ValueError If the provided modality name is not expected in the configuration. """ expected = {m.name for m in self.cfg.modalities} if name not in expected: raise ValueError( f"Modality '{name}' not in config. Expected one of {expected}." ) adata.var_names_make_unique() self._raw_adatas[name] = adata return self
[docs] def register_mudata(self, mdata: Any) -> "SpaRank": """Convenience method: register all modalities from a MuData object at once. Parameters ---------- mdata : MuData A MuData object containing multi-modal data. Returns ------- SpaRank Returns self to allow method chaining. Raises ------ KeyError If a modality defined in the configuration is missing from the MuData. """ for mcfg in self.cfg.modalities: if mcfg.name not in mdata.mod: raise KeyError( f"MuData has no modality '{mcfg.name}'. " f"Available: {list(mdata.mod.keys())}" ) self.register_modality(mcfg.name, mdata.mod[mcfg.name]) return self
# ────────────────────────────────────────────────────────────────────── # 2. Prepare: normalise, select markers, build vocab, simulate # ──────────────────────────────────────────────────────────────────────
[docs] def prepare( self, marker_features: Optional[Dict[str, List[str]]] = None, ) -> "SpaRank": """Execute the full pre-training pipeline. This aligns cells, normalizes RNA data, performs marker selection, builds multi-modal vocabularies, simulates pseudo-spots, and constructs the model and dataloaders. Parameters ---------- marker_features : Dict[str, List[str]], optional Pre-computed marker feature lists per modality. If *None*, markers are discovered automatically for RNA; other modalities keep all features. Returns ------- SpaRank Returns self. """ self._validate_registered() # 2a Intersect to common cells self._align_cells() # 2b Normalise & select features per modality self._preprocess_modalities( marker_features=marker_features ) # 2c Prefix feature names & concatenate ad_concat, self.modality_features = self._prefix_and_concat(self.mod_adatas) # 2d Build vocab self._build_vocab(ad_concat) # 2e Simulate pseudo-spots self._simulate(ad_concat) # 2f Build DataLoader self._build_loader() # 2g Build model self._build_model() return self
# ────────────────────────────────────────────────────────────────────── # 3. Fit # ──────────────────────────────────────────────────────────────────────
[docs] def fit(self, loader: Optional[DataLoader] = None) -> "SpaRank": """Train the model. Uses the internal loader generated by `prepare` unless overridden. Parameters ---------- loader : DataLoader, optional Custom DataLoader. If None, uses the internal simulator DataLoader. Returns ------- SpaRank Returns self. Raises ------ RuntimeError If the model or dataloader has not been initialized. """ loader = loader or self._loader if loader is None: raise RuntimeError("No DataLoader available. Call .prepare() first.") if self.model is None: raise RuntimeError("No model built. Call .prepare() first.") self.trainer = Trainer(self.model, self.cfg, device=self.device) self.trainer.fit(loader, save_dir=self.save_dir) return self
# ────────────────────────────────────────────────────────────────────── # 4. Save / Load # ──────────────────────────────────────────────────────────────────────
[docs] def save(self, save_dir: Optional[str] = None) -> None: """Persist config, vocabulary mappings, and final model weights. Parameters ---------- save_dir : str, optional Destination directory. Overrides class instantiation path if provided. """ save_dir = save_dir or self.save_dir meta_dir = join(save_dir, "metadata") os.makedirs(meta_dir, exist_ok=True) # config self.cfg.save(join(meta_dir, "config.json")) # vocab + top_ks payload = {**self.vocab, "top_ks": self.top_ks} with open(join(meta_dir, "vocab_meta.json"), "w") as f: json.dump(payload, f, indent=4, default=str) # model weights (last epoch) if self.model is not None: torch.save( self.model.state_dict(), join(save_dir, f"model_epoch{self.cfg.epochs}.pth"), )
[docs] @classmethod def load( cls, save_dir: str, device: Optional[Union[str, torch.device]] = None, verbose: bool = False ) -> "SpaRank": """Reconstruct a trained SpaRank model from saved artefacts. Parameters ---------- save_dir : str Directory where the model and metadata are stored. device : Union[str, torch.device], optional Target compute device for the loaded model. verbose : bool, default False If True, prints the checkpoint loading status. Returns ------- SpaRank An instantiated and loaded SpaRank pipeline ready for inference. """ meta_dir = join(save_dir, "metadata") cfg = ExpConfig.load(join(meta_dir, "config.json")) sr = cls(cfg, save_dir=save_dir, device=device) with open(join(meta_dir, "vocab_meta.json")) as f: vm = json.load(f) sr.vocab = vm sr.top_ks = vm.get("top_ks", {}) sr.cell_types = sorted(vm.get("type2id", {}).keys()) sr.context_cats = sorted(vm.get("context2id", {}).keys()) if "context2id" in vm else [] # Reconstruct modality_features from vocab for mcfg in cfg.modalities: name = mcfg.name if name in vm.get("vocabs", {}): # vocab keys include <PAD>, <UNK> — filter to real features sr.modality_features[name] = sorted( k for k, v in vm["vocabs"][name].items() if v >= 2 # skip <PAD>=0, <UNK>=1 ) # Rebuild model architecture sr._build_model() # Load weights ckpt = join(save_dir, f"model_epoch{cfg.epochs}.pth") if os.path.isfile(ckpt): msg = sr.model.load_state_dict( torch.load(ckpt, map_location=sr.device), strict=False ) sr.model.eval() if verbose: print(f"Loaded model from {ckpt} with message: {msg}") return sr
# ────────────────────────────────────────────────────────────────────── # 5. Predict # ──────────────────────────────────────────────────────────────────────
[docs] def predict( self, mod_adatas: Dict[str, sc.AnnData], batch_size: int = 512, num_workers: int = 0, return_embeddings: bool = False, return_gate_scores: bool = False, return_per_dim_gates: bool = False, gate_reduction: str = "mean", ) -> Union[pd.DataFrame, Tuple[Any, ...]]: """Run inference on spatial data. By default returns a DataFrame of predicted cell-type proportions. Optional flags additionally return embeddings and/or fusion gate scores in a single forward pass — nothing is recomputed. Parameters ---------- mod_adatas : Dict[str, sc.AnnData] One AnnData per modality, keyed by modality name. Feature names should be unprefixed; prefixing is handled internally. batch_size : int, default 512 Batch size for the Inference DataLoader. num_workers : int, default 0 Workers for the Inference DataLoader. return_embeddings : bool, default False If True, also return per-spot embeddings ``(n_spots, head_in)``. return_gate_scores : bool, default False If True, also return per-spot, per-modality fusion weights as a DataFrame ``(n_spots, n_modalities)``. For ``concat`` / single- modality fusion, returns uniform 1/M. return_per_dim_gates : bool, default False If True, also return the raw per-dimension gate tensor ``(n_spots, n_modalities, dim)``. Only meaningful for ``fusion_type='gate'``; ignored otherwise. Implies ``return_gate_scores=True``. gate_reduction : str, default "mean" How to reduce ``(M, B, d)`` gate tensors to per-modality scalars. Options are ``"mean"`` or ``"max"``. Returns ------- Union[pd.DataFrame, Tuple[Any, ...]] Always returned in this order, skipping items not requested: - **proportions** (pd.DataFrame) - **embeddings** (np.ndarray, if requested) - **gate_scores** (pd.DataFrame, if requested) - **per_dim_gates** (np.ndarray, if requested) A single DataFrame is returned when only proportions are requested. """ if self.model is None: raise RuntimeError( "Model not available. Call .prepare() + .fit(), or use SpaRank.load()." ) # per_dim implies gate_scores if return_per_dim_gates: return_gate_scores = True want_gates = return_gate_scores # alias for readability # ── A. Validate modalities ─────────────────────────────────────── ref_mods = set(self.modality_features.keys()) inp_mods = set(mod_adatas.keys()) if inp_mods != ref_mods: raise ValueError( f"Modality mismatch: input has {inp_mods}, training used {ref_mods}." ) # ── B. Prefix + concat + tokenize ──────────────────────────────── ad_concat = self._prefix_and_concat(mod_adatas, "test")[0] context2id = self.vocab.get("context2id", None) context_key = self.cfg.context_key if context2id else None tokens, _, context_ids = tokenize_batch( adata=ad_concat, vocabs=self.vocab["vocabs"], modality_names=self.cfg.modality_names, top_ks=self.top_ks, cell_types=self.cell_types, # Fixed: explicit passing to ensure type safety context2id=context2id, context_key=context_key, mode="test", ) ds = InferenceDataset(tokens, context_ids) loader = DataLoader( ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, ) # ── C. Forward pass (choose method based on what's needed) ─────── self.model.eval() all_probs: List[np.ndarray] = [] all_embs: List[np.ndarray] = [] if return_embeddings else [] all_scores: List[np.ndarray] = [] if want_gates else [] all_per_dim: List[np.ndarray] = [] if return_per_dim_gates else [] with torch.no_grad(): for tok_batch, ctx_batch in loader: tok_batch = tok_batch.to(self.device) ctx_batch = ctx_batch.to(self.device) if context2id else None if want_gates: logits, emb, gates = self.model.forward_with_gates( tok_batch, ctx_batch, ) else: logits, emb = self.model(tok_batch, context_ids=ctx_batch) all_probs.append(torch.softmax(logits, dim=-1).cpu().numpy()) if return_embeddings: all_embs.append(emb.cpu().numpy()) if want_gates: scores, per_dim = self._reduce_gates( gates, tok_batch.shape[0], gate_reduction, ) all_scores.append(scores) if return_per_dim_gates and per_dim is not None: all_per_dim.append(per_dim) # ── D. Assemble outputs ────────────────────────────────────────── proportions_df = pd.DataFrame( np.concatenate(all_probs, axis=0), index=ad_concat.obs_names, columns=self.cell_types, ) outputs: List[Any] = [proportions_df] if return_embeddings: outputs.append(np.concatenate(all_embs, axis=0)) if want_gates: outputs.append(pd.DataFrame( np.concatenate(all_scores, axis=0), index=ad_concat.obs_names, columns=self.cfg.modality_names, )) if return_per_dim_gates: if all_per_dim: outputs.append(np.concatenate(all_per_dim, axis=0)) else: # fusion_type doesn't produce per-dim gates (concat / attention / identity) outputs.append(None) return outputs[0] if len(outputs) == 1 else tuple(outputs)
def _reduce_gates( self, gates: dict, batch_size: int, reduction: str, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """Reduce a FusionLayer gates dict to per-modality (B, M) scores. Parameters ---------- gates : dict The gate dictionary returned by the fusion layer. batch_size : int The number of samples in the current batch. reduction : str Reduction method, either ``"mean"`` or ``"max"``. Returns ------- Tuple[np.ndarray, Optional[np.ndarray]] A tuple of ``(scores_bm, per_dim_bmd_or_None)``. The ``per_dim`` array is only non-None for ``fusion_type='gate'``. """ w = gates["weights"] gtype = gates["type"] M = len(gates["modalities"]) if gtype == "gate" or gtype == "sigmoid_gate": # w: (M, B, d) -> (B, M, d) w_bmd = w.permute(1, 0, 2).contiguous() per_dim = w_bmd.cpu().numpy() if reduction == "mean": scores = w_bmd.mean(dim=-1) elif reduction == "max": scores = w_bmd.max(dim=-1).values else: raise ValueError(f"Unknown gate_reduction: {reduction!r}") return scores.cpu().numpy(), per_dim if gtype == "attention": # w: (B, M, M) — average over query axis return w.mean(dim=1).cpu().numpy(), None # concat / identity: no learned weighting uniform = np.full((batch_size, M), 1.0 / M, dtype=np.float32) return uniform, None # ================================================================== # Private helpers # ================================================================== def _validate_registered(self) -> None: expected = {m.name for m in self.cfg.modalities} registered = set(self._raw_adatas.keys()) missing = expected - registered if missing: raise RuntimeError( f"Missing modalities: {missing}. " f"Call register_modality() for each before prepare()." ) # ---- 2a: align cells across modalities ---------------------------- def _align_cells(self) -> None: common_cells = None for ad in self._raw_adatas.values(): if common_cells is None: common_cells = ad.obs_names else: common_cells = np.intersect1d(common_cells, ad.obs_names) for name, ad in self._raw_adatas.items(): self.mod_adatas[name] = ad[common_cells].copy() # ensure batch_key exists on the primary modality primary = self.mod_adatas[self.cfg.modalities[0].name] if self.cfg.batch_key and self.cfg.batch_key not in primary.obs.columns: primary.obs[self.cfg.batch_key] = "batch_1" if self.cfg.simulation.batch_key and self.cfg.simulation.batch_key not in primary.obs.columns: primary.obs[self.cfg.simulation.batch_key] = "sim_batch_1" # ---- 2b: normalise & marker selection ----------------------------- def _preprocess_modalities( self, marker_features: Optional[Dict[str, List[str]]] = None, ) -> None: marker_features = marker_features or {} n_tops = {m.name: m.k_feat or 50 for m in self.cfg.modalities} for mcfg in self.cfg.modalities: name = mcfg.name ad = self.mod_adatas[name] if name == "rna": if name in marker_features: # use supplied markers if given feats = [f for f in marker_features[name] if f in ad.var_names] else: ad = normalize_rna(ad, layer_key="log1p") feats = find_sc_markers( ad, self.cfg.celltype_key, batch_key=self.cfg.batch_key, layer="log1p", deg_method="t-test", log2fc_min=0.5, pval_cutoff=0.01, n_top_markers=n_tops[name], pct_min=0.1, ) self.mod_adatas[name] = ad[:, feats].copy() else: # non-RNA modalities: use supplied markers or keep all if name in marker_features: feats = [f for f in marker_features[name] if f in ad.var_names] self.mod_adatas[name] = ad[:, feats].copy() # ---- 2c: prefix & concatenate ------------------------------------- def _prefix_and_concat( self, mod_adatas: Dict[str, sc.AnnData], mode: str = "train", ) -> Tuple[sc.AnnData, OrderedDict]: """Prefix var_names and concatenate across modalities. Works on **copies** of the input adatas to avoid mutating the caller's objects (critical for predict, where the user's spatial adatas must not be modified). Parameters ---------- mod_adatas : Dict[str, sc.AnnData] Dictionary of modality AnnData objects. mode : str, default "train" In "train" mode, obs metadata (celltype, batch, context) is propagated to the concatenated AnnData. Returns ------- Tuple[sc.AnnData, OrderedDict] - Concatenated AnnData object. - Ordered dictionary of modality feature lists. """ modality_features = OrderedDict() parts = [] for name in self.cfg.modality_names: ad = mod_adatas[name] # Detect whether features are already prefixed sample_var = ad.var_names[0] if len(ad.var_names) > 0 else "" already_prefixed = sample_var.startswith(f"{name}-") if already_prefixed: ad_copy = ad.copy() else: ad_copy = ad.copy() ad_copy.var_names = [f"{name}-{g}" for g in ad_copy.var_names] modality_features[name] = list(ad_copy.var_names) parts.append(ad_copy) ad_concat = sc.concat(parts, axis=1) # Carry over obs metadata from first modality (train mode only) if mode == "train": primary = parts[0] for col in [ self.cfg.celltype_key, self.cfg.batch_key, self.cfg.simulation.batch_key, self.cfg.context_key, ]: if col and col in primary.obs.columns: ad_concat.obs[col] = primary.obs[col].values else: # predict mode: carry context_key if present first_input = mod_adatas[self.cfg.modality_names[0]] if self.cfg.context_key and self.cfg.context_key in first_input.obs.columns: ad_concat.obs[self.cfg.context_key] = first_input.obs[ self.cfg.context_key ].values return ad_concat, modality_features # ---- 2d: vocab ---------------------------------------------------- def _build_vocab(self, ad_concat: sc.AnnData) -> None: self.cell_types = sorted( ad_concat.obs[self.cfg.celltype_key].unique().tolist() ) context_cats: Optional[List[str]] = None if self.cfg.context_key and self.cfg.context_key in ad_concat.obs: self.context_cats = sorted( ad_concat.obs[self.cfg.context_key].astype(str).unique().tolist() ) context_cats = self.context_cats self.vocab = build_vocab( modality_features=self.modality_features, cell_types=self.cell_types, context_categories=context_cats, ) # clamp top_k to actual feature count self.top_ks = {} for mcfg in self.cfg.modalities: actual = len(self.modality_features.get(mcfg.name, [])) self.top_ks[mcfg.name] = min(mcfg.top_k, actual) mcfg.top_k = self.top_ks[mcfg.name] # persist meta_dir = join(self.save_dir, "metadata") os.makedirs(meta_dir, exist_ok=True) payload = {**self.vocab, "top_ks": self.top_ks} with open(join(meta_dir, "vocab_meta.json"), "w") as f: json.dump(payload, f, indent=4, default=str) # ---- 2e: simulate ------------------------------------------------- def _simulate(self, ad_concat: sc.AnnData) -> None: all_features: List[str] = [] for name in self.cfg.modality_names: all_features.extend(self.modality_features[name]) context2id = self.vocab.get("context2id", None) context_key = ( self.cfg.context_key if context2id else None ) self._real_n, self._inp_path, self._lbl_path, self._context_path = simulate( ad_concat, vocabs=self.vocab["vocabs"], modality_names=self.cfg.modality_names, top_ks=self.top_ks, cell_types=self.cell_types, all_features=all_features, celltype_key=self.cfg.celltype_key, sim_batch_key=self.cfg.simulation.batch_key, save_dir=self.save_dir, cfg=self.cfg.simulation, context2id=context2id, context_key=context_key, ) # ---- 2f: dataloader ----------------------------------------------- def _build_loader(self) -> None: segment_layout = [ { "name": mcfg.name, "top_k": self.top_ks[mcfg.name], "mask_id": self.vocab["mask_ids"][mcfg.name], "cl_dropout_rate": mcfg.cl_dropout_rate, "mrp_mask_rate": mcfg.mrp_mask_rate, } for mcfg in self.cfg.modalities ] ds = MemmapDataset( input_path=self._inp_path, label_path=self._lbl_path, context_path=self._context_path, valid_samples=self._real_n, max_samples=self.cfg.simulation.total_samples, seq_len=self.cfg.total_seq_len, num_classes=len(self.cell_types), segment_layout=segment_layout, cl_mode=self.cfg.use_cl, mrp_mode=self.cfg.use_mrp, ) self._loader = DataLoader( ds, batch_size=self.cfg.batch_size, shuffle=True, num_workers=self.cfg.num_workers ) # ---- 2g: model architecture --------------------------------------- def _build_model(self) -> None: tower_specs = [ { "name": mcfg.name, "vocab_size": len(self.vocab["vocabs"][mcfg.name]) + 1, "top_k": self.top_ks[mcfg.name], "mask_id": self.vocab["mask_ids"][mcfg.name], "depth": mcfg.depth or self.cfg.depth, "heads": mcfg.heads or self.cfg.heads, } for mcfg in self.cfg.modalities ] num_contexts = len(self.vocab.get("context2id", {})) self.model = SpotRankTransformer( tower_specs=tower_specs, num_cell_types=len(self.cell_types), num_contexts=num_contexts, dim=self.cfg.dim, depth=self.cfg.depth, heads=self.cfg.heads, mlp_dim=self.cfg.mlp_dim, dropout=self.cfg.dropout, fusion_type=getattr(self.cfg, "fusion_type", None), context_dim=getattr(self.cfg, "context_dim", None), use_cl=self.cfg.use_cl, use_mrp=self.cfg.use_mrp, ).to(self.device) # ────────────────────────────────────────────────────────────────────── # Repr # ────────────────────────────────────────────────────────────────────── def __repr__(self): mods = ", ".join(m.name for m in self.cfg.modalities) status = "trained" if (self.model is not None and self.trainer is not None) else "untrained" return f"SpaRank(modalities=[{mods}], status={status}, device={self.device})"