Source code for sparank.data.simulation

"""Pseudo-spot simulation with batch-wise memmap writing.

Generates synthetic spatial transcriptomics spots by mixing single-cell
profiles from a pre-concatenated multi-modal AnnData. The caller is
expected to prefix feature names (e.g., ``rna-GAPDH``, ``adt-CD3``) and
concatenate modalities along axis-1 before invoking :func:`simulate`.

Numba-accelerated parallel sampling reproduces the logic of the
original SPACEL (https://github.com/QuKunLab/SPACEL/tree/main/SPACEL) ``sample_cell`` kernel.
"""

from __future__ import annotations

import gc
from os.path import join
from typing import Dict, List, Optional, Tuple, Any

import numba as nb
import numpy as np
import pandas as pd
from anndata import AnnData
from numba import jit
from scipy.sparse import issparse
from tqdm import tqdm

from sparank.config import SimulationConfig
from sparank.data.tokenizer import tokenize_batch


# ──────────────────────────────────────────────────────────────────────────────
# Numba helpers
# ──────────────────────────────────────────────────────────────────────────────

@nb.njit
def _nb_apply_along_axis(func1d, axis, arr):
    """``np.apply_along_axis`` equivalent for numba (2-D only).

    Parameters
    ----------
    func1d : callable
        A 1-D Numba-compatible function to apply.
    axis : int
        Axis along which to apply the function (0 or 1).
    arr : np.ndarray
        The 2-D array to process.

    Returns
    -------
    np.ndarray
        The result array after applying the function.
    """
    assert arr.ndim == 2
    assert axis in (0, 1)
    if axis == 0:
        result = np.empty(arr.shape[1], dtype=arr.dtype)
        for i in range(len(result)):
            result[i] = func1d(arr[:, i])
    else:
        result = np.empty(arr.shape[0], dtype=arr.dtype)
        for i in range(len(result)):
            result[i] = func1d(arr[i, :])
    return result


@nb.njit
def _nb_sum(array, axis):
    """Row-wise or column-wise sum for a 2-D array inside numba.

    Parameters
    ----------
    array : np.ndarray
        The input 2-D array.
    axis : int
        The axis to sum over (0 for columns, 1 for rows).

    Returns
    -------
    np.ndarray
        The summed 1-D array.
    """
    return _nb_apply_along_axis(np.sum, axis, array)


# ──────────────────────────────────────────────────────────────────────────────
# Numba-jitted core sampling kernel
# ──────────────────────────────────────────────────────────────────────────────

@jit(nopython=True, parallel=True)
def _sample_spots_numba(
    param_list,
    cluster_p,
    clusters,
    cluster_id,
    sample_exp,
    sample_cluster,
    cell_p_balanced,
):
    """Generate pseudo-spots by mixing single cells (parallelised).

    Reproduces the **exact** sampling logic of the original SPACEL
    ``sample_cell`` kernel. The expression matrix may contain features
    from arbitrarily many concatenated modalities — the kernel is
    agnostic to modality boundaries.

    Parameters
    ----------
    param_list : np.ndarray
        Shape ``(n_spots, 2)``. ``[:, 0]`` is the number of cells per spot, 
        and ``[:, 1]`` is the number of cluster types per spot.
    cluster_p : np.ndarray
        Shape ``(n_cluster_types,)``. Cluster-level sampling probabilities.
    clusters : np.ndarray
        Shape ``(n_cluster_types,)``. Unique integer cluster IDs (ordered by frequency).
    cluster_id : np.ndarray
        Shape ``(n_cells,)``. Per-cell integer cluster label.
    sample_exp : np.ndarray
        Shape ``(n_cells, total_features)``. Expression matrix 
        (all modalities concatenated along axis-1).
    sample_cluster : np.ndarray
        Shape ``(n_cluster_types, n_cluster_types)``. Identity matrix used 
        to produce one-hot composition vectors.
    cell_p_balanced : np.ndarray
        Shape ``(n_cells,)``. Per-cell balanced sampling weights.

    Returns
    -------
    Tuple[np.ndarray, np.ndarray]
        - **exp** (*np.ndarray*): Summed expression per pseudo-spot, shape ``(n_spots, total_features)``.
        - **density** (*np.ndarray*): Cell-type composition counts per pseudo-spot, shape ``(n_spots, n_cluster_types)``.
    """
    n_spots = len(param_list)
    exp = np.empty((n_spots, sample_exp.shape[1]), dtype=np.float32)
    density = np.empty((n_spots, sample_cluster.shape[1]), dtype=np.float32)

    for i in nb.prange(n_spots):
        num_cell = param_list[i, 0]
        num_cluster = param_list[i, 1]

        # sampling clusters
        cum_cluster_p = np.cumsum(cluster_p)
        raw_idx = np.searchsorted(
            cum_cluster_p,
            np.random.rand(num_cluster),
            side="right",
        )
        raw_idx = np.minimum(raw_idx, len(clusters) - 1)
        used_clusters = clusters[raw_idx]

        cluster_mask = np.zeros(len(cluster_id), dtype=np.bool_)
        for c in used_clusters:
            cluster_mask = (cluster_id == c) | cluster_mask

        used_cell_ind = np.where(cluster_mask)[0]

        # in case all-zero spot
        if len(used_cell_ind) == 0:
            exp[i, :]     = np.zeros(sample_exp.shape[1],    dtype=np.float32)
            density[i, :] = np.zeros(sample_cluster.shape[1], dtype=np.float32)
            continue

        # sampling cells
        used_cell_p = cell_p_balanced[cluster_mask]
        used_cell_p = used_cell_p / used_cell_p.sum()

        cum_cell_p = np.cumsum(used_cell_p)
        cell_idx = np.searchsorted(
            cum_cell_p,
            np.random.rand(num_cell),
            side="right",
        )
        cell_idx = np.minimum(cell_idx, len(used_cell_ind) - 1)
        sampled_cells = used_cell_ind[cell_idx.astype(np.int64)]

        combined_exp = _nb_sum(sample_exp[sampled_cells, :], axis=0).astype(np.float32)
        combined_clusters = _nb_sum(
            sample_cluster[cluster_id[sampled_cells]], axis=0
        ).astype(np.float32)

        exp[i, :]     = combined_exp
        density[i, :] = combined_clusters

    return exp, density


