sparank.ExpConfig
- class sparank.ExpConfig(modalities=<factory>, celltype_key='cell_type', batch_key=None, context_key=None, dim=64, depth=1, heads=1, mlp_dim=256, dropout=0.1, fusion_type='sigmoid_gate', context_dim=None, batch_size=256, lr=0.0001, epochs=10, cl_weight=0.5, cl_temperature=0.1, mrp_weight=0.5, cls_weight=1.0, num_workers=1, simulation=<factory>, save_dir='output')[source]
Bases:
objectUnified configuration covering both unimodal and multimodal setups.
When
len(modalities) == 1, the model reduces to a single-tower encoder (the “unimodal” case). Context conditioning is activated whenevercontext_keyis not None.- Example — unimodal with context:
>>> cfg = ExpConfig( ... modalities=[ModalityConfig(name="rna", top_k=500)], ... simulation=SimulationConfig( ... total_samples=50_000 ... cells_mean=20, ... cells_std=10, ... cell_sample_method="lognormal", ... batch_key="Tissue", ... ), ... context_key="subtype", ... )
- Example — trimodal:
>>> cfg = ExpConfig( ... modalities=[ ... ModalityConfig(name="rna", top_k=500), ... ModalityConfig(name="adt", top_k=200), ... ModalityConfig(name="atac", top_k=300), ... ], ... fusion_type="gate", ... )
- Parameters:
modalities (List[ModalityConfig])
celltype_key (str)
batch_key (str | None)
context_key (str | None)
dim (int)
depth (int)
heads (int)
mlp_dim (int)
dropout (float)
fusion_type (str)
context_dim (int | None)
batch_size (int)
lr (float)
epochs (int)
cl_weight (float)
cl_temperature (float)
mrp_weight (float)
cls_weight (float)
num_workers (int)
simulation (SimulationConfig)
save_dir (str)
- __init__(modalities=<factory>, celltype_key='cell_type', batch_key=None, context_key=None, dim=64, depth=1, heads=1, mlp_dim=256, dropout=0.1, fusion_type='sigmoid_gate', context_dim=None, batch_size=256, lr=0.0001, epochs=10, cl_weight=0.5, cl_temperature=0.1, mrp_weight=0.5, cls_weight=1.0, num_workers=1, simulation=<factory>, save_dir='output')
- Parameters:
modalities (List[ModalityConfig])
celltype_key (str)
batch_key (str | None)
context_key (str | None)
dim (int)
depth (int)
heads (int)
mlp_dim (int)
dropout (float)
fusion_type (str)
context_dim (int | None)
batch_size (int)
lr (float)
epochs (int)
cl_weight (float)
cl_temperature (float)
mrp_weight (float)
cls_weight (float)
num_workers (int)
simulation (SimulationConfig)
save_dir (str)
- Return type:
None
Methods
__init__([modalities, celltype_key, ...])get_modality(name)Retrieve a specific modality configuration by its name.
load(path)Load an ExpConfig instance from a JSON file.
save(path)Save the configuration to a JSON file.
Attributes
Column in adata.obs denoting batch (for marker selection).
Training batch size.
Column in adata.obs denoting cell types.
Temperature scaling parameter for contrastive loss.
Weight for Contrastive Learning loss.
Weight for classification loss.
Dimension of context embedding.
Key for context conditioning.
Number of transformer layers (depth).
Base hidden dimension size.
Global dropout rate.
Number of training epochs.
"concat", "gate", "attention", "sigmoid_gate".
Number of attention heads.
Learning rate.
Hidden dimension size of the MLP layer.
Get a list of all configured modality names.
Weight for Masked Region Prediction loss.
Get the total number of configured modalities.
Number of Dataloader workers.
Directory path to save outputs.
Calculate the sum of top_k features across all modalities.
Check if Contrastive Learning (CL) is active based on its weight.
Check if context conditioning is active.
Check if Masked Region Prediction (MRP) is active based on its weight.
List of modality configurations.
Configuration block for pseudo-spot simulation.
- fusion_type: str = 'sigmoid_gate'
“concat”, “gate”, “attention”, “sigmoid_gate”.
- Type:
Type of multimodal fusion. Options
- get_modality(name)[source]
Retrieve a specific modality configuration by its name.
- modalities: List[ModalityConfig]
List of modality configurations.
- simulation: SimulationConfig
Configuration block for pseudo-spot simulation.