Quick start

Unimodal RNA

from sparank import SpaRank, ExpConfig, ModalityConfig, SimulationConfig

cfg = ExpConfig(
    modalities=[ModalityConfig(name="rna", top_k=500)],
    simulation=SimulationConfig(       # controls pseudo-spot generation for training
        total_samples=100_000,
        cells_mean=10,                 # expected cells per spot — match your section resolution
        cells_std=5,
        cell_sample_method="gaussian",
        batch_key="Tissue",            # group cells by Tissue; simulate pseudo-spots within each, then pool (can differ from context_key)
    ),
    celltype_key="cell_type",
    batch_key="batch",                # group cells by batch; detect marker genes within each group
    context_key="Tissue",             # label cells by Tissue; embed the label into the prediction head
    epochs=20,
)

sr = SpaRank(cfg, save_dir="./output")
sr.register_modality("rna", adata_sc_rna)
sr.prepare()                           
# or pass markers directly:
# sr.prepare(marker_features={"rna": prior_gene_list})
sr.fit()
sr.save()

# Deploy to a spatial section — no retraining needed
df = sr.predict({"rna": adata_sp_rna})

Multimodal RNA + ADT

cfg = ExpConfig(
    modalities=[
        ModalityConfig(name="rna", top_k=500),
        ModalityConfig(name="adt", top_k=200),
    ],
    simulation=SimulationConfig(       # controls pseudo-spot generation for training
        total_samples=100_000,
        cells_mean=10,                 # expected cells per spot — match your section resolution
        cells_std=5,
        cell_sample_method="gaussian",
        batch_key="Tissue",            # group cells by Tissue; simulate pseudo-spots within each, then pool (can differ from context_key)
    ),
    celltype_key="cell_type",
    batch_key="batch",           
    context_key="Tissue",          
    fusion_type="sigmoid_gate",    # "concat" | "gate" | "joint_gate" | "sigmoid_gate" | "attention"
    epochs=20,
)

sr = SpaRank(cfg, save_dir="./output")
sr.register_modality("rna", adata_sc_rna)
sr.register_modality("adt", adata_sc_adt)
sr.prepare()
sr.fit()

# Predict proportions
proportions_df = sr.predict(
    {"rna": adata_sp_rna, "adt": adata_sp_adt},
)