"""Dataclass-based configuration for SpaRank experiments."""
from __future__ import annotations
import json
import os
import warnings
from dataclasses import dataclass, field, asdict
from typing import List, Optional
[docs]
@dataclass
class SimulationConfig:
"""Parameters for pseudo-spot simulation via SPACEL."""
#: Total number of simulated samples to generate.
total_samples: int = 100_000
#: Number of threads for parallel processing.
n_threads: int = 8
#: Mean number of cells per pseudo-spot.
cells_mean: int = 10
#: Minimum number of cells per pseudo-spot.
cells_min: int = 1
#: Maximum number of cells per pseudo-spot.
cells_max: int = 50
#: Standard deviation of cells per pseudo-spot.
cells_std: int = 5
#: Method for sampling cells. Options: "gaussian" or "lognormal".
cell_sample_method: str = "gaussian"
#: Buffer size used during simulation.
sim_buffer: int = 10
#: Size of batch requests.
batch_request_size: int = 100_000
#: Column in adata.obs to batch by during simulation.
batch_key: Optional[str] = None
[docs]
@dataclass
class ModalityConfig:
"""Specification for one omics modality (RNA, ADT, ATAC, ...).
Each modality owns its own vocabulary, token-sequence length,
number of top markers, and optional per-tower architecture overrides.
"""
#: Name of the modality (e.g., "rna", "adt").
name: str = "rna"
#: Number of top features to retain.
top_k: int = 500
#: Number of features to use for marker selection.
k_feat: Optional[int] = 50
#: Dropout rate for the contrastive learning layer.
cl_dropout_rate: float = 0.3
#: Mask rate for masked region prediction (MRP).
mrp_mask_rate: float = 0.3
#: Override depth for this modality's tower. If None, inherits from ExpConfig.
depth: Optional[int] = None
#: Override attention heads. If None, inherits from ExpConfig.
heads: Optional[int] = None
[docs]
@dataclass
class ExpConfig:
"""Unified configuration covering both unimodal and multimodal setups.
When ``len(modalities) == 1``, the model reduces to a single-tower
encoder (the "unimodal" case). Context conditioning is activated
whenever ``context_key`` is not None.
Example — unimodal with context:
>>> cfg = ExpConfig(
... modalities=[ModalityConfig(name="rna", top_k=500)],
... simulation=SimulationConfig(
... total_samples=50_000
... cells_mean=20,
... cells_std=10,
... cell_sample_method="lognormal",
... batch_key="Tissue",
... ),
... context_key="subtype",
... )
Example — trimodal:
>>> cfg = ExpConfig(
... modalities=[
... ModalityConfig(name="rna", top_k=500),
... ModalityConfig(name="adt", top_k=200),
... ModalityConfig(name="atac", top_k=300),
... ],
... fusion_type="gate",
... )
"""
# ── Modalities ─────────────────────────────────────────────
#: List of modality configurations.
modalities: List[ModalityConfig] = field(
default_factory=lambda: [ModalityConfig(name="rna", top_k=500, k_feat=50)]
)
# ── Annotation keys ────────────────────────────────────────
#: Column in adata.obs denoting cell types.
celltype_key: str = "cell_type"
#: Column in adata.obs denoting batch (for marker selection).
batch_key: Optional[str] = None
#: Key for context conditioning. None indicates no state embedding.
context_key: Optional[str] = None
# ── Shared architecture ────────────────────────────────────
#: Base hidden dimension size.
dim: int = 64
#: Number of transformer layers (depth).
depth: int = 1
#: Number of attention heads.
heads: int = 1
#: Hidden dimension size of the MLP layer.
mlp_dim: int = 256
#: Global dropout rate.
dropout: float = 0.1
#: Type of multimodal fusion. Options: "concat", "gate", "attention", "sigmoid_gate".
fusion_type: str = "sigmoid_gate"
#: Dimension of context embedding. Defaults to `dim` if None.
context_dim: Optional[int] = None
# ── Training ───────────────────────────────────────────────
#: Training batch size.
batch_size: int = 256
#: Learning rate.
lr: float = 1e-4
#: Number of training epochs.
epochs: int = 10
#: Weight for Contrastive Learning loss. Set to <= 0 to disable.
cl_weight: float = 0.5
#: Temperature scaling parameter for contrastive loss.
cl_temperature: float = 0.1
#: Weight for Masked Region Prediction loss. Set to <= 0 to disable.
mrp_weight: float = 0.5
#: Weight for classification loss.
cls_weight: float = 1.0
#: Number of Dataloader workers.
num_workers: int = 1
# ── Simulation / I/O ───────────────────────────────────────
#: Configuration block for pseudo-spot simulation.
simulation: SimulationConfig = field(default_factory=SimulationConfig)
#: Directory path to save outputs.
save_dir: str = "output"
# ── Convenience ────────────────────────────────────────────
@property
def modality_names(self) -> List[str]:
"""Get a list of all configured modality names."""
return [m.name for m in self.modalities]
@property
def n_modalities(self) -> int:
"""Get the total number of configured modalities."""
return len(self.modalities)
@property
def total_seq_len(self) -> int:
"""Calculate the sum of `top_k` features across all modalities."""
return sum(m.top_k for m in self.modalities)
@property
def use_context(self) -> bool:
"""Check if context conditioning is active."""
return self.context_key is not None
@property
def use_cl(self) -> bool:
"""Check if Contrastive Learning (CL) is active based on its weight."""
return self.cl_weight > 0
@property
def use_mrp(self) -> bool:
"""Check if Masked Region Prediction (MRP) is active based on its weight."""
return self.mrp_weight > 0
[docs]
def get_modality(self, name: str) -> ModalityConfig:
"""Retrieve a specific modality configuration by its name.
Args:
name (str): The name of the modality to search for.
Returns:
ModalityConfig: The matching modality configuration.
Raises:
KeyError: If the specified modality name is not found.
"""
for m in self.modalities:
if m.name == name:
return m
raise KeyError(f"Modality '{name}' not found")
[docs]
def save(self, path: str) -> None:
"""Save the configuration to a JSON file.
Args:
path (str): The destination file path.
"""
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
with open(path, "w") as f:
json.dump(asdict(self), f, indent=4)
[docs]
@classmethod
def load(cls, path: str) -> "ExpConfig":
"""Load an ExpConfig instance from a JSON file.
Args:
path (str): The path to the JSON configuration file.
Returns:
ExpConfig: The instantiated configuration object.
"""
with open(path) as f:
d = json.load(f)
sim = SimulationConfig(**d.pop("simulation", {}))
mods = [ModalityConfig(**m) for m in d.pop("modalities", [])]
return cls(modalities=mods, simulation=sim, **d)
def __post_init__(self):
"""Perform post-initialization validation and sanitization."""
# Fix: Ensure negative weights are safely clamped to 0.0 to avoid unexpected behavior.
if self.cl_weight < 0:
warnings.warn("cl_weight is negative. Clamping to 0.0.")
self.cl_weight = 0.0
if self.mrp_weight < 0:
warnings.warn("mrp_weight is negative. Clamping to 0.0.")
self.mrp_weight = 0.0