"""Pseudo-spot simulation with batch-wise memmap writing.
Generates synthetic spatial transcriptomics spots by mixing single-cell
profiles from a pre-concatenated multi-modal AnnData. The caller is
expected to prefix feature names (e.g., ``rna-GAPDH``, ``adt-CD3``) and
concatenate modalities along axis-1 before invoking :func:`simulate`.
Numba-accelerated parallel sampling reproduces the logic of the
original SPACEL (https://github.com/QuKunLab/SPACEL/tree/main/SPACEL) ``sample_cell`` kernel.
"""
from __future__ import annotations
import gc
from os.path import join
from typing import Dict, List, Optional, Tuple, Any
import numba as nb
import numpy as np
import pandas as pd
from anndata import AnnData
from numba import jit
from scipy.sparse import issparse
from tqdm import tqdm
from sparank.config import SimulationConfig
from sparank.data.tokenizer import tokenize_batch
# ──────────────────────────────────────────────────────────────────────────────
# Numba helpers
# ──────────────────────────────────────────────────────────────────────────────
@nb.njit
def _nb_apply_along_axis(func1d, axis, arr):
"""``np.apply_along_axis`` equivalent for numba (2-D only).
Parameters
----------
func1d : callable
A 1-D Numba-compatible function to apply.
axis : int
Axis along which to apply the function (0 or 1).
arr : np.ndarray
The 2-D array to process.
Returns
-------
np.ndarray
The result array after applying the function.
"""
assert arr.ndim == 2
assert axis in (0, 1)
if axis == 0:
result = np.empty(arr.shape[1], dtype=arr.dtype)
for i in range(len(result)):
result[i] = func1d(arr[:, i])
else:
result = np.empty(arr.shape[0], dtype=arr.dtype)
for i in range(len(result)):
result[i] = func1d(arr[i, :])
return result
@nb.njit
def _nb_sum(array, axis):
"""Row-wise or column-wise sum for a 2-D array inside numba.
Parameters
----------
array : np.ndarray
The input 2-D array.
axis : int
The axis to sum over (0 for columns, 1 for rows).
Returns
-------
np.ndarray
The summed 1-D array.
"""
return _nb_apply_along_axis(np.sum, axis, array)
# ──────────────────────────────────────────────────────────────────────────────
# Numba-jitted core sampling kernel
# ──────────────────────────────────────────────────────────────────────────────
@jit(nopython=True, parallel=True)
def _sample_spots_numba(
param_list,
cluster_p,
clusters,
cluster_id,
sample_exp,
sample_cluster,
cell_p_balanced,
):
"""Generate pseudo-spots by mixing single cells (parallelised).
Reproduces the **exact** sampling logic of the original SPACEL
``sample_cell`` kernel. The expression matrix may contain features
from arbitrarily many concatenated modalities — the kernel is
agnostic to modality boundaries.
Parameters
----------
param_list : np.ndarray
Shape ``(n_spots, 2)``. ``[:, 0]`` is the number of cells per spot,
and ``[:, 1]`` is the number of cluster types per spot.
cluster_p : np.ndarray
Shape ``(n_cluster_types,)``. Cluster-level sampling probabilities.
clusters : np.ndarray
Shape ``(n_cluster_types,)``. Unique integer cluster IDs (ordered by frequency).
cluster_id : np.ndarray
Shape ``(n_cells,)``. Per-cell integer cluster label.
sample_exp : np.ndarray
Shape ``(n_cells, total_features)``. Expression matrix
(all modalities concatenated along axis-1).
sample_cluster : np.ndarray
Shape ``(n_cluster_types, n_cluster_types)``. Identity matrix used
to produce one-hot composition vectors.
cell_p_balanced : np.ndarray
Shape ``(n_cells,)``. Per-cell balanced sampling weights.
Returns
-------
Tuple[np.ndarray, np.ndarray]
- **exp** (*np.ndarray*): Summed expression per pseudo-spot, shape ``(n_spots, total_features)``.
- **density** (*np.ndarray*): Cell-type composition counts per pseudo-spot, shape ``(n_spots, n_cluster_types)``.
"""
n_spots = len(param_list)
exp = np.empty((n_spots, sample_exp.shape[1]), dtype=np.float32)
density = np.empty((n_spots, sample_cluster.shape[1]), dtype=np.float32)
for i in nb.prange(n_spots):
num_cell = param_list[i, 0]
num_cluster = param_list[i, 1]
# sampling clusters
cum_cluster_p = np.cumsum(cluster_p)
raw_idx = np.searchsorted(
cum_cluster_p,
np.random.rand(num_cluster),
side="right",
)
raw_idx = np.minimum(raw_idx, len(clusters) - 1)
used_clusters = clusters[raw_idx]
cluster_mask = np.zeros(len(cluster_id), dtype=np.bool_)
for c in used_clusters:
cluster_mask = (cluster_id == c) | cluster_mask
used_cell_ind = np.where(cluster_mask)[0]
# in case all-zero spot
if len(used_cell_ind) == 0:
exp[i, :] = np.zeros(sample_exp.shape[1], dtype=np.float32)
density[i, :] = np.zeros(sample_cluster.shape[1], dtype=np.float32)
continue
# sampling cells
used_cell_p = cell_p_balanced[cluster_mask]
used_cell_p = used_cell_p / used_cell_p.sum()
cum_cell_p = np.cumsum(used_cell_p)
cell_idx = np.searchsorted(
cum_cell_p,
np.random.rand(num_cell),
side="right",
)
cell_idx = np.minimum(cell_idx, len(used_cell_ind) - 1)
sampled_cells = used_cell_ind[cell_idx.astype(np.int64)]
combined_exp = _nb_sum(sample_exp[sampled_cells, :], axis=0).astype(np.float32)
combined_clusters = _nb_sum(
sample_cluster[cluster_id[sampled_cells]], axis=0
).astype(np.float32)
exp[i, :] = combined_exp
density[i, :] = combined_clusters
return exp, density
# ──────────────────────────────────────────────────────────────────────────────
# Sampling-parameter generators
# ──────────────────────────────────────────────────────────────────────────────
def _sample_params_gaussian(
n: int,
cells_mean: float,
cells_std: float,
cells_min: int,
cells_max: int,
clusters_mean: float,
clusters_std: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""Draw ``(cell_count, cluster_count)`` pairs from clipped Gaussian distributions.
Parameters
----------
n : int
Number of parameter pairs to draw.
cells_mean : float
Mean of cell counts.
cells_std : float
Standard deviation of cell counts.
cells_min : int
Minimum threshold for cell counts.
cells_max : int
Maximum threshold for cell counts.
clusters_mean : float
Mean of cluster counts.
clusters_std : float
Standard deviation of cluster counts.
Returns
-------
Tuple[np.ndarray, np.ndarray]
Arrays for cell counts and cluster counts, shape ``(n,)``.
"""
cell_counts = np.ceil(
np.clip(
np.random.normal(cells_mean, cells_std, size=n),
cells_min, cells_max,
)
).astype(int)
cluster_counts = np.ceil(
np.clip(
np.random.normal(clusters_mean, clusters_std, size=n),
1, cell_counts,
)
).astype(int)
return cell_counts, cluster_counts
def _sample_params_uniform(
n: int,
cells_min: int,
cells_max: int,
clusters_min: int,
clusters_max: int,
) -> Tuple[np.ndarray, np.ndarray]:
"""Draw ``(cell_count, cluster_count)`` pairs from Uniform distributions.
Parameters
----------
n : int
Number of parameter pairs to draw.
cells_min : int
Minimum threshold for cell counts.
cells_max : int
Maximum threshold for cell counts.
clusters_min : int
Minimum threshold for cluster counts.
clusters_max : int
Maximum threshold for cluster counts.
Returns
-------
Tuple[np.ndarray, np.ndarray]
Arrays for cell counts and cluster counts, shape ``(n,)``.
"""
cell_counts = np.ceil(
np.random.uniform(cells_min, cells_max, size=n)
).astype(int)
cluster_counts = np.ceil(
np.clip(
np.random.uniform(clusters_min, clusters_max, size=n),
1, cell_counts,
)
).astype(int)
return cell_counts, cluster_counts
def _sample_params_exponential(
n: int,
cells_mean: float,
cells_min: int,
cells_max: int,
clusters_mean: float,
clusters_std: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""Draw cell counts from Exponential; cluster counts from Gaussian distributions.
Parameters
----------
n : int
Number of parameter pairs to draw.
cells_mean : float
Mean scale parameter for the exponential distribution.
cells_min : int
Minimum threshold for cell counts.
cells_max : int
Maximum threshold for cell counts.
clusters_mean : float
Mean of cluster counts (Gaussian).
clusters_std : float
Standard deviation of cluster counts (Gaussian).
Returns
-------
Tuple[np.ndarray, np.ndarray]
Arrays for cell counts and cluster counts, shape ``(n,)``.
"""
cell_counts = np.ceil(
np.clip(
np.random.exponential(scale=cells_mean, size=n),
cells_min, cells_max,
)
).astype(int)
cluster_counts = np.ceil(
np.clip(
np.random.normal(clusters_mean, clusters_std, size=n),
1, cell_counts,
)
).astype(int)
return cell_counts, cluster_counts
def _sample_params_lognormal(
n: int,
cells_mean: float,
cells_std: float,
cells_min: int,
cells_max: int,
clusters_mean: float,
clusters_std: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""Draw cell counts from Log-Normal; cluster counts from Gaussian distributions.
Parameters
----------
n : int
Number of parameter pairs to draw.
cells_mean : float
Mean of the log-normal distribution.
cells_std : float
Standard deviation of the log-normal distribution.
cells_min : int
Minimum threshold for cell counts.
cells_max : int
Maximum threshold for cell counts.
clusters_mean : float
Mean of cluster counts (Gaussian).
clusters_std : float
Standard deviation of cluster counts (Gaussian).
Returns
-------
Tuple[np.ndarray, np.ndarray]
Arrays for cell counts and cluster counts, shape ``(n,)``.
"""
m = max(cells_mean, 1e-6)
v = cells_std ** 2
phi = np.sqrt(v + m ** 2)
mu_param = np.log(m ** 2 / phi)
sigma_param = np.sqrt(np.log(phi ** 2 / m ** 2))
cell_counts = np.ceil(
np.clip(
np.random.lognormal(mean=mu_param, sigma=sigma_param, size=n),
cells_min, cells_max,
)
).astype(int)
cluster_counts = np.ceil(
np.clip(
np.random.normal(clusters_mean, clusters_std, size=n),
1, cell_counts,
)
).astype(int)
return cell_counts, cluster_counts
def _get_sample_params(
n: int,
method: str,
cells_mean: float,
cells_std: float,
cells_min: int,
cells_max: int,
n_clusters: int,
) -> Tuple[np.ndarray, np.ndarray]:
"""Dispatch to the correct sampling-parameter generator.
Parameters
----------
n : int
Number of parameter pairs to draw.
method : str
One of ``{"gaussian", "uniform", "exponential", "lognormal"}``.
cells_mean : float
Control parameter for the mean of cell counts.
cells_std : float
Control parameter for the standard deviation of cell counts.
cells_min : int
Control parameter for the minimum cell counts.
cells_max : int
Control parameter for the maximum cell counts.
n_clusters : int
Total number of cell types (used to cap ``clusters_max``).
Returns
-------
Tuple[np.ndarray, np.ndarray]
Integer arrays of ``(cell_counts, cluster_counts)`` of shape ``(n,)``.
Raises
------
ValueError
If the chosen sampling method is unrecognized.
"""
clusters_mean = cells_mean / 2
clusters_std = cells_std / 2
clusters_min = max(cells_min, 1)
clusters_max = min(cells_max // 2, n_clusters)
if method == "gaussian":
return _sample_params_gaussian(
n, cells_mean, cells_std, cells_min, cells_max,
clusters_mean, clusters_std,
)
if method == "uniform":
return _sample_params_uniform(
n, cells_min, cells_max, clusters_min, clusters_max,
)
if method == "exponential":
return _sample_params_exponential(
n, cells_mean, cells_min, cells_max,
clusters_mean, clusters_std,
)
if method == "lognormal":
return _sample_params_lognormal(
n, cells_mean, cells_std, cells_min, cells_max,
clusters_mean, clusters_std,
)
raise ValueError(
f"Unknown cell_sample_method '{method}'. "
"Choose from {{gaussian, uniform, exponential, lognormal}}."
)
# ──────────────────────────────────────────────────────────────────────────────
# Cell-type probability initialisation
# ──────────────────────────────────────────────────────────────────────────────
def _init_sample_prob(
obs: pd.DataFrame,
celltype_key: str,
) -> Tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray], np.ndarray, np.ndarray]:
"""Compute cluster-level and cell-level sampling probabilities.
Three cluster-level modes are prepared, following the original
``init_sample_prob`` implementation:
* **unbalance** – proportional to actual cell-type frequencies.
* **sqrt** – proportional to sqrt of frequencies (mild balancing).
* **balance** – uniform across cell types.
Parameters
----------
obs : pd.DataFrame
Cell-level metadata dataframe. Must contain the *celltype_key* column.
celltype_key : str
Column in ``obs`` storing cell-type labels.
Returns
-------
Tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray], np.ndarray, np.ndarray]
A tuple containing:
- `cluster_ordered`: Unique integer cluster IDs sorted by frequency.
- `cluster_id`: Per-cell integer cluster ID, shape ``(n_cells,)``.
- `cluster_p_dict`: Dictionary mapping mode strings to probability arrays.
- `cell_p_balanced`: Per-cell balanced sampling weight, shape ``(n_cells,)``.
- `cluster_names`: Cell-type names ordered to match the integer encoding.
"""
unique_cts = obs[celltype_key].value_counts() # descending by count
unique_cts = unique_cts[unique_cts > 0]
cluster_names = unique_cts.index.values
ct2id = {ct: i for i, ct in enumerate(cluster_names)}
cluster_id = np.array(
[ct2id[c] for c in obs[celltype_key].values], dtype=np.int32,
)
cluster_ordered = np.arange(len(cluster_names), dtype=np.int32)
counts = unique_cts.values.astype(np.float64)
p_unbalance = counts / counts.sum()
p_sqrt = np.sqrt(counts) / np.sqrt(counts).sum()
p_balance = np.ones(len(counts)) / len(counts)
cluster_p_dict: Dict[str, np.ndarray] = {
"unbalance": p_unbalance,
"sqrt": p_sqrt,
"balance": p_balance,
}
# Balanced cell probability: inversely proportional to cluster frequency
cell_p = np.array([1.0 / p_unbalance[cid] for cid in cluster_id])
cell_p = (cell_p / cell_p.sum()).astype(np.float64)
return cluster_ordered, cluster_id, cluster_p_dict, cell_p, cluster_names
# ──────────────────────────────────────────────────────────────────────────────
# Expression matrix helper
# ──────────────────────────────────────────────────────────────────────────────
def _to_dense_f32(x: Any) -> np.ndarray:
"""Coerce sparse or dense matrix to a C-contiguous float32 ndarray.
Parameters
----------
x : Union[np.ndarray, scipy.sparse.spmatrix]
The input expression matrix.
Returns
-------
np.ndarray
A dense, contiguous array of type float32.
"""
if issparse(x):
x = np.asarray(x.todense())
return np.ascontiguousarray(x, dtype=np.float32)
# ──────────────────────────────────────────────────────────────────────────────
# High-level chunk generator
# ──────────────────────────────────────────────────────────────────────────────
def _generate_chunk(
sub_obs: pd.DataFrame,
sub_exp: np.ndarray,
used_features: np.ndarray,
n_request: int,
celltype_key: str,
cfg: SimulationConfig,
) -> Optional[AnnData]:
"""Generate one chunk of simulated pseudo-spots.
The returned AnnData stores:
* ``.X`` — summed expression across all concatenated modalities.
* ``.var.index`` — prefixed feature names (e.g. ``rna-GAPDH``).
* ``.obsm["label"]`` — cell-type proportion DataFrame.
Parameters
----------
sub_obs : pd.DataFrame
Cell-level metadata for one batch (must contain *celltype_key*).
sub_exp : np.ndarray
Dense float32 expression matrix for the batch,
shape ``(n_cells_in_batch, n_total_features)``.
used_features : np.ndarray
Ordered array of all (prefixed) feature names.
n_request : int
Number of pseudo-spots to generate in this chunk.
celltype_key : str
Cell-type annotation column in *sub_obs*.
cfg : SimulationConfig
Configuration object controlling sampling behaviour.
Returns
-------
Optional[AnnData]
The simulated AnnData object, or ``None`` if the batch contains
too few cells to sample from (less than 3).
"""
if len(sub_obs) < 3:
return None
# -- Probabilities -----------------------------------------------------
(
cluster_ordered,
cluster_id,
cluster_p_dict,
cell_p_balanced,
cluster_names,
) = _init_sample_prob(sub_obs, celltype_key)
n_clusters = len(cluster_names)
cluster_mask = np.eye(n_clusters, dtype=np.float32)
# -- Balance modes: unbalance / sqrt / balance -------------------------
balance_modes = ["unbalance", "sqrt", "balance"]
n_per_mode = n_request // len(balance_modes)
remainder = n_request - n_per_mode * len(balance_modes)
max_n = n_per_mode + remainder
cell_counts, cluster_counts = _get_sample_params(
n=max_n,
method=cfg.cell_sample_method,
cells_mean=cfg.cells_mean,
cells_std=cfg.cells_std,
cells_min=cfg.cells_min,
cells_max=cfg.cells_max,
n_clusters=n_clusters,
)
params_shared = np.column_stack([cell_counts, cluster_counts]).astype(np.int64)
exp_parts: List[np.ndarray] = []
label_parts: List[np.ndarray] = []
for idx, mode in enumerate(balance_modes):
n_this = n_per_mode + (remainder if idx == 0 else 0)
if n_this <= 0:
continue
cluster_p = cluster_p_dict[mode]
exp_chunk, labels_chunk = _sample_spots_numba(
param_list=params_shared[:n_this],
cluster_p=cluster_p,
clusters=cluster_ordered,
cluster_id=cluster_id,
sample_exp=sub_exp,
sample_cluster=cluster_mask,
cell_p_balanced=cell_p_balanced,
)
exp_parts.append(exp_chunk)
label_parts.append(labels_chunk)
# -- Concatenate across balance modes ----------------------------------
sim_exp = np.concatenate(exp_parts, axis=0)
sim_labels = np.concatenate(label_parts, axis=0)
# -- Normalise labels to proportions -----------------------------------
row_sums = sim_labels.sum(axis=1, keepdims=True)
row_sums[row_sums == 0] = 1.0
sim_proportions = sim_labels / row_sums
# -- Assemble AnnData --------------------------------------------------
sm_ad = AnnData(sim_exp)
sm_ad.var.index = pd.Index(used_features)
sm_ad.obsm["label"] = pd.DataFrame(
sim_proportions,
columns=cluster_names,
index=sm_ad.obs_names,
)
return sm_ad
# ──────────────────────────────────────────────────────────────────────────────
# Public entry point
# ──────────────────────────────────────────────────────────────────────────────
[docs]
def simulate(
adata: AnnData,
vocabs: Dict[str, Dict[str, int]],
modality_names: List[str],
top_ks: Dict[str, int],
cell_types: List[str],
*,
celltype_key: str,
save_dir: str,
cfg: SimulationConfig,
all_features: Optional[List[str]] = None,
context2id: Optional[Dict[str, int]] = None,
sim_batch_key: Optional[str] = None,
context_key: Optional[str] = None,
) -> Tuple[int, str, str, Optional[str]]:
"""Batch-proportional pseudo-spot simulation with memmap writing.
The input *adata* is a **pre-concatenated** multi-modal AnnData
whose ``.X`` contains features from all modalities (columns prefixed
by modality name, e.g. ``rna-GAPDH``, ``adt-CD3``). Pseudo-spots
are generated independently per batch, with the number of spots
proportional to each batch's share of total cells. Results are
tokenised on-the-fly and streamed to memory-mapped files.
Parameters
----------
adata : AnnData
Concatenated multi-modal single-cell reference. ``.obs`` must
contain *celltype_key* (and optionally *sim_batch_key*,
*context_key*). ``.X`` has shape
``(n_cells, sum_of_features_across_modalities)``.
vocabs : Dict[str, Dict[str, int]]
Per-modality token vocabularies mapping prefixed feature names to integers.
modality_names : List[str]
Ordered list of modalities, e.g. ``["rna"]`` or ``["rna", "adt"]``.
top_ks : Dict[str, int]
Per-modality mapping defining the number of top-ranked features to keep per spot.
cell_types : List[str]
Ordered list of cell-type names for the label vector.
celltype_key : str
Column name in ``adata.obs`` storing the cell-type labels.
save_dir : str
Directory path where the output memmap files will be saved.
cfg : SimulationConfig
Configuration object containing the simulation hyperparameters.
all_features : List[str], optional
Optional flat list of **all** prefixed feature names across modalities.
If provided, the `adata` is subset to these features for optimization.
context2id : Dict[str, int], optional
Optional mapping from context labels to integer IDs.
sim_batch_key : str, optional
Optional column in ``adata.obs`` with identifiers to split the simulated context.
context_key : str, optional
Optional column in ``adata.obs`` storing context labels.
Returns
-------
Tuple[int, str, str, Optional[str]]
A tuple containing:
- `real_total`: Number of valid samples actually written.
- `inp_path`: Path to the tokenised-input memmap file.
- `lbl_path`: Path to the label memmap file.
- `ctx_path`: Path to the context memmap file, or ``None`` if unused.
"""
total = cfg.total_samples
seq_len = sum(top_ks[m] for m in modality_names)
use_context = context2id is not None and context_key is not None
# Ensure sim_batch_key exists for proportional allocation; if not, use the whole as a single batch
if sim_batch_key is None:
sim_batch_key = "_sim_batch"
adata.obs[sim_batch_key] = "sim_batch_1"
inp_path = join(save_dir, "train_inputs.mmap")
lbl_path = join(save_dir, "train_labels.mmap")
ctx_path = join(save_dir, "train_contexts.mmap") if use_context else None
fp_inp = np.memmap(
inp_path, dtype="int64", mode="w+", shape=(total, seq_len),
)
fp_lbl = np.memmap(
lbl_path, dtype="float32", mode="w+", shape=(total, len(cell_types)),
)
fp_ctx = (
np.memmap(ctx_path, dtype="int64", mode="w+", shape=(total,))
if use_context else None
)
# -- Global Feature Subsetting (Performance Optimization) --------------
if all_features is not None:
used_features = np.array(all_features)
adata = adata[:, used_features]
else:
used_features = np.array(adata.var_names)
# -- Proportional allocation by batch ----------------------------------
batch_counts = adata.obs[sim_batch_key].value_counts()
plans: List[dict] = []
for bid, cnt in batch_counts.items():
n = int((cnt / adata.n_obs) * total)
if n <= 0:
continue
plan: dict = {"id": bid, "n_target": n}
if use_context:
contexts = adata.obs.loc[
adata.obs[sim_batch_key] == bid, context_key,
].astype(str)
plan["context"] = contexts.value_counts().idxmax()
plans.append(plan)
# Redistribute rounding remainder to the first batch
diff = total - sum(p["n_target"] for p in plans)
if diff > 0 and plans:
plans[0]["n_target"] += diff
# -- Iterate batches ---------------------------------------------------
ptr = 0
pbar = tqdm(total=total, desc="Simulating")
for plan in plans:
batch_mask = (adata.obs[sim_batch_key] == plan["id"]).values
sub_obs = adata.obs.loc[batch_mask].copy()
sub_exp = _to_dense_f32(adata[batch_mask].X)
if len(sub_obs) < 3:
pbar.update(plan["n_target"])
continue
done = 0
while done < plan["n_target"] and ptr < total:
remaining = min(plan["n_target"] - done, cfg.batch_request_size)
try:
chunk = _generate_chunk(
sub_obs=sub_obs,
sub_exp=sub_exp,
used_features=used_features,
n_request=remaining + cfg.sim_buffer,
celltype_key=celltype_key,
cfg=cfg,
)
except Exception as exc:
print(f" Error batch={plan['id']}: {exc}")
break
if chunk is None:
break
if use_context:
chunk.uns["_context_label"] = plan.get("context")
tok, lab, sid = tokenize_batch(
chunk,
vocabs,
modality_names,
top_ks,
cell_types,
context2id=context2id,
context_key=context_key,
mode="train",
)
if tok is None:
del chunk
gc.collect()
continue
vl = min(tok.shape[0], plan["n_target"] - done, total - ptr)
fp_inp[ptr : ptr + vl] = tok[:vl]
fp_lbl[ptr : ptr + vl] = lab[:vl].astype(np.float32)
if use_context and sid is not None:
fp_ctx[ptr : ptr + vl] = sid[:vl]
ptr += vl
done += vl
pbar.update(vl)
del chunk, tok, lab, sid
gc.collect()
del sub_obs, sub_exp
gc.collect()
if ptr >= total:
break
pbar.close()
# -- Flush & cleanup ---------------------------------------------------
fp_inp.flush()
fp_lbl.flush()
if fp_ctx is not None:
fp_ctx.flush()
del fp_inp, fp_lbl, fp_ctx
print(f" Simulation done — {ptr} valid samples written.")
return ptr, inp_path, lbl_path, ctx_path