sparank.SpaRank

class sparank.SpaRank(cfg, save_dir=None, device=None)[source]

Bases: object

High-level pipeline: register modalities -> prepare -> fit -> predict.

This class serves as the main entry point for the SpaRank framework. It manages multiple AnnData objects (one per modality), handles pre-processing and pseudo-spot simulation, instantiates the multimodal model, and runs training and inference.

Parameters:
  • cfg (ExpConfig) – Configuration object containing hyperparameters for modalities, model architecture, simulation, and training.

  • save_dir (str, optional) – Directory to save models and metadata. If None, falls back to cfg.save_dir or "./sparank_output".

  • device (Union[str, torch.device], optional) – Compute device. Automatically determined if left as None.

__init__(cfg, save_dir=None, device=None)[source]
Parameters:

Methods

__init__(cfg[, save_dir, device])

fit([loader])

Train the model.

load(save_dir[, device, verbose])

Reconstruct a trained SpaRank model from saved artefacts.

predict(mod_adatas[, batch_size, ...])

Run inference on spatial data.

prepare([marker_features])

Execute the full pre-training pipeline.

register_modality(name, adata)

Register a single-modality AnnData.

register_mudata(mdata)

Convenience method: register all modalities from a MuData object at once.

save([save_dir])

Persist config, vocabulary mappings, and final model weights.

fit(loader=None)[source]

Train the model. Uses the internal loader generated by prepare unless overridden.

Parameters:

loader (DataLoader, optional) – Custom DataLoader. If None, uses the internal simulator DataLoader.

Returns:

Returns self.

Return type:

SpaRank

Raises:

RuntimeError – If the model or dataloader has not been initialized.

classmethod load(save_dir, device=None, verbose=False)[source]

Reconstruct a trained SpaRank model from saved artefacts.

Parameters:
  • save_dir (str) – Directory where the model and metadata are stored.

  • device (Union[str, torch.device], optional) – Target compute device for the loaded model.

  • verbose (bool, default False) – If True, prints the checkpoint loading status.

Returns:

An instantiated and loaded SpaRank pipeline ready for inference.

Return type:

SpaRank

predict(mod_adatas, batch_size=512, num_workers=0, return_embeddings=False, return_gate_scores=False, return_per_dim_gates=False, gate_reduction='mean')[source]

Run inference on spatial data.

By default returns a DataFrame of predicted cell-type proportions. Optional flags additionally return embeddings and/or fusion gate scores in a single forward pass — nothing is recomputed.

Parameters:
  • mod_adatas (Dict[str, sc.AnnData]) – One AnnData per modality, keyed by modality name. Feature names should be unprefixed; prefixing is handled internally.

  • batch_size (int, default 512) – Batch size for the Inference DataLoader.

  • num_workers (int, default 0) – Workers for the Inference DataLoader.

  • return_embeddings (bool, default False) – If True, also return per-spot embeddings (n_spots, head_in).

  • return_gate_scores (bool, default False) – If True, also return per-spot, per-modality fusion weights as a DataFrame (n_spots, n_modalities). For concat / single- modality fusion, returns uniform 1/M.

  • return_per_dim_gates (bool, default False) – If True, also return the raw per-dimension gate tensor (n_spots, n_modalities, dim). Only meaningful for fusion_type='gate'; ignored otherwise. Implies return_gate_scores=True.

  • gate_reduction (str, default "mean") – How to reduce (M, B, d) gate tensors to per-modality scalars. Options are "mean" or "max".

Returns:

Always returned in this order, skipping items not requested: - proportions (pd.DataFrame) - embeddings (np.ndarray, if requested) - gate_scores (pd.DataFrame, if requested) - per_dim_gates (np.ndarray, if requested)

A single DataFrame is returned when only proportions are requested.

Return type:

Union[pd.DataFrame, Tuple[Any, …]]

prepare(marker_features=None)[source]

Execute the full pre-training pipeline.

This aligns cells, normalizes RNA data, performs marker selection, builds multi-modal vocabularies, simulates pseudo-spots, and constructs the model and dataloaders.

Parameters:

marker_features (Dict[str, List[str]], optional) – Pre-computed marker feature lists per modality. If None, markers are discovered automatically for RNA; other modalities keep all features.

Returns:

Returns self.

Return type:

SpaRank

register_modality(name, adata)[source]

Register a single-modality AnnData. Call once per modality.

Parameters:
  • name (str) – Name of the modality (must match a modality name defined in cfg).

  • adata (sc.AnnData) – The AnnData object containing the single-cell reference for this modality.

Returns:

Returns self to allow method chaining.

Return type:

SpaRank

Raises:

ValueError – If the provided modality name is not expected in the configuration.

register_mudata(mdata)[source]

Convenience method: register all modalities from a MuData object at once.

Parameters:

mdata (MuData) – A MuData object containing multi-modal data.

Returns:

Returns self to allow method chaining.

Return type:

SpaRank

Raises:

KeyError – If a modality defined in the configuration is missing from the MuData.

save(save_dir=None)[source]

Persist config, vocabulary mappings, and final model weights.

Parameters:

save_dir (str, optional) – Destination directory. Overrides class instantiation path if provided.

Return type:

None