"""Gene-level preprocessing: normalization and marker gene detection."""
from __future__ import annotations
from typing import List, Optional, Union
import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
[docs]
def normalize_rna(
adata: AnnData,
target_sum: float = 1e4,
layer_key: str = "log1p",
) -> AnnData:
"""Library-size normalisation followed by log1p, stored as a layer.
Parameters
----------
adata : AnnData
AnnData object (modified in-place; the new layer is added).
target_sum : float, default 1e4
Target total counts per cell before log-transform.
layer_key : str, default "log1p"
Name of the layer where normalised values are stored.
Returns
-------
AnnData
The same `adata` object with ``adata.layers[layer_key]`` populated.
"""
tmp = adata.copy()
sc.pp.normalize_total(tmp, target_sum=target_sum)
sc.pp.log1p(tmp)
adata.layers[layer_key] = tmp.X
return adata
[docs]
def find_sc_markers(
adata: AnnData,
celltype_key: str,
batch_key: Optional[str] = None,
layer: str = "log1p",
deg_method: str = "wilcoxon",
log2fc_min: float = 0.5,
pval_cutoff: float = 0.01,
n_top_markers: int = 200,
pct_diff: Optional[float] = None,
pct_min: float = 0.1,
) -> np.ndarray:
"""Batch-aware marker gene detection using scanpy's rank_genes_groups.
When `batch_key` is given, differential expression is run independently
in each batch and the union of per-batch marker sets is returned.
Parameters
----------
adata : AnnData
Annotated single-cell reference dataset.
celltype_key : str
Column in ``adata.obs`` storing cell-type labels.
batch_key : str, optional
Optional column for batch-aware DE. ``None`` indicates global mode.
layer : str, default "log1p"
Layer in `adata` used as expression input.
deg_method : str, default "wilcoxon"
Statistical method forwarded to ``sc.tl.rank_genes_groups``.
log2fc_min : float, default 0.5
Minimum log2 fold-change threshold.
pval_cutoff : float, default 0.01
Adjusted p-value cutoff.
n_top_markers : int, default 200
Maximum number of markers retained per cell type per batch.
pct_diff : float, optional
If set, additionally filter by ``(pct_group − pct_rest) > pct_diff``.
pct_min : float, default 0.1
Minimum fraction of cells in the group expressing the gene.
Returns
-------
np.ndarray
Sorted array of unique marker gene names across all batches.
"""
print(
f"### Finding marker genes | mode={'batch-wise' if batch_key else 'global'}"
)
batch_list: List[Union[str, int]] = (
adata.obs[batch_key].unique().tolist() if batch_key else ["all_data"]
)
if batch_key:
print(f" Batches ({len(batch_list)}): {batch_list}")
all_dfs: List[pd.DataFrame] = []
for bid in batch_list:
sub = (
adata[adata.obs[batch_key] == bid].copy()
if batch_key
else adata.copy()
)
if sub.n_obs < 10:
print(f" skip batch {bid}: {sub.n_obs} cells")
continue
if sub.obs[celltype_key].nunique() < 2:
print(f" skip batch {bid}: <2 cell types")
continue
# Drop cell types with <=1 cell
counts = sub.obs[celltype_key].value_counts()
rare = counts[counts <= 1].index.tolist()
if rare:
sub = sub[~sub.obs[celltype_key].isin(rare)].copy()
if sub.obs[celltype_key].nunique() < 2:
continue
try:
sc.tl.rank_genes_groups(
sub,
groupby=celltype_key,
pts=True,
layer=layer,
use_raw=False,
method=deg_method or "wilcoxon",
)
except Exception as exc:
print(f" [batch {bid}] rank_genes_groups error: {exc}")
continue
for ct in sub.obs[celltype_key].unique():
try:
df = sc.get.rank_genes_groups_df(
sub, group=ct, pval_cutoff=pval_cutoff, log2fc_min=log2fc_min
)
except KeyError:
continue
if df.empty:
continue
df.index = df["names"]
df[celltype_key] = ct
df["batch_source"] = str(bid)
# Optional pct_diff filter
if pct_diff is not None:
pts_g = sub.uns["rank_genes_groups"]["pts"][ct]
pts_r = sub.uns["rank_genes_groups"]["pts_rest"][ct]
shared = np.intersect1d(df.index, pts_g.index)
keep = pts_g.loc[shared].index[(pts_g.loc[shared] - pts_r.loc[shared]) > pct_diff]
df = df.loc[np.intersect1d(keep, df.index)]
if "pct_nz_group" in df.columns:
df = df[df["pct_nz_group"] > pct_min]
df = df.sort_values("logfoldchanges", ascending=False).iloc[:n_top_markers]
all_dfs.append(df)
if not all_dfs:
print("!!! No markers found.")
return np.array([])
merged = pd.concat(all_dfs, axis=0)
genes = np.unique(merged["names"].values)
print(f" Total unique markers: {len(genes)}")
return genes