Source code for sparank.config

"""Dataclass-based configuration for SpaRank experiments."""

from __future__ import annotations

import json
import os
import warnings
from dataclasses import dataclass, field, asdict
from typing import List, Optional


[docs] @dataclass class SimulationConfig: """Parameters for pseudo-spot simulation via SPACEL.""" #: Total number of simulated samples to generate. total_samples: int = 100_000 #: Number of threads for parallel processing. n_threads: int = 8 #: Mean number of cells per pseudo-spot. cells_mean: int = 10 #: Minimum number of cells per pseudo-spot. cells_min: int = 1 #: Maximum number of cells per pseudo-spot. cells_max: int = 50 #: Standard deviation of cells per pseudo-spot. cells_std: int = 5 #: Method for sampling cells. Options: "gaussian" or "lognormal". cell_sample_method: str = "gaussian" #: Buffer size used during simulation. sim_buffer: int = 10 #: Size of batch requests. batch_request_size: int = 100_000 #: Column in adata.obs to batch by during simulation. batch_key: Optional[str] = None
[docs] @dataclass class ModalityConfig: """Specification for one omics modality (RNA, ADT, ATAC, ...). Each modality owns its own vocabulary, token-sequence length, number of top markers, and optional per-tower architecture overrides. """ #: Name of the modality (e.g., "rna", "adt"). name: str = "rna" #: Number of top features to retain. top_k: int = 500 #: Number of features to use for marker selection. k_feat: Optional[int] = 50 #: Dropout rate for the contrastive learning layer. cl_dropout_rate: float = 0.3 #: Mask rate for masked region prediction (MRP). mrp_mask_rate: float = 0.3 #: Override depth for this modality's tower. If None, inherits from ExpConfig. depth: Optional[int] = None #: Override attention heads. If None, inherits from ExpConfig. heads: Optional[int] = None
[docs] @dataclass class ExpConfig: """Unified configuration covering both unimodal and multimodal setups. When ``len(modalities) == 1``, the model reduces to a single-tower encoder (the "unimodal" case). Context conditioning is activated whenever ``context_key`` is not None. Example — unimodal with context: >>> cfg = ExpConfig( ... modalities=[ModalityConfig(name="rna", top_k=500)], ... simulation=SimulationConfig( ... total_samples=50_000 ... cells_mean=20, ... cells_std=10, ... cell_sample_method="lognormal", ... batch_key="Tissue", ... ), ... context_key="subtype", ... ) Example — trimodal: >>> cfg = ExpConfig( ... modalities=[ ... ModalityConfig(name="rna", top_k=500), ... ModalityConfig(name="adt", top_k=200), ... ModalityConfig(name="atac", top_k=300), ... ], ... fusion_type="gate", ... ) """ # ── Modalities ───────────────────────────────────────────── #: List of modality configurations. modalities: List[ModalityConfig] = field( default_factory=lambda: [ModalityConfig(name="rna", top_k=500, k_feat=50)] ) # ── Annotation keys ──────────────────────────────────────── #: Column in adata.obs denoting cell types. celltype_key: str = "cell_type" #: Column in adata.obs denoting batch (for marker selection). batch_key: Optional[str] = None #: Key for context conditioning. None indicates no state embedding. context_key: Optional[str] = None # ── Shared architecture ──────────────────────────────────── #: Base hidden dimension size. dim: int = 64 #: Number of transformer layers (depth). depth: int = 1 #: Number of attention heads. heads: int = 1 #: Hidden dimension size of the MLP layer. mlp_dim: int = 256 #: Global dropout rate. dropout: float = 0.1 #: Type of multimodal fusion. Options: "concat", "gate", "attention", "sigmoid_gate". fusion_type: str = "sigmoid_gate" #: Dimension of context embedding. Defaults to `dim` if None. context_dim: Optional[int] = None # ── Training ─────────────────────────────────────────────── #: Training batch size. batch_size: int = 256 #: Learning rate. lr: float = 1e-4 #: Number of training epochs. epochs: int = 10 #: Weight for Contrastive Learning loss. Set to <= 0 to disable. cl_weight: float = 0.5 #: Temperature scaling parameter for contrastive loss. cl_temperature: float = 0.1 #: Weight for Masked Region Prediction loss. Set to <= 0 to disable. mrp_weight: float = 0.5 #: Weight for classification loss. cls_weight: float = 1.0 #: Number of Dataloader workers. num_workers: int = 1 # ── Simulation / I/O ─────────────────────────────────────── #: Configuration block for pseudo-spot simulation. simulation: SimulationConfig = field(default_factory=SimulationConfig) #: Directory path to save outputs. save_dir: str = "output" # ── Convenience ──────────────────────────────────────────── @property def modality_names(self) -> List[str]: """Get a list of all configured modality names.""" return [m.name for m in self.modalities] @property def n_modalities(self) -> int: """Get the total number of configured modalities.""" return len(self.modalities) @property def total_seq_len(self) -> int: """Calculate the sum of `top_k` features across all modalities.""" return sum(m.top_k for m in self.modalities) @property def use_context(self) -> bool: """Check if context conditioning is active.""" return self.context_key is not None @property def use_cl(self) -> bool: """Check if Contrastive Learning (CL) is active based on its weight.""" return self.cl_weight > 0 @property def use_mrp(self) -> bool: """Check if Masked Region Prediction (MRP) is active based on its weight.""" return self.mrp_weight > 0
[docs] def get_modality(self, name: str) -> ModalityConfig: """Retrieve a specific modality configuration by its name. Args: name (str): The name of the modality to search for. Returns: ModalityConfig: The matching modality configuration. Raises: KeyError: If the specified modality name is not found. """ for m in self.modalities: if m.name == name: return m raise KeyError(f"Modality '{name}' not found")
[docs] def save(self, path: str) -> None: """Save the configuration to a JSON file. Args: path (str): The destination file path. """ os.makedirs(os.path.dirname(path) or ".", exist_ok=True) with open(path, "w") as f: json.dump(asdict(self), f, indent=4)
[docs] @classmethod def load(cls, path: str) -> "ExpConfig": """Load an ExpConfig instance from a JSON file. Args: path (str): The path to the JSON configuration file. Returns: ExpConfig: The instantiated configuration object. """ with open(path) as f: d = json.load(f) sim = SimulationConfig(**d.pop("simulation", {})) mods = [ModalityConfig(**m) for m in d.pop("modalities", [])] return cls(modalities=mods, simulation=sim, **d)
def __post_init__(self): """Perform post-initialization validation and sanitization.""" # Fix: Ensure negative weights are safely clamped to 0.0 to avoid unexpected behavior. if self.cl_weight < 0: warnings.warn("cl_weight is negative. Clamping to 0.0.") self.cl_weight = 0.0 if self.mrp_weight < 0: warnings.warn("mrp_weight is negative. Clamping to 0.0.") self.mrp_weight = 0.0