sparank.SpaRank
- class sparank.SpaRank(cfg, save_dir=None, device=None)[source]
Bases:
objectHigh-level pipeline: register modalities -> prepare -> fit -> predict.
This class serves as the main entry point for the SpaRank framework. It manages multiple AnnData objects (one per modality), handles pre-processing and pseudo-spot simulation, instantiates the multimodal model, and runs training and inference.
- Parameters:
cfg (ExpConfig) – Configuration object containing hyperparameters for modalities, model architecture, simulation, and training.
save_dir (str, optional) – Directory to save models and metadata. If None, falls back to
cfg.save_diror"./sparank_output".device (Union[str, torch.device], optional) – Compute device. Automatically determined if left as None.
- __init__(cfg, save_dir=None, device=None)[source]
- Parameters:
cfg (ExpConfig)
save_dir (str | None)
device (str | torch.device | None)
Methods
__init__(cfg[, save_dir, device])fit([loader])Train the model.
load(save_dir[, device, verbose])Reconstruct a trained SpaRank model from saved artefacts.
predict(mod_adatas[, batch_size, ...])Run inference on spatial data.
prepare([marker_features])Execute the full pre-training pipeline.
register_modality(name, adata)Register a single-modality AnnData.
register_mudata(mdata)Convenience method: register all modalities from a MuData object at once.
save([save_dir])Persist config, vocabulary mappings, and final model weights.
- fit(loader=None)[source]
Train the model. Uses the internal loader generated by prepare unless overridden.
- Parameters:
loader (DataLoader, optional) – Custom DataLoader. If None, uses the internal simulator DataLoader.
- Returns:
Returns self.
- Return type:
- Raises:
RuntimeError – If the model or dataloader has not been initialized.
- classmethod load(save_dir, device=None, verbose=False)[source]
Reconstruct a trained SpaRank model from saved artefacts.
- Parameters:
save_dir (str) – Directory where the model and metadata are stored.
device (Union[str, torch.device], optional) – Target compute device for the loaded model.
verbose (bool, default False) – If True, prints the checkpoint loading status.
- Returns:
An instantiated and loaded SpaRank pipeline ready for inference.
- Return type:
- predict(mod_adatas, batch_size=512, num_workers=0, return_embeddings=False, return_gate_scores=False, return_per_dim_gates=False, gate_reduction='mean')[source]
Run inference on spatial data.
By default returns a DataFrame of predicted cell-type proportions. Optional flags additionally return embeddings and/or fusion gate scores in a single forward pass — nothing is recomputed.
- Parameters:
mod_adatas (Dict[str, sc.AnnData]) – One AnnData per modality, keyed by modality name. Feature names should be unprefixed; prefixing is handled internally.
batch_size (int, default 512) – Batch size for the Inference DataLoader.
num_workers (int, default 0) – Workers for the Inference DataLoader.
return_embeddings (bool, default False) – If True, also return per-spot embeddings
(n_spots, head_in).return_gate_scores (bool, default False) – If True, also return per-spot, per-modality fusion weights as a DataFrame
(n_spots, n_modalities). Forconcat/ single- modality fusion, returns uniform 1/M.return_per_dim_gates (bool, default False) – If True, also return the raw per-dimension gate tensor
(n_spots, n_modalities, dim). Only meaningful forfusion_type='gate'; ignored otherwise. Impliesreturn_gate_scores=True.gate_reduction (str, default "mean") – How to reduce
(M, B, d)gate tensors to per-modality scalars. Options are"mean"or"max".
- Returns:
Always returned in this order, skipping items not requested: - proportions (pd.DataFrame) - embeddings (np.ndarray, if requested) - gate_scores (pd.DataFrame, if requested) - per_dim_gates (np.ndarray, if requested)
A single DataFrame is returned when only proportions are requested.
- Return type:
Union[pd.DataFrame, Tuple[Any, …]]
- prepare(marker_features=None)[source]
Execute the full pre-training pipeline.
This aligns cells, normalizes RNA data, performs marker selection, builds multi-modal vocabularies, simulates pseudo-spots, and constructs the model and dataloaders.
- register_modality(name, adata)[source]
Register a single-modality AnnData. Call once per modality.
- Parameters:
name (str) – Name of the modality (must match a modality name defined in cfg).
adata (sc.AnnData) – The AnnData object containing the single-cell reference for this modality.
- Returns:
Returns self to allow method chaining.
- Return type:
- Raises:
ValueError – If the provided modality name is not expected in the configuration.