Quick start
Unimodal RNA
from sparank import SpaRank, ExpConfig, ModalityConfig, SimulationConfig
cfg = ExpConfig(
modalities=[ModalityConfig(name="rna", top_k=500)],
simulation=SimulationConfig( # controls pseudo-spot generation for training
total_samples=100_000,
cells_mean=10, # expected cells per spot — match your section resolution
cells_std=5,
cell_sample_method="gaussian",
batch_key="Tissue", # group cells by Tissue; simulate pseudo-spots within each, then pool (can differ from context_key)
),
celltype_key="cell_type",
batch_key="batch", # group cells by batch; detect marker genes within each group
context_key="Tissue", # label cells by Tissue; embed the label into the prediction head
epochs=20,
)
sr = SpaRank(cfg, save_dir="./output")
sr.register_modality("rna", adata_sc_rna)
sr.prepare()
# or pass markers directly:
# sr.prepare(marker_features={"rna": prior_gene_list})
sr.fit()
sr.save()
# Deploy to a spatial section — no retraining needed
df = sr.predict({"rna": adata_sp_rna})
Multimodal RNA + ADT
cfg = ExpConfig(
modalities=[
ModalityConfig(name="rna", top_k=500),
ModalityConfig(name="adt", top_k=200),
],
simulation=SimulationConfig( # controls pseudo-spot generation for training
total_samples=100_000,
cells_mean=10, # expected cells per spot — match your section resolution
cells_std=5,
cell_sample_method="gaussian",
batch_key="Tissue", # group cells by Tissue; simulate pseudo-spots within each, then pool (can differ from context_key)
),
celltype_key="cell_type",
batch_key="batch",
context_key="Tissue",
fusion_type="sigmoid_gate", # "concat" | "gate" | "joint_gate" | "sigmoid_gate" | "attention"
epochs=20,
)
sr = SpaRank(cfg, save_dir="./output")
sr.register_modality("rna", adata_sc_rna)
sr.register_modality("adt", adata_sc_adt)
sr.prepare()
sr.fit()
# Predict proportions
proportions_df = sr.predict(
{"rna": adata_sp_rna, "adt": adata_sp_adt},
)