"""Vocabulary construction and expression-to-token conversion.
All functions are N-modality generic. A "unimodal" workflow is simply
one entry in the vocabs / modality_names dicts.
"""
from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Any
import numpy as np
import pandas as pd
import scipy.sparse as sps
from anndata import AnnData
# -----------------------------------------------------------------------
# Vocabulary builder (works for 1 to N modalities)
# -----------------------------------------------------------------------
[docs]
def build_vocab(
modality_features: Dict[str, List[str]],
cell_types: List[str],
context_categories: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Build per-modality vocabularies.
Parameters
----------
modality_features : Dict[str, List[str]]
A dictionary mapping modality names to lists of prefixed feature names.
E.g., ``{"rna": ["rna-GAPDH", ...], "adt": ["adt-CD3", ...]}``.
For unimodal workflows without prefixes, use any key (e.g., ``"rna"``).
cell_types : List[str]
A list of sorted cell-type labels.
context_categories : List[str], optional
A list of context labels. ``None`` indicates no context vocabulary
should be generated.
Returns
-------
Dict[str, Any]
A dictionary containing the generated mappings:
- ``vocabs``: ``{mod_name: {feature: id, ...}}``
- ``mask_ids``: ``{mod_name: int}``
- ``type2id``: ``{cell_type: int}``
- ``cell_types``: Original list of cell types.
- ``context2id``: ``{context: int}`` (only if *context_categories* is given).
"""
vocabs: Dict[str, Dict[str, int]] = {}
mask_ids: Dict[str, int] = {}
for mod_name, feats in modality_features.items():
v = {"<PAD>": 0, "<UNK>": 1}
v.update({g: i + 2 for i, g in enumerate(sorted(feats))})
vocabs[mod_name] = v
mask_ids[mod_name] = len(v)
type2id = {ct: i for i, ct in enumerate(cell_types)}
result = dict(
vocabs=vocabs,
mask_ids=mask_ids,
type2id=type2id,
cell_types=cell_types,
)
if context_categories is not None:
result["context2id"] = {c: i for i, c in enumerate(context_categories)}
return result
# -----------------------------------------------------------------------
# Top-k helper
# -----------------------------------------------------------------------
def _vectorized_top_k(X: np.ndarray, vocab_ids: np.ndarray, top_k: int) -> np.ndarray:
"""Extract top-k highly expressed features per cell using vectorized operations.
Parameters
----------
X : np.ndarray
Expression matrix of shape (B, n) where B is batch size and n is features.
vocab_ids : np.ndarray
Vocabulary token IDs corresponding to the columns of X, shape (n,).
top_k : int
Number of top features to retain per cell.
Returns
-------
np.ndarray
Token matrix of shape (B, top_k) populated with token IDs. Cells with
zero expression values are padded with 0.
"""
B, n = X.shape
actual_k = min(n, top_k)
tokens = np.zeros((B, top_k), dtype=np.int64)
if n == 0:
return tokens
if n > actual_k:
idx = np.argpartition(X, -actual_k, axis=1)[:, -actual_k:]
else:
idx = np.arange(n).reshape(1, -1).repeat(B, axis=0)
vals = np.take_along_axis(X, idx, axis=1)
order = np.argsort(vals, axis=1)[:, ::-1]
col_sorted = np.take_along_axis(idx, order, axis=1)
tok = vocab_ids[col_sorted]
val_sorted = np.take_along_axis(vals, order, axis=1)
# Mask out tokens where expression is <= 0
tok[val_sorted <= 0] = 0
tokens[:, :actual_k] = tok
return tokens
# -----------------------------------------------------------------------
# Unified tokeniser
# -----------------------------------------------------------------------
[docs]
def tokenize_batch(
adata: AnnData,
vocabs: Dict[str, Dict[str, int]],
modality_names: List[str],
top_ks: Dict[str, int],
cell_types: List[str],
context2id: Optional[Dict[str, int]] = None,
context_key: Optional[str] = None,
mode: str = "train",
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
"""Tokenise an AnnData batch across 1-to-N modalities.
Features in *adata* must generally be prefixed (e.g., ``"rna-GAPDH"``,
``"adt-CD3"``). For unimodal workflows without a prefix, the vocab dict
keys must match ``adata.var_names`` exactly.
Parameters
----------
adata : AnnData
Annotated data matrix containing the batch to tokenise.
vocabs : Dict[str, Dict[str, int]]
Per-modality vocabulary dictionaries mapping feature names to token IDs.
modality_names : List[str]
List of modality names to process.
top_ks : Dict[str, int]
Dictionary mapping modality names to their target sequence lengths.
cell_types : List[str]
List of all possible cell-type labels.
context2id : Dict[str, int], optional
Mapping of context labels to token IDs.
context_key : str, optional
Key in ``adata.obs`` denoting context.
mode : str, default "train"
Processing mode. If "train", label matrices are computed.
Returns
-------
Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]
A tuple containing:
- `token_matrix`: np.ndarray of shape (N, total_seq_len) or None.
- `label_matrix`: np.ndarray of shape (N, C) or None.
- `context_ids`: np.ndarray of shape (N,) or None.
"""
# Labels
label_mat = None
if mode == "train" and "label" in adata.obsm:
if isinstance(adata.obsm["label"], pd.DataFrame):
df = adata.obsm["label"]
label_mat = np.zeros((adata.n_obs, len(cell_types)), dtype=np.float32)
for i, ct in enumerate(cell_types):
if ct in df.columns:
label_mat[:, i] = df[ct].values
else:
label_mat = np.asarray(adata.obsm["label"], dtype=np.float32)
# Context
context_ids = None
if context2id is not None and context_key is not None:
if context_key in adata.obs.columns:
raw = adata.obs[context_key].values
elif "_context_label" in adata.uns:
raw = np.full(adata.n_obs, adata.uns["_context_label"], dtype=object)
else:
raw = np.full(adata.n_obs, list(context2id.keys())[0], dtype=object)
context_ids = np.array([context2id.get(str(c), 0) for c in raw], dtype=np.int64)
# Tokenise per modality
X = adata.X
if sps.issparse(X):
X = X.toarray()
all_vars = adata.var_names.tolist()
B = X.shape[0]
segments = []
for mod_name in modality_names:
vocab = vocabs[mod_name]
top_k = top_ks[mod_name]
col_idx, voc_ids = [], []
for j, feat in enumerate(all_vars):
# Fix: Simplified logic. Since 'vocab' is exact match,
# we just need to check if 'feat' exists in 'vocab' directly.
if feat in vocab and vocab[feat] > 1:
col_idx.append(j)
voc_ids.append(vocab[feat])
if col_idx:
seg = _vectorized_top_k(
X[:, col_idx], np.array(voc_ids, dtype=np.int64), top_k,
)
else:
seg = np.zeros((B, top_k), dtype=np.int64)
segments.append(seg)
if not segments:
return None, None, None
tokens = np.hstack(segments)
return tokens, label_mat, context_ids