"""Memory-mapped PyTorch Dataset with on-the-fly CL / MRP augmentation.
A single ``MemmapDataset`` handles both unimodal (1 segment) and
multimodal (N segments) layouts. Context support is activated
when *context_path* is provided.
"""
from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Any
import numpy as np
import torch
from torch.utils.data import Dataset
# -----------------------------------------------------------------------
# Low-level augmentation helpers
# -----------------------------------------------------------------------
def _dropout_segment(seg: np.ndarray, rate: float, min_keep: int = 10) -> np.ndarray:
"""Apply random dropout to a modality segment.
Parameters
----------
seg : np.ndarray
The input token sequence array for a specific modality.
rate : float
The proportion of non-zero tokens to drop (set to 0).
min_keep : int, default 10
The minimum number of non-zero tokens to retain. If the segment
has fewer non-zero tokens than this, no dropout is applied.
Returns
-------
np.ndarray
The augmented segment with tokens dropped out.
"""
seg = seg.copy()
nz = np.where(seg != 0)[0]
if len(nz) <= min_keep:
return seg
n_drop = int(len(nz) * rate)
if len(nz) - n_drop < min_keep:
n_drop = len(nz) - min_keep
if n_drop <= 0:
return seg
seg[np.random.choice(nz, n_drop, replace=False)] = 0
return seg
def _mask_segment(
seg: np.ndarray, rate: float, mask_id: int, min_keep: int = 10
) -> Tuple[np.ndarray, np.ndarray]:
"""Apply Masked Region Prediction (MRP) masking to a segment.
Parameters
----------
seg : np.ndarray
The input token sequence array for a specific modality.
rate : float
The proportion of non-zero tokens to replace with `mask_id`.
mask_id : int
The token ID used to represent masked positions.
min_keep : int, default 10
The minimum number of non-zero tokens required to perform masking.
Returns
-------
Tuple[np.ndarray, np.ndarray]
A tuple containing:
- The masked segment (np.ndarray).
- A boolean array indicating the masked positions (np.ndarray).
"""
seg = seg.copy()
pos = np.zeros(len(seg), dtype=bool)
nz = np.where(seg != 0)[0]
if len(nz) <= min_keep or rate <= 0:
return seg, pos
n_mask = max(1, int(len(nz) * rate))
idx = np.random.choice(nz, n_mask, replace=False)
seg[idx] = mask_id
pos[idx] = True
return seg, pos
# -----------------------------------------------------------------------
# Unified Dataset
# -----------------------------------------------------------------------
[docs]
class MemmapDataset(Dataset):
"""Memory-mapped dataset for 1-to-N modalities.
The concatenated token sequence is split into per-modality segments
according to *segment_layout*, and CL dropout / MRP masking rates
are applied independently per segment.
Parameters
----------
input_path : str
Path to the memmap file for token inputs.
label_path : str
Path to the memmap file for labels.
valid_samples : int
Number of valid (written) samples to read.
max_samples : int
Total allocated row count in the memmap files.
seq_len : int
Total token-sequence length (= sum of per-modality top_k).
num_classes : int
Number of cell-type classes.
segment_layout : List[Dict]
One dict per modality defining the layout::
{"name": str, "top_k": int, "mask_id": int,
"cl_dropout_rate": float, "mrp_mask_rate": float}
Segments are contiguous and must sum to *seq_len*.
context_path : str, optional
Path to context memmap. ``None`` indicates no context returned.
cl_mode : bool, default False
If ``True``, each sample produces two dropout-augmented views
(`view_a` and `view_b`) for Contrastive Learning.
mrp_mode : bool, default False
If ``True``, produces masked targets and positions for Masked
Region Prediction.
"""
[docs]
def __init__(
self,
input_path: str,
label_path: str,
valid_samples: int,
max_samples: int,
seq_len: int,
num_classes: int,
segment_layout: List[Dict],
context_path: Optional[str] = None,
cl_mode: bool = False,
mrp_mode: bool = False,
):
self.valid_samples = valid_samples
self.inputs = np.memmap(
input_path, dtype="int64", mode="r", shape=(max_samples, seq_len)
)
self.labels = np.memmap(
label_path, dtype="float32", mode="r", shape=(max_samples, num_classes)
)
self.has_context = context_path is not None
if self.has_context:
self.contexts = np.memmap(
context_path, dtype="int64", mode="r", shape=(max_samples,)
)
self.cl_mode = cl_mode
self.mrp_mode = mrp_mode
self.layout = segment_layout
# Pre-compute segment boundaries
self._bounds: List[Tuple[int, int]] = []
offset = 0
for seg in self.layout:
k = seg["top_k"]
self._bounds.append((offset, offset + k))
offset += k
# ── Augmentation ──────────────────────────────────────────
def _dropout(self, x: np.ndarray) -> np.ndarray:
parts = [
_dropout_segment(x[s:e], seg["cl_dropout_rate"])
for seg, (s, e) in zip(self.layout, self._bounds)
]
return np.concatenate(parts)
def _mask(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
target = x.copy()
m_parts, p_parts = [], []
for seg, (s, e) in zip(self.layout, self._bounds):
m, p = _mask_segment(x[s:e], seg["mrp_mask_rate"], seg["mask_id"])
m_parts.append(m)
p_parts.append(p)
return np.concatenate(m_parts), np.concatenate(p_parts), target
# ── Interface ─────────────────────────────────────────────
def __len__(self) -> int:
return self.valid_samples
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Fetch and augment a single sample.
Parameters
----------
idx : int
Sample index.
Returns
-------
Dict[str, torch.Tensor]
Dictionary containing 'x' and 'y'. Depending on initialization,
it may also contain 'context', 'view_a', 'view_b', 'masked_x',
'mask_pos', and 'target_x'.
"""
if idx >= self.valid_samples:
raise IndexError
x = self.inputs[idx].copy()
y = torch.from_numpy(self.labels[idx].copy())
item = {"x": torch.from_numpy(x), "y": y}
if self.has_context:
item["context"] = torch.tensor(self.contexts[idx], dtype=torch.long)
if self.cl_mode:
item["view_a"] = torch.from_numpy(self._dropout(x))
item["view_b"] = torch.from_numpy(self._dropout(x))
if self.mrp_mode:
masked_x, mask_pos, target_x = self._mask(x)
item["masked_x"] = torch.from_numpy(masked_x)
item["mask_pos"] = torch.from_numpy(mask_pos)
item["target_x"] = torch.from_numpy(target_x)
return item
[docs]
class InferenceDataset(Dataset):
"""Minimal in-memory dataset designed for inference.
Parameters
----------
tokens : np.ndarray
Integer token matrix produced by ``tokenize_batch``, shape (N, seq_len).
context_ids : np.ndarray, optional
Optional context / state IDs per sample, shape (N,).
"""
[docs]
def __init__(
self,
tokens: np.ndarray,
context_ids: Optional[np.ndarray] = None,
):
self.tokens = torch.from_numpy(tokens).long()
self.context_ids = (
torch.from_numpy(context_ids).long()
if context_ids is not None
else None
)
def __len__(self) -> int:
return len(self.tokens)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get token and context for inference.
Parameters
----------
idx : int
Sample index.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
A tuple of (tokens, context). If context was not provided,
defaults to a 0-tensor.
"""
if self.context_ids is not None:
context = self.context_ids[idx].clone().detach()
else:
context = torch.tensor(0, dtype=torch.long)
return self.tokens[idx], context