"""High-level framework encompassing data registration, preparation, and training."""
from __future__ import annotations
import json
import os
from collections import OrderedDict
from os.path import join
from typing import Dict, List, Optional, Sequence, Tuple, Union, Any
import numpy as np
import pandas as pd
import scanpy as sc
import torch
from torch.utils.data import DataLoader
from sparank.config import ExpConfig, ModalityConfig
from sparank.data import (
MemmapDataset,
InferenceDataset,
build_vocab,
find_sc_markers,
normalize_rna,
simulate,
tokenize_batch
)
from sparank.modules import SpotRankTransformer
from sparank.training import Trainer
[docs]
class SpaRank:
"""High-level pipeline: register modalities -> prepare -> fit -> predict.
This class serves as the main entry point for the SpaRank framework. It manages
multiple AnnData objects (one per modality), handles pre-processing and pseudo-spot
simulation, instantiates the multimodal model, and runs training and inference.
Parameters
----------
cfg : ExpConfig
Configuration object containing hyperparameters for modalities, model architecture,
simulation, and training.
save_dir : str, optional
Directory to save models and metadata. If None, falls back to ``cfg.save_dir``
or ``"./sparank_output"``.
device : Union[str, torch.device], optional
Compute device. Automatically determined if left as None.
"""
# ──────────────────────────────────────────────────────────────────────
# Init
# ──────────────────────────────────────────────────────────────────────
[docs]
def __init__(
self,
cfg: ExpConfig,
save_dir: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
):
self.cfg = cfg
self.save_dir = save_dir or getattr(cfg, "save_dir", "./sparank_output")
self.device = (
torch.device(device)
if device is not None
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
# ordered – insertion order == modality order
self._raw_adatas: OrderedDict[str, sc.AnnData] = OrderedDict()
# populated during prepare()
self.mod_adatas: OrderedDict[str, sc.AnnData] = OrderedDict()
self.modality_features: OrderedDict[str, List[str]] = OrderedDict()
self.cell_types: List[str] = []
self.context_cats: List[str] = []
self.vocab: dict = {}
self.top_ks: Dict[str, int] = {}
self.model: Optional[SpotRankTransformer] = None
self.trainer: Optional[Trainer] = None
self._loader: Optional[DataLoader] = None
# internal simulation paths
self._real_n: int = 0
self._inp_path: str = ""
self._lbl_path: str = ""
self._context_path: Optional[str] = None
# ──────────────────────────────────────────────────────────────────────
# 1. Register raw AnnData for each modality
# ──────────────────────────────────────────────────────────────────────
[docs]
def register_modality(self, name: str, adata: sc.AnnData) -> "SpaRank":
"""Register a single-modality AnnData. Call once per modality.
Parameters
----------
name : str
Name of the modality (must match a modality name defined in `cfg`).
adata : sc.AnnData
The AnnData object containing the single-cell reference for this modality.
Returns
-------
SpaRank
Returns self to allow method chaining.
Raises
------
ValueError
If the provided modality name is not expected in the configuration.
"""
expected = {m.name for m in self.cfg.modalities}
if name not in expected:
raise ValueError(
f"Modality '{name}' not in config. Expected one of {expected}."
)
adata.var_names_make_unique()
self._raw_adatas[name] = adata
return self
[docs]
def register_mudata(self, mdata: Any) -> "SpaRank":
"""Convenience method: register all modalities from a MuData object at once.
Parameters
----------
mdata : MuData
A MuData object containing multi-modal data.
Returns
-------
SpaRank
Returns self to allow method chaining.
Raises
------
KeyError
If a modality defined in the configuration is missing from the MuData.
"""
for mcfg in self.cfg.modalities:
if mcfg.name not in mdata.mod:
raise KeyError(
f"MuData has no modality '{mcfg.name}'. "
f"Available: {list(mdata.mod.keys())}"
)
self.register_modality(mcfg.name, mdata.mod[mcfg.name])
return self
# ──────────────────────────────────────────────────────────────────────
# 2. Prepare: normalise, select markers, build vocab, simulate
# ──────────────────────────────────────────────────────────────────────
[docs]
def prepare(
self,
marker_features: Optional[Dict[str, List[str]]] = None,
) -> "SpaRank":
"""Execute the full pre-training pipeline.
This aligns cells, normalizes RNA data, performs marker selection, builds
multi-modal vocabularies, simulates pseudo-spots, and constructs the
model and dataloaders.
Parameters
----------
marker_features : Dict[str, List[str]], optional
Pre-computed marker feature lists per modality. If *None*, markers
are discovered automatically for RNA; other modalities keep all
features.
Returns
-------
SpaRank
Returns self.
"""
self._validate_registered()
# 2a Intersect to common cells
self._align_cells()
# 2b Normalise & select features per modality
self._preprocess_modalities(
marker_features=marker_features
)
# 2c Prefix feature names & concatenate
ad_concat, self.modality_features = self._prefix_and_concat(self.mod_adatas)
# 2d Build vocab
self._build_vocab(ad_concat)
# 2e Simulate pseudo-spots
self._simulate(ad_concat)
# 2f Build DataLoader
self._build_loader()
# 2g Build model
self._build_model()
return self
# ──────────────────────────────────────────────────────────────────────
# 3. Fit
# ──────────────────────────────────────────────────────────────────────
[docs]
def fit(self, loader: Optional[DataLoader] = None) -> "SpaRank":
"""Train the model. Uses the internal loader generated by `prepare` unless overridden.
Parameters
----------
loader : DataLoader, optional
Custom DataLoader. If None, uses the internal simulator DataLoader.
Returns
-------
SpaRank
Returns self.
Raises
------
RuntimeError
If the model or dataloader has not been initialized.
"""
loader = loader or self._loader
if loader is None:
raise RuntimeError("No DataLoader available. Call .prepare() first.")
if self.model is None:
raise RuntimeError("No model built. Call .prepare() first.")
self.trainer = Trainer(self.model, self.cfg, device=self.device)
self.trainer.fit(loader, save_dir=self.save_dir)
return self
# ──────────────────────────────────────────────────────────────────────
# 4. Save / Load
# ──────────────────────────────────────────────────────────────────────
[docs]
def save(self, save_dir: Optional[str] = None) -> None:
"""Persist config, vocabulary mappings, and final model weights.
Parameters
----------
save_dir : str, optional
Destination directory. Overrides class instantiation path if provided.
"""
save_dir = save_dir or self.save_dir
meta_dir = join(save_dir, "metadata")
os.makedirs(meta_dir, exist_ok=True)
# config
self.cfg.save(join(meta_dir, "config.json"))
# vocab + top_ks
payload = {**self.vocab, "top_ks": self.top_ks}
with open(join(meta_dir, "vocab_meta.json"), "w") as f:
json.dump(payload, f, indent=4, default=str)
# model weights (last epoch)
if self.model is not None:
torch.save(
self.model.state_dict(),
join(save_dir, f"model_epoch{self.cfg.epochs}.pth"),
)
[docs]
@classmethod
def load(
cls,
save_dir: str,
device: Optional[Union[str, torch.device]] = None,
verbose: bool = False
) -> "SpaRank":
"""Reconstruct a trained SpaRank model from saved artefacts.
Parameters
----------
save_dir : str
Directory where the model and metadata are stored.
device : Union[str, torch.device], optional
Target compute device for the loaded model.
verbose : bool, default False
If True, prints the checkpoint loading status.
Returns
-------
SpaRank
An instantiated and loaded SpaRank pipeline ready for inference.
"""
meta_dir = join(save_dir, "metadata")
cfg = ExpConfig.load(join(meta_dir, "config.json"))
sr = cls(cfg, save_dir=save_dir, device=device)
with open(join(meta_dir, "vocab_meta.json")) as f:
vm = json.load(f)
sr.vocab = vm
sr.top_ks = vm.get("top_ks", {})
sr.cell_types = sorted(vm.get("type2id", {}).keys())
sr.context_cats = sorted(vm.get("context2id", {}).keys()) if "context2id" in vm else []
# Reconstruct modality_features from vocab
for mcfg in cfg.modalities:
name = mcfg.name
if name in vm.get("vocabs", {}):
# vocab keys include <PAD>, <UNK> — filter to real features
sr.modality_features[name] = sorted(
k for k, v in vm["vocabs"][name].items()
if v >= 2 # skip <PAD>=0, <UNK>=1
)
# Rebuild model architecture
sr._build_model()
# Load weights
ckpt = join(save_dir, f"model_epoch{cfg.epochs}.pth")
if os.path.isfile(ckpt):
msg = sr.model.load_state_dict(
torch.load(ckpt, map_location=sr.device),
strict=False
)
sr.model.eval()
if verbose:
print(f"Loaded model from {ckpt} with message: {msg}")
return sr
# ──────────────────────────────────────────────────────────────────────
# 5. Predict
# ──────────────────────────────────────────────────────────────────────
[docs]
def predict(
self,
mod_adatas: Dict[str, sc.AnnData],
batch_size: int = 512,
num_workers: int = 0,
return_embeddings: bool = False,
return_gate_scores: bool = False,
return_per_dim_gates: bool = False,
gate_reduction: str = "mean",
) -> Union[pd.DataFrame, Tuple[Any, ...]]:
"""Run inference on spatial data.
By default returns a DataFrame of predicted cell-type proportions.
Optional flags additionally return embeddings and/or fusion gate
scores in a single forward pass — nothing is recomputed.
Parameters
----------
mod_adatas : Dict[str, sc.AnnData]
One AnnData per modality, keyed by modality name. Feature
names should be unprefixed; prefixing is handled internally.
batch_size : int, default 512
Batch size for the Inference DataLoader.
num_workers : int, default 0
Workers for the Inference DataLoader.
return_embeddings : bool, default False
If True, also return per-spot embeddings ``(n_spots, head_in)``.
return_gate_scores : bool, default False
If True, also return per-spot, per-modality fusion weights as a
DataFrame ``(n_spots, n_modalities)``. For ``concat`` / single-
modality fusion, returns uniform 1/M.
return_per_dim_gates : bool, default False
If True, also return the raw per-dimension gate tensor
``(n_spots, n_modalities, dim)``. Only meaningful for
``fusion_type='gate'``; ignored otherwise. Implies
``return_gate_scores=True``.
gate_reduction : str, default "mean"
How to reduce ``(M, B, d)`` gate tensors to per-modality scalars.
Options are ``"mean"`` or ``"max"``.
Returns
-------
Union[pd.DataFrame, Tuple[Any, ...]]
Always returned in this order, skipping items not requested:
- **proportions** (pd.DataFrame)
- **embeddings** (np.ndarray, if requested)
- **gate_scores** (pd.DataFrame, if requested)
- **per_dim_gates** (np.ndarray, if requested)
A single DataFrame is returned when only proportions are requested.
"""
if self.model is None:
raise RuntimeError(
"Model not available. Call .prepare() + .fit(), or use SpaRank.load()."
)
# per_dim implies gate_scores
if return_per_dim_gates:
return_gate_scores = True
want_gates = return_gate_scores # alias for readability
# ── A. Validate modalities ───────────────────────────────────────
ref_mods = set(self.modality_features.keys())
inp_mods = set(mod_adatas.keys())
if inp_mods != ref_mods:
raise ValueError(
f"Modality mismatch: input has {inp_mods}, training used {ref_mods}."
)
# ── B. Prefix + concat + tokenize ────────────────────────────────
ad_concat = self._prefix_and_concat(mod_adatas, "test")[0]
context2id = self.vocab.get("context2id", None)
context_key = self.cfg.context_key if context2id else None
tokens, _, context_ids = tokenize_batch(
adata=ad_concat,
vocabs=self.vocab["vocabs"],
modality_names=self.cfg.modality_names,
top_ks=self.top_ks,
cell_types=self.cell_types, # Fixed: explicit passing to ensure type safety
context2id=context2id,
context_key=context_key,
mode="test",
)
ds = InferenceDataset(tokens, context_ids)
loader = DataLoader(
ds, batch_size=batch_size, shuffle=False, num_workers=num_workers,
)
# ── C. Forward pass (choose method based on what's needed) ───────
self.model.eval()
all_probs: List[np.ndarray] = []
all_embs: List[np.ndarray] = [] if return_embeddings else []
all_scores: List[np.ndarray] = [] if want_gates else []
all_per_dim: List[np.ndarray] = [] if return_per_dim_gates else []
with torch.no_grad():
for tok_batch, ctx_batch in loader:
tok_batch = tok_batch.to(self.device)
ctx_batch = ctx_batch.to(self.device) if context2id else None
if want_gates:
logits, emb, gates = self.model.forward_with_gates(
tok_batch, ctx_batch,
)
else:
logits, emb = self.model(tok_batch, context_ids=ctx_batch)
all_probs.append(torch.softmax(logits, dim=-1).cpu().numpy())
if return_embeddings:
all_embs.append(emb.cpu().numpy())
if want_gates:
scores, per_dim = self._reduce_gates(
gates, tok_batch.shape[0], gate_reduction,
)
all_scores.append(scores)
if return_per_dim_gates and per_dim is not None:
all_per_dim.append(per_dim)
# ── D. Assemble outputs ──────────────────────────────────────────
proportions_df = pd.DataFrame(
np.concatenate(all_probs, axis=0),
index=ad_concat.obs_names,
columns=self.cell_types,
)
outputs: List[Any] = [proportions_df]
if return_embeddings:
outputs.append(np.concatenate(all_embs, axis=0))
if want_gates:
outputs.append(pd.DataFrame(
np.concatenate(all_scores, axis=0),
index=ad_concat.obs_names,
columns=self.cfg.modality_names,
))
if return_per_dim_gates:
if all_per_dim:
outputs.append(np.concatenate(all_per_dim, axis=0))
else:
# fusion_type doesn't produce per-dim gates (concat / attention / identity)
outputs.append(None)
return outputs[0] if len(outputs) == 1 else tuple(outputs)
def _reduce_gates(
self,
gates: dict,
batch_size: int,
reduction: str,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""Reduce a FusionLayer gates dict to per-modality (B, M) scores.
Parameters
----------
gates : dict
The gate dictionary returned by the fusion layer.
batch_size : int
The number of samples in the current batch.
reduction : str
Reduction method, either ``"mean"`` or ``"max"``.
Returns
-------
Tuple[np.ndarray, Optional[np.ndarray]]
A tuple of ``(scores_bm, per_dim_bmd_or_None)``. The ``per_dim`` array
is only non-None for ``fusion_type='gate'``.
"""
w = gates["weights"]
gtype = gates["type"]
M = len(gates["modalities"])
if gtype == "gate" or gtype == "sigmoid_gate":
# w: (M, B, d) -> (B, M, d)
w_bmd = w.permute(1, 0, 2).contiguous()
per_dim = w_bmd.cpu().numpy()
if reduction == "mean":
scores = w_bmd.mean(dim=-1)
elif reduction == "max":
scores = w_bmd.max(dim=-1).values
else:
raise ValueError(f"Unknown gate_reduction: {reduction!r}")
return scores.cpu().numpy(), per_dim
if gtype == "attention":
# w: (B, M, M) — average over query axis
return w.mean(dim=1).cpu().numpy(), None
# concat / identity: no learned weighting
uniform = np.full((batch_size, M), 1.0 / M, dtype=np.float32)
return uniform, None
# ==================================================================
# Private helpers
# ==================================================================
def _validate_registered(self) -> None:
expected = {m.name for m in self.cfg.modalities}
registered = set(self._raw_adatas.keys())
missing = expected - registered
if missing:
raise RuntimeError(
f"Missing modalities: {missing}. "
f"Call register_modality() for each before prepare()."
)
# ---- 2a: align cells across modalities ----------------------------
def _align_cells(self) -> None:
common_cells = None
for ad in self._raw_adatas.values():
if common_cells is None:
common_cells = ad.obs_names
else:
common_cells = np.intersect1d(common_cells, ad.obs_names)
for name, ad in self._raw_adatas.items():
self.mod_adatas[name] = ad[common_cells].copy()
# ensure batch_key exists on the primary modality
primary = self.mod_adatas[self.cfg.modalities[0].name]
if self.cfg.batch_key and self.cfg.batch_key not in primary.obs.columns:
primary.obs[self.cfg.batch_key] = "batch_1"
if self.cfg.simulation.batch_key and self.cfg.simulation.batch_key not in primary.obs.columns:
primary.obs[self.cfg.simulation.batch_key] = "sim_batch_1"
# ---- 2b: normalise & marker selection -----------------------------
def _preprocess_modalities(
self,
marker_features: Optional[Dict[str, List[str]]] = None,
) -> None:
marker_features = marker_features or {}
n_tops = {m.name: m.k_feat or 50 for m in self.cfg.modalities}
for mcfg in self.cfg.modalities:
name = mcfg.name
ad = self.mod_adatas[name]
if name == "rna":
if name in marker_features: # use supplied markers if given
feats = [f for f in marker_features[name] if f in ad.var_names]
else:
ad = normalize_rna(ad, layer_key="log1p")
feats = find_sc_markers(
ad,
self.cfg.celltype_key,
batch_key=self.cfg.batch_key,
layer="log1p",
deg_method="t-test",
log2fc_min=0.5,
pval_cutoff=0.01,
n_top_markers=n_tops[name],
pct_min=0.1,
)
self.mod_adatas[name] = ad[:, feats].copy()
else:
# non-RNA modalities: use supplied markers or keep all
if name in marker_features:
feats = [f for f in marker_features[name] if f in ad.var_names]
self.mod_adatas[name] = ad[:, feats].copy()
# ---- 2c: prefix & concatenate -------------------------------------
def _prefix_and_concat(
self,
mod_adatas: Dict[str, sc.AnnData],
mode: str = "train",
) -> Tuple[sc.AnnData, OrderedDict]:
"""Prefix var_names and concatenate across modalities.
Works on **copies** of the input adatas to avoid mutating the
caller's objects (critical for predict, where the user's spatial
adatas must not be modified).
Parameters
----------
mod_adatas : Dict[str, sc.AnnData]
Dictionary of modality AnnData objects.
mode : str, default "train"
In "train" mode, obs metadata (celltype, batch, context) is
propagated to the concatenated AnnData.
Returns
-------
Tuple[sc.AnnData, OrderedDict]
- Concatenated AnnData object.
- Ordered dictionary of modality feature lists.
"""
modality_features = OrderedDict()
parts = []
for name in self.cfg.modality_names:
ad = mod_adatas[name]
# Detect whether features are already prefixed
sample_var = ad.var_names[0] if len(ad.var_names) > 0 else ""
already_prefixed = sample_var.startswith(f"{name}-")
if already_prefixed:
ad_copy = ad.copy()
else:
ad_copy = ad.copy()
ad_copy.var_names = [f"{name}-{g}" for g in ad_copy.var_names]
modality_features[name] = list(ad_copy.var_names)
parts.append(ad_copy)
ad_concat = sc.concat(parts, axis=1)
# Carry over obs metadata from first modality (train mode only)
if mode == "train":
primary = parts[0]
for col in [
self.cfg.celltype_key,
self.cfg.batch_key,
self.cfg.simulation.batch_key,
self.cfg.context_key,
]:
if col and col in primary.obs.columns:
ad_concat.obs[col] = primary.obs[col].values
else:
# predict mode: carry context_key if present
first_input = mod_adatas[self.cfg.modality_names[0]]
if self.cfg.context_key and self.cfg.context_key in first_input.obs.columns:
ad_concat.obs[self.cfg.context_key] = first_input.obs[
self.cfg.context_key
].values
return ad_concat, modality_features
# ---- 2d: vocab ----------------------------------------------------
def _build_vocab(self, ad_concat: sc.AnnData) -> None:
self.cell_types = sorted(
ad_concat.obs[self.cfg.celltype_key].unique().tolist()
)
context_cats: Optional[List[str]] = None
if self.cfg.context_key and self.cfg.context_key in ad_concat.obs:
self.context_cats = sorted(
ad_concat.obs[self.cfg.context_key].astype(str).unique().tolist()
)
context_cats = self.context_cats
self.vocab = build_vocab(
modality_features=self.modality_features,
cell_types=self.cell_types,
context_categories=context_cats,
)
# clamp top_k to actual feature count
self.top_ks = {}
for mcfg in self.cfg.modalities:
actual = len(self.modality_features.get(mcfg.name, []))
self.top_ks[mcfg.name] = min(mcfg.top_k, actual)
mcfg.top_k = self.top_ks[mcfg.name]
# persist
meta_dir = join(self.save_dir, "metadata")
os.makedirs(meta_dir, exist_ok=True)
payload = {**self.vocab, "top_ks": self.top_ks}
with open(join(meta_dir, "vocab_meta.json"), "w") as f:
json.dump(payload, f, indent=4, default=str)
# ---- 2e: simulate -------------------------------------------------
def _simulate(self, ad_concat: sc.AnnData) -> None:
all_features: List[str] = []
for name in self.cfg.modality_names:
all_features.extend(self.modality_features[name])
context2id = self.vocab.get("context2id", None)
context_key = (
self.cfg.context_key
if context2id
else None
)
self._real_n, self._inp_path, self._lbl_path, self._context_path = simulate(
ad_concat,
vocabs=self.vocab["vocabs"],
modality_names=self.cfg.modality_names,
top_ks=self.top_ks,
cell_types=self.cell_types,
all_features=all_features,
celltype_key=self.cfg.celltype_key,
sim_batch_key=self.cfg.simulation.batch_key,
save_dir=self.save_dir,
cfg=self.cfg.simulation,
context2id=context2id,
context_key=context_key,
)
# ---- 2f: dataloader -----------------------------------------------
def _build_loader(self) -> None:
segment_layout = [
{
"name": mcfg.name,
"top_k": self.top_ks[mcfg.name],
"mask_id": self.vocab["mask_ids"][mcfg.name],
"cl_dropout_rate": mcfg.cl_dropout_rate,
"mrp_mask_rate": mcfg.mrp_mask_rate,
}
for mcfg in self.cfg.modalities
]
ds = MemmapDataset(
input_path=self._inp_path,
label_path=self._lbl_path,
context_path=self._context_path,
valid_samples=self._real_n,
max_samples=self.cfg.simulation.total_samples,
seq_len=self.cfg.total_seq_len,
num_classes=len(self.cell_types),
segment_layout=segment_layout,
cl_mode=self.cfg.use_cl,
mrp_mode=self.cfg.use_mrp,
)
self._loader = DataLoader(
ds,
batch_size=self.cfg.batch_size,
shuffle=True,
num_workers=self.cfg.num_workers
)
# ---- 2g: model architecture ---------------------------------------
def _build_model(self) -> None:
tower_specs = [
{
"name": mcfg.name,
"vocab_size": len(self.vocab["vocabs"][mcfg.name]) + 1,
"top_k": self.top_ks[mcfg.name],
"mask_id": self.vocab["mask_ids"][mcfg.name],
"depth": mcfg.depth or self.cfg.depth,
"heads": mcfg.heads or self.cfg.heads,
}
for mcfg in self.cfg.modalities
]
num_contexts = len(self.vocab.get("context2id", {}))
self.model = SpotRankTransformer(
tower_specs=tower_specs,
num_cell_types=len(self.cell_types),
num_contexts=num_contexts,
dim=self.cfg.dim,
depth=self.cfg.depth,
heads=self.cfg.heads,
mlp_dim=self.cfg.mlp_dim,
dropout=self.cfg.dropout,
fusion_type=getattr(self.cfg, "fusion_type", None),
context_dim=getattr(self.cfg, "context_dim", None),
use_cl=self.cfg.use_cl,
use_mrp=self.cfg.use_mrp,
).to(self.device)
# ──────────────────────────────────────────────────────────────────────
# Repr
# ──────────────────────────────────────────────────────────────────────
def __repr__(self):
mods = ", ".join(m.name for m in self.cfg.modalities)
status = "trained" if (self.model is not None and self.trainer is not None) else "untrained"
return f"SpaRank(modalities=[{mods}], status={status}, device={self.device})"