sparank.ExpConfig

class sparank.ExpConfig(modalities=<factory>, celltype_key='cell_type', batch_key=None, context_key=None, dim=64, depth=1, heads=1, mlp_dim=256, dropout=0.1, fusion_type='sigmoid_gate', context_dim=None, batch_size=256, lr=0.0001, epochs=10, cl_weight=0.5, cl_temperature=0.1, mrp_weight=0.5, cls_weight=1.0, num_workers=1, simulation=<factory>, save_dir='output')[source]

Bases: object

Unified configuration covering both unimodal and multimodal setups.

When len(modalities) == 1, the model reduces to a single-tower encoder (the “unimodal” case). Context conditioning is activated whenever context_key is not None.

Example — unimodal with context:
>>> cfg = ExpConfig(
...     modalities=[ModalityConfig(name="rna", top_k=500)],
...     simulation=SimulationConfig(
...         total_samples=50_000
...         cells_mean=20,
...         cells_std=10,
...         cell_sample_method="lognormal",
...         batch_key="Tissue",
...     ),
...     context_key="subtype",
... )
Example — trimodal:
>>> cfg = ExpConfig(
...     modalities=[
...         ModalityConfig(name="rna",  top_k=500),
...         ModalityConfig(name="adt",  top_k=200),
...         ModalityConfig(name="atac", top_k=300),
...     ],
...     fusion_type="gate",
... )
Parameters:
__init__(modalities=<factory>, celltype_key='cell_type', batch_key=None, context_key=None, dim=64, depth=1, heads=1, mlp_dim=256, dropout=0.1, fusion_type='sigmoid_gate', context_dim=None, batch_size=256, lr=0.0001, epochs=10, cl_weight=0.5, cl_temperature=0.1, mrp_weight=0.5, cls_weight=1.0, num_workers=1, simulation=<factory>, save_dir='output')
Parameters:
Return type:

None

Methods

__init__([modalities, celltype_key, ...])

get_modality(name)

Retrieve a specific modality configuration by its name.

load(path)

Load an ExpConfig instance from a JSON file.

save(path)

Save the configuration to a JSON file.

Attributes

batch_key

Column in adata.obs denoting batch (for marker selection).

batch_size

Training batch size.

celltype_key

Column in adata.obs denoting cell types.

cl_temperature

Temperature scaling parameter for contrastive loss.

cl_weight

Weight for Contrastive Learning loss.

cls_weight

Weight for classification loss.

context_dim

Dimension of context embedding.

context_key

Key for context conditioning.

depth

Number of transformer layers (depth).

dim

Base hidden dimension size.

dropout

Global dropout rate.

epochs

Number of training epochs.

fusion_type

"concat", "gate", "attention", "sigmoid_gate".

heads

Number of attention heads.

lr

Learning rate.

mlp_dim

Hidden dimension size of the MLP layer.

modality_names

Get a list of all configured modality names.

mrp_weight

Weight for Masked Region Prediction loss.

n_modalities

Get the total number of configured modalities.

num_workers

Number of Dataloader workers.

save_dir

Directory path to save outputs.

total_seq_len

Calculate the sum of top_k features across all modalities.

use_cl

Check if Contrastive Learning (CL) is active based on its weight.

use_context

Check if context conditioning is active.

use_mrp

Check if Masked Region Prediction (MRP) is active based on its weight.

modalities

List of modality configurations.

simulation

Configuration block for pseudo-spot simulation.

batch_key: str | None = None

Column in adata.obs denoting batch (for marker selection).

batch_size: int = 256

Training batch size.

celltype_key: str = 'cell_type'

Column in adata.obs denoting cell types.

cl_temperature: float = 0.1

Temperature scaling parameter for contrastive loss.

cl_weight: float = 0.5

Weight for Contrastive Learning loss. Set to <= 0 to disable.

cls_weight: float = 1.0

Weight for classification loss.

context_dim: int | None = None

Dimension of context embedding. Defaults to dim if None.

context_key: str | None = None

Key for context conditioning. None indicates no state embedding.

depth: int = 1

Number of transformer layers (depth).

dim: int = 64

Base hidden dimension size.

dropout: float = 0.1

Global dropout rate.

epochs: int = 10

Number of training epochs.

fusion_type: str = 'sigmoid_gate'

“concat”, “gate”, “attention”, “sigmoid_gate”.

Type:

Type of multimodal fusion. Options

get_modality(name)[source]

Retrieve a specific modality configuration by its name.

Parameters:

name (str) – The name of the modality to search for.

Returns:

The matching modality configuration.

Return type:

ModalityConfig

Raises:

KeyError – If the specified modality name is not found.

heads: int = 1

Number of attention heads.

classmethod load(path)[source]

Load an ExpConfig instance from a JSON file.

Parameters:

path (str) – The path to the JSON configuration file.

Returns:

The instantiated configuration object.

Return type:

ExpConfig

lr: float = 0.0001

Learning rate.

mlp_dim: int = 256

Hidden dimension size of the MLP layer.

modalities: List[ModalityConfig]

List of modality configurations.

property modality_names: List[str]

Get a list of all configured modality names.

mrp_weight: float = 0.5

Weight for Masked Region Prediction loss. Set to <= 0 to disable.

property n_modalities: int

Get the total number of configured modalities.

num_workers: int = 1

Number of Dataloader workers.

save(path)[source]

Save the configuration to a JSON file.

Parameters:

path (str) – The destination file path.

Return type:

None

save_dir: str = 'output'

Directory path to save outputs.

simulation: SimulationConfig

Configuration block for pseudo-spot simulation.

property total_seq_len: int

Calculate the sum of top_k features across all modalities.

property use_cl: bool

Check if Contrastive Learning (CL) is active based on its weight.

property use_context: bool

Check if context conditioning is active.

property use_mrp: bool

Check if Masked Region Prediction (MRP) is active based on its weight.