# ──────────────────────────────────────────────────────────────────────────────
# Sampling-parameter generators
# ──────────────────────────────────────────────────────────────────────────────

def _sample_params_gaussian(
    n: int,
    cells_mean: float,
    cells_std: float,
    cells_min: int,
    cells_max: int,
    clusters_mean: float,
    clusters_std: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """Draw ``(cell_count, cluster_count)`` pairs from clipped Gaussian distributions.

    Parameters
    ----------
    n : int
        Number of parameter pairs to draw.
    cells_mean : float
        Mean of cell counts.
    cells_std : float
        Standard deviation of cell counts.
    cells_min : int
        Minimum threshold for cell counts.
    cells_max : int
        Maximum threshold for cell counts.
    clusters_mean : float
        Mean of cluster counts.
    clusters_std : float
        Standard deviation of cluster counts.

    Returns
    -------
    Tuple[np.ndarray, np.ndarray]
        Arrays for cell counts and cluster counts, shape ``(n,)``.
    """
    cell_counts = np.ceil(
        np.clip(
            np.random.normal(cells_mean, cells_std, size=n),
            cells_min, cells_max,
        )
    ).astype(int)
    cluster_counts = np.ceil(
        np.clip(
            np.random.normal(clusters_mean, clusters_std, size=n),
            1, cell_counts,
        )
    ).astype(int)
    return cell_counts, cluster_counts


def _sample_params_uniform(
    n: int,
    cells_min: int,
    cells_max: int,
    clusters_min: int,
    clusters_max: int,
) -> Tuple[np.ndarray, np.ndarray]:
    """Draw ``(cell_count, cluster_count)`` pairs from Uniform distributions.

    Parameters
    ----------
    n : int
        Number of parameter pairs to draw.
    cells_min : int
        Minimum threshold for cell counts.
    cells_max : int
        Maximum threshold for cell counts.
    clusters_min : int
        Minimum threshold for cluster counts.
    clusters_max : int
        Maximum threshold for cluster counts.

    Returns
    -------
    Tuple[np.ndarray, np.ndarray]
        Arrays for cell counts and cluster counts, shape ``(n,)``.
    """
    cell_counts = np.ceil(
        np.random.uniform(cells_min, cells_max, size=n)
    ).astype(int)
    cluster_counts = np.ceil(
        np.clip(
            np.random.uniform(clusters_min, clusters_max, size=n),
            1, cell_counts,
        )
    ).astype(int)
    return cell_counts, cluster_counts


def _sample_params_exponential(
    n: int,
    cells_mean: float,
    cells_min: int,
    cells_max: int,
    clusters_mean: float,
    clusters_std: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """Draw cell counts from Exponential; cluster counts from Gaussian distributions.

    Parameters
    ----------
    n : int
        Number of parameter pairs to draw.
    cells_mean : float
        Mean scale parameter for the exponential distribution.
    cells_min : int
        Minimum threshold for cell counts.
    cells_max : int
        Maximum threshold for cell counts.
    clusters_mean : float
        Mean of cluster counts (Gaussian).
    clusters_std : float
        Standard deviation of cluster counts (Gaussian).

    Returns
    -------
    Tuple[np.ndarray, np.ndarray]
        Arrays for cell counts and cluster counts, shape ``(n,)``.
    """
    cell_counts = np.ceil(
        np.clip(
            np.random.exponential(scale=cells_mean, size=n),
            cells_min, cells_max,
        )
    ).astype(int)
    cluster_counts = np.ceil(
        np.clip(
            np.random.normal(clusters_mean, clusters_std, size=n),
            1, cell_counts,
        )
    ).astype(int)
    return cell_counts, cluster_counts


def _sample_params_lognormal(
    n: int,
    cells_mean: float,
    cells_std: float,
    cells_min: int,
    cells_max: int,
    clusters_mean: float,
    clusters_std: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """Draw cell counts from Log-Normal; cluster counts from Gaussian distributions.

    Parameters
    ----------
    n : int
        Number of parameter pairs to draw.
    cells_mean : float
        Mean of the log-normal distribution.
    cells_std : float
        Standard deviation of the log-normal distribution.
    cells_min : int
        Minimum threshold for cell counts.
    cells_max : int
        Maximum threshold for cell counts.
    clusters_mean : float
        Mean of cluster counts (Gaussian).
    clusters_std : float
        Standard deviation of cluster counts (Gaussian).

    Returns
    -------
    Tuple[np.ndarray, np.ndarray]
        Arrays for cell counts and cluster counts, shape ``(n,)``.
    """
    m = max(cells_mean, 1e-6)
    v = cells_std ** 2
    phi = np.sqrt(v + m ** 2)
    mu_param = np.log(m ** 2 / phi)
    sigma_param = np.sqrt(np.log(phi ** 2 / m ** 2))

    cell_counts = np.ceil(
        np.clip(
            np.random.lognormal(mean=mu_param, sigma=sigma_param, size=n),
            cells_min, cells_max,
        )
    ).astype(int)
    cluster_counts = np.ceil(
        np.clip(
            np.random.normal(clusters_mean, clusters_std, size=n),
            1, cell_counts,
        )
    ).astype(int)
    return cell_counts, cluster_counts


def _get_sample_params(
    n: int,
    method: str,
    cells_mean: float,
    cells_std: float,
    cells_min: int,
    cells_max: int,
    n_clusters: int,
) -> Tuple[np.ndarray, np.ndarray]:
    """Dispatch to the correct sampling-parameter generator.

    Parameters
    ----------
    n : int
        Number of parameter pairs to draw.
    method : str
        One of ``{"gaussian", "uniform", "exponential", "lognormal"}``.
    cells_mean : float
        Control parameter for the mean of cell counts.
    cells_std : float
        Control parameter for the standard deviation of cell counts.
    cells_min : int
        Control parameter for the minimum cell counts.
    cells_max : int
        Control parameter for the maximum cell counts.
    n_clusters : int
        Total number of cell types (used to cap ``clusters_max``).

    Returns
    -------
    Tuple[np.ndarray, np.ndarray]
        Integer arrays of ``(cell_counts, cluster_counts)`` of shape ``(n,)``.
        
    Raises
    ------
    ValueError
        If the chosen sampling method is unrecognized.
    """
    clusters_mean = cells_mean / 2
    clusters_std = cells_std / 2
    clusters_min = max(cells_min, 1)
    clusters_max = min(cells_max // 2, n_clusters)

    if method == "gaussian":
        return _sample_params_gaussian(
            n, cells_mean, cells_std, cells_min, cells_max,
            clusters_mean, clusters_std,
        )
    if method == "uniform":
        return _sample_params_uniform(
            n, cells_min, cells_max, clusters_min, clusters_max,
        )
    if method == "exponential":
        return _sample_params_exponential(
            n, cells_mean, cells_min, cells_max,
            clusters_mean, clusters_std,
        )
    if method == "lognormal":
        return _sample_params_lognormal(
            n, cells_mean, cells_std, cells_min, cells_max,
            clusters_mean, clusters_std,
        )
    raise ValueError(
        f"Unknown cell_sample_method '{method}'. "
        "Choose from {{gaussian, uniform, exponential, lognormal}}."
    )


# ──────────────────────────────────────────────────────────────────────────────
# Cell-type probability initialisation
# ──────────────────────────────────────────────────────────────────────────────

def _init_sample_prob(
    obs: pd.DataFrame,
    celltype_key: str,
) -> Tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray], np.ndarray, np.ndarray]:
    """Compute cluster-level and cell-level sampling probabilities.

    Three cluster-level modes are prepared, following the original
    ``init_sample_prob`` implementation:

    * **unbalance** – proportional to actual cell-type frequencies.
    * **sqrt** – proportional to sqrt of frequencies (mild balancing).
    * **balance** – uniform across cell types.

    Parameters
    ----------
    obs : pd.DataFrame
        Cell-level metadata dataframe. Must contain the *celltype_key* column.
    celltype_key : str
        Column in ``obs`` storing cell-type labels.

    Returns
    -------
    Tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray], np.ndarray, np.ndarray]
        A tuple containing:
        - `cluster_ordered`: Unique integer cluster IDs sorted by frequency.
        - `cluster_id`: Per-cell integer cluster ID, shape ``(n_cells,)``.
        - `cluster_p_dict`: Dictionary mapping mode strings to probability arrays.
        - `cell_p_balanced`: Per-cell balanced sampling weight, shape ``(n_cells,)``.
        - `cluster_names`: Cell-type names ordered to match the integer encoding.
    """
    unique_cts = obs[celltype_key].value_counts()  # descending by count
    unique_cts = unique_cts[unique_cts > 0]
    cluster_names = unique_cts.index.values
    ct2id = {ct: i for i, ct in enumerate(cluster_names)}

    cluster_id = np.array(
        [ct2id[c] for c in obs[celltype_key].values], dtype=np.int32,
    )
    cluster_ordered = np.arange(len(cluster_names), dtype=np.int32)

    counts = unique_cts.values.astype(np.float64)
    p_unbalance = counts / counts.sum()
    p_sqrt = np.sqrt(counts) / np.sqrt(counts).sum()
    p_balance = np.ones(len(counts)) / len(counts)

    cluster_p_dict: Dict[str, np.ndarray] = {
        "unbalance": p_unbalance,
        "sqrt": p_sqrt,
        "balance": p_balance,
    }

    # Balanced cell probability: inversely proportional to cluster frequency
    cell_p = np.array([1.0 / p_unbalance[cid] for cid in cluster_id])
    cell_p = (cell_p / cell_p.sum()).astype(np.float64)

    return cluster_ordered, cluster_id, cluster_p_dict, cell_p, cluster_names


# ──────────────────────────────────────────────────────────────────────────────
# Expression matrix helper
# ──────────────────────────────────────────────────────────────────────────────

def _to_dense_f32(x: Any) -> np.ndarray:
    """Coerce sparse or dense matrix to a C-contiguous float32 ndarray.

    Parameters
    ----------
    x : Union[np.ndarray, scipy.sparse.spmatrix]
        The input expression matrix.

    Returns
    -------
    np.ndarray
        A dense, contiguous array of type float32.
    """
    if issparse(x):
        x = np.asarray(x.todense())
    return np.ascontiguousarray(x, dtype=np.float32)


# ──────────────────────────────────────────────────────────────────────────────
# High-level chunk generator
# ──────────────────────────────────────────────────────────────────────────────

def _generate_chunk(
    sub_obs: pd.DataFrame,
    sub_exp: np.ndarray,
    used_features: np.ndarray,
    n_request: int,
    celltype_key: str,
    cfg: SimulationConfig,
) -> Optional[AnnData]:
    """Generate one chunk of simulated pseudo-spots.

    The returned AnnData stores:

    * ``.X`` — summed expression across all concatenated modalities.
    * ``.var.index`` — prefixed feature names (e.g. ``rna-GAPDH``).
    * ``.obsm["label"]`` — cell-type proportion DataFrame.

    Parameters
    ----------
    sub_obs : pd.DataFrame
        Cell-level metadata for one batch (must contain *celltype_key*).
    sub_exp : np.ndarray
        Dense float32 expression matrix for the batch,
        shape ``(n_cells_in_batch, n_total_features)``.
    used_features : np.ndarray
        Ordered array of all (prefixed) feature names.
    n_request : int
        Number of pseudo-spots to generate in this chunk.
    celltype_key : str
        Cell-type annotation column in *sub_obs*.
    cfg : SimulationConfig
        Configuration object controlling sampling behaviour.

    Returns
    -------
    Optional[AnnData]
        The simulated AnnData object, or ``None`` if the batch contains 
        too few cells to sample from (less than 3).
    """
    if len(sub_obs) < 3:
        return None

    # -- Probabilities -----------------------------------------------------
    (
        cluster_ordered,
        cluster_id,
        cluster_p_dict,
        cell_p_balanced,
        cluster_names,
    ) = _init_sample_prob(sub_obs, celltype_key)
    n_clusters = len(cluster_names)
    cluster_mask = np.eye(n_clusters, dtype=np.float32)

    # -- Balance modes: unbalance / sqrt / balance -------------------------
    balance_modes = ["unbalance", "sqrt", "balance"]
    n_per_mode = n_request // len(balance_modes)
    remainder = n_request - n_per_mode * len(balance_modes)

    max_n = n_per_mode + remainder
    cell_counts, cluster_counts = _get_sample_params(
        n=max_n,
        method=cfg.cell_sample_method,
        cells_mean=cfg.cells_mean,
        cells_std=cfg.cells_std,
        cells_min=cfg.cells_min,
        cells_max=cfg.cells_max,
        n_clusters=n_clusters,
    )
    params_shared = np.column_stack([cell_counts, cluster_counts]).astype(np.int64)

    exp_parts: List[np.ndarray] = []
    label_parts: List[np.ndarray] = []

    for idx, mode in enumerate(balance_modes):
        n_this = n_per_mode + (remainder if idx == 0 else 0)
        if n_this <= 0:
            continue

        cluster_p = cluster_p_dict[mode]

        exp_chunk, labels_chunk = _sample_spots_numba(
            param_list=params_shared[:n_this],  
            cluster_p=cluster_p,
            clusters=cluster_ordered,
            cluster_id=cluster_id,
            sample_exp=sub_exp,
            sample_cluster=cluster_mask,
            cell_p_balanced=cell_p_balanced,
        )

        exp_parts.append(exp_chunk)
        label_parts.append(labels_chunk)

    # -- Concatenate across balance modes ----------------------------------
    sim_exp = np.concatenate(exp_parts, axis=0)
    sim_labels = np.concatenate(label_parts, axis=0)

    # -- Normalise labels to proportions -----------------------------------
    row_sums = sim_labels.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1.0
    sim_proportions = sim_labels / row_sums

    # -- Assemble AnnData --------------------------------------------------
    sm_ad = AnnData(sim_exp)
    sm_ad.var.index = pd.Index(used_features)
    sm_ad.obsm["label"] = pd.DataFrame(
        sim_proportions,
        columns=cluster_names,
        index=sm_ad.obs_names,
    )

    return sm_ad


# ──────────────────────────────────────────────────────────────────────────────
# Public entry point
# ──────────────────────────────────────────────────────────────────────────────

[docs] def simulate( adata: AnnData, vocabs: Dict[str, Dict[str, int]], modality_names: List[str], top_ks: Dict[str, int], cell_types: List[str], *, celltype_key: str, save_dir: str, cfg: SimulationConfig, all_features: Optional[List[str]] = None, context2id: Optional[Dict[str, int]] = None, sim_batch_key: Optional[str] = None, context_key: Optional[str] = None, ) -> Tuple[int, str, str, Optional[str]]: """Batch-proportional pseudo-spot simulation with memmap writing. The input *adata* is a **pre-concatenated** multi-modal AnnData whose ``.X`` contains features from all modalities (columns prefixed by modality name, e.g. ``rna-GAPDH``, ``adt-CD3``). Pseudo-spots are generated independently per batch, with the number of spots proportional to each batch's share of total cells. Results are tokenised on-the-fly and streamed to memory-mapped files. Parameters ---------- adata : AnnData Concatenated multi-modal single-cell reference. ``.obs`` must contain *celltype_key* (and optionally *sim_batch_key*, *context_key*). ``.X`` has shape ``(n_cells, sum_of_features_across_modalities)``. vocabs : Dict[str, Dict[str, int]] Per-modality token vocabularies mapping prefixed feature names to integers. modality_names : List[str] Ordered list of modalities, e.g. ``["rna"]`` or ``["rna", "adt"]``. top_ks : Dict[str, int] Per-modality mapping defining the number of top-ranked features to keep per spot. cell_types : List[str] Ordered list of cell-type names for the label vector. celltype_key : str Column name in ``adata.obs`` storing the cell-type labels. save_dir : str Directory path where the output memmap files will be saved. cfg : SimulationConfig Configuration object containing the simulation hyperparameters. all_features : List[str], optional Optional flat list of **all** prefixed feature names across modalities. If provided, the `adata` is subset to these features for optimization. context2id : Dict[str, int], optional Optional mapping from context labels to integer IDs. sim_batch_key : str, optional Optional column in ``adata.obs`` with identifiers to split the simulated context. context_key : str, optional Optional column in ``adata.obs`` storing context labels. Returns ------- Tuple[int, str, str, Optional[str]] A tuple containing: - `real_total`: Number of valid samples actually written. - `inp_path`: Path to the tokenised-input memmap file. - `lbl_path`: Path to the label memmap file. - `ctx_path`: Path to the context memmap file, or ``None`` if unused. """ total = cfg.total_samples seq_len = sum(top_ks[m] for m in modality_names) use_context = context2id is not None and context_key is not None # Ensure sim_batch_key exists for proportional allocation; if not, use the whole as a single batch if sim_batch_key is None: sim_batch_key = "_sim_batch" adata.obs[sim_batch_key] = "sim_batch_1" inp_path = join(save_dir, "train_inputs.mmap") lbl_path = join(save_dir, "train_labels.mmap") ctx_path = join(save_dir, "train_contexts.mmap") if use_context else None fp_inp = np.memmap( inp_path, dtype="int64", mode="w+", shape=(total, seq_len), ) fp_lbl = np.memmap( lbl_path, dtype="float32", mode="w+", shape=(total, len(cell_types)), ) fp_ctx = ( np.memmap(ctx_path, dtype="int64", mode="w+", shape=(total,)) if use_context else None ) # -- Global Feature Subsetting (Performance Optimization) -------------- if all_features is not None: used_features = np.array(all_features) adata = adata[:, used_features] else: used_features = np.array(adata.var_names) # -- Proportional allocation by batch ---------------------------------- batch_counts = adata.obs[sim_batch_key].value_counts() plans: List[dict] = [] for bid, cnt in batch_counts.items(): n = int((cnt / adata.n_obs) * total) if n <= 0: continue plan: dict = {"id": bid, "n_target": n} if use_context: contexts = adata.obs.loc[ adata.obs[sim_batch_key] == bid, context_key, ].astype(str) plan["context"] = contexts.value_counts().idxmax() plans.append(plan) # Redistribute rounding remainder to the first batch diff = total - sum(p["n_target"] for p in plans) if diff > 0 and plans: plans[0]["n_target"] += diff # -- Iterate batches --------------------------------------------------- ptr = 0 pbar = tqdm(total=total, desc="Simulating") for plan in plans: batch_mask = (adata.obs[sim_batch_key] == plan["id"]).values sub_obs = adata.obs.loc[batch_mask].copy() sub_exp = _to_dense_f32(adata[batch_mask].X) if len(sub_obs) < 3: pbar.update(plan["n_target"]) continue done = 0 while done < plan["n_target"] and ptr < total: remaining = min(plan["n_target"] - done, cfg.batch_request_size) try: chunk = _generate_chunk( sub_obs=sub_obs, sub_exp=sub_exp, used_features=used_features, n_request=remaining + cfg.sim_buffer, celltype_key=celltype_key, cfg=cfg, ) except Exception as exc: print(f" Error batch={plan['id']}: {exc}") break if chunk is None: break if use_context: chunk.uns["_context_label"] = plan.get("context") tok, lab, sid = tokenize_batch( chunk, vocabs, modality_names, top_ks, cell_types, context2id=context2id, context_key=context_key, mode="train", ) if tok is None: del chunk gc.collect() continue vl = min(tok.shape[0], plan["n_target"] - done, total - ptr) fp_inp[ptr : ptr + vl] = tok[:vl] fp_lbl[ptr : ptr + vl] = lab[:vl].astype(np.float32) if use_context and sid is not None: fp_ctx[ptr : ptr + vl] = sid[:vl] ptr += vl done += vl pbar.update(vl) del chunk, tok, lab, sid gc.collect() del sub_obs, sub_exp gc.collect() if ptr >= total: break pbar.close() # -- Flush & cleanup --------------------------------------------------- fp_inp.flush() fp_lbl.flush() if fp_ctx is not None: fp_ctx.flush() del fp_inp, fp_lbl, fp_ctx print(f" Simulation done — {ptr} valid samples written.") return ptr, inp_path, lbl_path, ctx_path