Source code for sparank.data.dataset

"""Memory-mapped PyTorch Dataset with on-the-fly CL / MRP augmentation.

A single ``MemmapDataset`` handles both unimodal (1 segment) and
multimodal (N segments) layouts. Context support is activated
when *context_path* is provided.
"""

from __future__ import annotations

from typing import Dict, List, Optional, Tuple, Any

import numpy as np
import torch
from torch.utils.data import Dataset


# -----------------------------------------------------------------------
#  Low-level augmentation helpers
# -----------------------------------------------------------------------

def _dropout_segment(seg: np.ndarray, rate: float, min_keep: int = 10) -> np.ndarray:
    """Apply random dropout to a modality segment.

    Parameters
    ----------
    seg : np.ndarray
        The input token sequence array for a specific modality.
    rate : float
        The proportion of non-zero tokens to drop (set to 0).
    min_keep : int, default 10
        The minimum number of non-zero tokens to retain. If the segment 
        has fewer non-zero tokens than this, no dropout is applied.

    Returns
    -------
    np.ndarray
        The augmented segment with tokens dropped out.
    """
    seg = seg.copy()
    nz = np.where(seg != 0)[0]
    if len(nz) <= min_keep:
        return seg
    n_drop = int(len(nz) * rate)
    if len(nz) - n_drop < min_keep:
        n_drop = len(nz) - min_keep
    if n_drop <= 0:
        return seg
    seg[np.random.choice(nz, n_drop, replace=False)] = 0
    return seg


def _mask_segment(
    seg: np.ndarray, rate: float, mask_id: int, min_keep: int = 10
) -> Tuple[np.ndarray, np.ndarray]:
    """Apply Masked Region Prediction (MRP) masking to a segment.

    Parameters
    ----------
    seg : np.ndarray
        The input token sequence array for a specific modality.
    rate : float
        The proportion of non-zero tokens to replace with `mask_id`.
    mask_id : int
        The token ID used to represent masked positions.
    min_keep : int, default 10
        The minimum number of non-zero tokens required to perform masking.

    Returns
    -------
    Tuple[np.ndarray, np.ndarray]
        A tuple containing:
        - The masked segment (np.ndarray).
        - A boolean array indicating the masked positions (np.ndarray).
    """
    seg = seg.copy()
    pos = np.zeros(len(seg), dtype=bool)
    nz = np.where(seg != 0)[0]
    if len(nz) <= min_keep or rate <= 0:
        return seg, pos
    
    n_mask = max(1, int(len(nz) * rate))
    idx = np.random.choice(nz, n_mask, replace=False)
    seg[idx] = mask_id
    pos[idx] = True
    return seg, pos


# -----------------------------------------------------------------------
#  Unified Dataset
# -----------------------------------------------------------------------

[docs] class MemmapDataset(Dataset): """Memory-mapped dataset for 1-to-N modalities. The concatenated token sequence is split into per-modality segments according to *segment_layout*, and CL dropout / MRP masking rates are applied independently per segment. Parameters ---------- input_path : str Path to the memmap file for token inputs. label_path : str Path to the memmap file for labels. valid_samples : int Number of valid (written) samples to read. max_samples : int Total allocated row count in the memmap files. seq_len : int Total token-sequence length (= sum of per-modality top_k). num_classes : int Number of cell-type classes. segment_layout : List[Dict] One dict per modality defining the layout:: {"name": str, "top_k": int, "mask_id": int, "cl_dropout_rate": float, "mrp_mask_rate": float} Segments are contiguous and must sum to *seq_len*. context_path : str, optional Path to context memmap. ``None`` indicates no context returned. cl_mode : bool, default False If ``True``, each sample produces two dropout-augmented views (`view_a` and `view_b`) for Contrastive Learning. mrp_mode : bool, default False If ``True``, produces masked targets and positions for Masked Region Prediction. """
[docs] def __init__( self, input_path: str, label_path: str, valid_samples: int, max_samples: int, seq_len: int, num_classes: int, segment_layout: List[Dict], context_path: Optional[str] = None, cl_mode: bool = False, mrp_mode: bool = False, ): self.valid_samples = valid_samples self.inputs = np.memmap( input_path, dtype="int64", mode="r", shape=(max_samples, seq_len) ) self.labels = np.memmap( label_path, dtype="float32", mode="r", shape=(max_samples, num_classes) ) self.has_context = context_path is not None if self.has_context: self.contexts = np.memmap( context_path, dtype="int64", mode="r", shape=(max_samples,) ) self.cl_mode = cl_mode self.mrp_mode = mrp_mode self.layout = segment_layout # Pre-compute segment boundaries self._bounds: List[Tuple[int, int]] = [] offset = 0 for seg in self.layout: k = seg["top_k"] self._bounds.append((offset, offset + k)) offset += k
# ── Augmentation ────────────────────────────────────────── def _dropout(self, x: np.ndarray) -> np.ndarray: parts = [ _dropout_segment(x[s:e], seg["cl_dropout_rate"]) for seg, (s, e) in zip(self.layout, self._bounds) ] return np.concatenate(parts) def _mask(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: target = x.copy() m_parts, p_parts = [], [] for seg, (s, e) in zip(self.layout, self._bounds): m, p = _mask_segment(x[s:e], seg["mrp_mask_rate"], seg["mask_id"]) m_parts.append(m) p_parts.append(p) return np.concatenate(m_parts), np.concatenate(p_parts), target # ── Interface ───────────────────────────────────────────── def __len__(self) -> int: return self.valid_samples def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Fetch and augment a single sample. Parameters ---------- idx : int Sample index. Returns ------- Dict[str, torch.Tensor] Dictionary containing 'x' and 'y'. Depending on initialization, it may also contain 'context', 'view_a', 'view_b', 'masked_x', 'mask_pos', and 'target_x'. """ if idx >= self.valid_samples: raise IndexError x = self.inputs[idx].copy() y = torch.from_numpy(self.labels[idx].copy()) item = {"x": torch.from_numpy(x), "y": y} if self.has_context: item["context"] = torch.tensor(self.contexts[idx], dtype=torch.long) if self.cl_mode: item["view_a"] = torch.from_numpy(self._dropout(x)) item["view_b"] = torch.from_numpy(self._dropout(x)) if self.mrp_mode: masked_x, mask_pos, target_x = self._mask(x) item["masked_x"] = torch.from_numpy(masked_x) item["mask_pos"] = torch.from_numpy(mask_pos) item["target_x"] = torch.from_numpy(target_x) return item
[docs] class InferenceDataset(Dataset): """Minimal in-memory dataset designed for inference. Parameters ---------- tokens : np.ndarray Integer token matrix produced by ``tokenize_batch``, shape (N, seq_len). context_ids : np.ndarray, optional Optional context / state IDs per sample, shape (N,). """
[docs] def __init__( self, tokens: np.ndarray, context_ids: Optional[np.ndarray] = None, ): self.tokens = torch.from_numpy(tokens).long() self.context_ids = ( torch.from_numpy(context_ids).long() if context_ids is not None else None )
def __len__(self) -> int: return len(self.tokens) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """Get token and context for inference. Parameters ---------- idx : int Sample index. Returns ------- Tuple[torch.Tensor, torch.Tensor] A tuple of (tokens, context). If context was not provided, defaults to a 0-tensor. """ if self.context_ids is not None: context = self.context_ids[idx].clone().detach() else: context = torch.tensor(0, dtype=torch.long) return self.tokens[idx], context