Source code for sparank.data.tokenizer

"""Vocabulary construction and expression-to-token conversion.

All functions are N-modality generic. A "unimodal" workflow is simply
one entry in the vocabs / modality_names dicts.
"""

from __future__ import annotations

from typing import Dict, List, Optional, Tuple, Any

import numpy as np
import pandas as pd
import scipy.sparse as sps
from anndata import AnnData


# -----------------------------------------------------------------------
#  Vocabulary builder  (works for 1 to N modalities)
# -----------------------------------------------------------------------

[docs] def build_vocab( modality_features: Dict[str, List[str]], cell_types: List[str], context_categories: Optional[List[str]] = None, ) -> Dict[str, Any]: """Build per-modality vocabularies. Parameters ---------- modality_features : Dict[str, List[str]] A dictionary mapping modality names to lists of prefixed feature names. E.g., ``{"rna": ["rna-GAPDH", ...], "adt": ["adt-CD3", ...]}``. For unimodal workflows without prefixes, use any key (e.g., ``"rna"``). cell_types : List[str] A list of sorted cell-type labels. context_categories : List[str], optional A list of context labels. ``None`` indicates no context vocabulary should be generated. Returns ------- Dict[str, Any] A dictionary containing the generated mappings: - ``vocabs``: ``{mod_name: {feature: id, ...}}`` - ``mask_ids``: ``{mod_name: int}`` - ``type2id``: ``{cell_type: int}`` - ``cell_types``: Original list of cell types. - ``context2id``: ``{context: int}`` (only if *context_categories* is given). """ vocabs: Dict[str, Dict[str, int]] = {} mask_ids: Dict[str, int] = {} for mod_name, feats in modality_features.items(): v = {"<PAD>": 0, "<UNK>": 1} v.update({g: i + 2 for i, g in enumerate(sorted(feats))}) vocabs[mod_name] = v mask_ids[mod_name] = len(v) type2id = {ct: i for i, ct in enumerate(cell_types)} result = dict( vocabs=vocabs, mask_ids=mask_ids, type2id=type2id, cell_types=cell_types, ) if context_categories is not None: result["context2id"] = {c: i for i, c in enumerate(context_categories)} return result
# ----------------------------------------------------------------------- # Top-k helper # ----------------------------------------------------------------------- def _vectorized_top_k(X: np.ndarray, vocab_ids: np.ndarray, top_k: int) -> np.ndarray: """Extract top-k highly expressed features per cell using vectorized operations. Parameters ---------- X : np.ndarray Expression matrix of shape (B, n) where B is batch size and n is features. vocab_ids : np.ndarray Vocabulary token IDs corresponding to the columns of X, shape (n,). top_k : int Number of top features to retain per cell. Returns ------- np.ndarray Token matrix of shape (B, top_k) populated with token IDs. Cells with zero expression values are padded with 0. """ B, n = X.shape actual_k = min(n, top_k) tokens = np.zeros((B, top_k), dtype=np.int64) if n == 0: return tokens if n > actual_k: idx = np.argpartition(X, -actual_k, axis=1)[:, -actual_k:] else: idx = np.arange(n).reshape(1, -1).repeat(B, axis=0) vals = np.take_along_axis(X, idx, axis=1) order = np.argsort(vals, axis=1)[:, ::-1] col_sorted = np.take_along_axis(idx, order, axis=1) tok = vocab_ids[col_sorted] val_sorted = np.take_along_axis(vals, order, axis=1) # Mask out tokens where expression is <= 0 tok[val_sorted <= 0] = 0 tokens[:, :actual_k] = tok return tokens # ----------------------------------------------------------------------- # Unified tokeniser # -----------------------------------------------------------------------
[docs] def tokenize_batch( adata: AnnData, vocabs: Dict[str, Dict[str, int]], modality_names: List[str], top_ks: Dict[str, int], cell_types: List[str], context2id: Optional[Dict[str, int]] = None, context_key: Optional[str] = None, mode: str = "train", ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: """Tokenise an AnnData batch across 1-to-N modalities. Features in *adata* must generally be prefixed (e.g., ``"rna-GAPDH"``, ``"adt-CD3"``). For unimodal workflows without a prefix, the vocab dict keys must match ``adata.var_names`` exactly. Parameters ---------- adata : AnnData Annotated data matrix containing the batch to tokenise. vocabs : Dict[str, Dict[str, int]] Per-modality vocabulary dictionaries mapping feature names to token IDs. modality_names : List[str] List of modality names to process. top_ks : Dict[str, int] Dictionary mapping modality names to their target sequence lengths. cell_types : List[str] List of all possible cell-type labels. context2id : Dict[str, int], optional Mapping of context labels to token IDs. context_key : str, optional Key in ``adata.obs`` denoting context. mode : str, default "train" Processing mode. If "train", label matrices are computed. Returns ------- Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]] A tuple containing: - `token_matrix`: np.ndarray of shape (N, total_seq_len) or None. - `label_matrix`: np.ndarray of shape (N, C) or None. - `context_ids`: np.ndarray of shape (N,) or None. """ # Labels label_mat = None if mode == "train" and "label" in adata.obsm: if isinstance(adata.obsm["label"], pd.DataFrame): df = adata.obsm["label"] label_mat = np.zeros((adata.n_obs, len(cell_types)), dtype=np.float32) for i, ct in enumerate(cell_types): if ct in df.columns: label_mat[:, i] = df[ct].values else: label_mat = np.asarray(adata.obsm["label"], dtype=np.float32) # Context context_ids = None if context2id is not None and context_key is not None: if context_key in adata.obs.columns: raw = adata.obs[context_key].values elif "_context_label" in adata.uns: raw = np.full(adata.n_obs, adata.uns["_context_label"], dtype=object) else: raw = np.full(adata.n_obs, list(context2id.keys())[0], dtype=object) context_ids = np.array([context2id.get(str(c), 0) for c in raw], dtype=np.int64) # Tokenise per modality X = adata.X if sps.issparse(X): X = X.toarray() all_vars = adata.var_names.tolist() B = X.shape[0] segments = [] for mod_name in modality_names: vocab = vocabs[mod_name] top_k = top_ks[mod_name] col_idx, voc_ids = [], [] for j, feat in enumerate(all_vars): # Fix: Simplified logic. Since 'vocab' is exact match, # we just need to check if 'feat' exists in 'vocab' directly. if feat in vocab and vocab[feat] > 1: col_idx.append(j) voc_ids.append(vocab[feat]) if col_idx: seg = _vectorized_top_k( X[:, col_idx], np.array(voc_ids, dtype=np.int64), top_k, ) else: seg = np.zeros((B, top_k), dtype=np.int64) segments.append(seg) if not segments: return None, None, None tokens = np.hstack(segments) return tokens, label_mat, context_ids