Mouse Isocortex Deconvolution with SpaRank

This tutorial shows how a single SpaRank model, trained on a mouse isocortex single-cell reference, can be directly applied to 27 spatial sections at spot resolution without any per-section retraining.

Both the single-cell reference and the spatial data come from the mouse isocortex. The spatial dataset was generated from single-cell resolution MERFISH slides by overlaying a regular grid to create pseudo-spots, resulting in a total of 27 sections. Datasets are available on Zenodo.

[ ]:
import muon as mu
import os
from os.path import join
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import sparank
from sparank.config import ExpConfig, ModalityConfig, SimulationConfig
from sparank.framework import SpaRank

1. Data

We load the single-cell reference and one representative spatial section. The spatial section is used to obtain the MERFISH gene panel, which is passed to prepare() as prior knowledge. Since this panel was designed to include genes informative for mouse brain cell types, it provides a biologically grounded constraint for marker selection.

[ ]:
data_dir = '../data/mouse_isocortex'

# single-cell reference data
adx_sc_rna = mu.read_h5mu(join(data_dir, 'sc/isocortex.h5mu')).mod['rna']
adx_sc_rna.var_names_make_unique()

# spatial data
_section = 'Zhuang-ABCA-2.016'  # taking one to extract the gene panel
adx_sp_rna = mu.read_h5mu(join(data_dir, f'sp/simulated/{_section}/Isocortex/window=0.12/sp_data.h5mu')).mod['rna']
adx_sp_rna.var_names_make_unique()

gene_panel = adx_sp_rna.var_names.to_list()
len(gene_panel)

2. Training

The cell below shows the training configuration used in the experiments reported in our paper. It is commented out by default, so you only need to run it if you want to reproduce the model from scratch. Training takes about 17 minutes on a single RTX 4090 GPU.

A few settings are specific to this dataset:

  • cell_sample_method="lognormal" : lognormal sampling better matches the long-tailed distribution of cell counts

  • context_key=None : no context conditioning is used

  • marker_features={"rna": gene_panel} : input features are restricted to genes in the MERFISH panel

  • cl_weight=0, mrp_weight=0 : contrastive learning and reconstruction losses are disabled in this setting

[3]:
# cfg = ExpConfig(
#     modalities=[
#         ModalityConfig(name="rna", top_k=500)
#     ],
#     simulation=SimulationConfig(
#         total_samples=1_000_000,
#         batch_request_size=100_000,
#         cells_mean=15,
#         cells_std=10,
#         cell_sample_method='lognormal',
#         batch_key='batch',
#     ),
#     celltype_key='class',
#     batch_key='batch',
#     context_key=None,
#     dim=128,
#     depth=2,
#     heads=4,
#     cl_weight=0.,
#     cl_temperature=0.1,
#     mrp_weight=0.,
#     cls_weight=1.0,
#     epochs=4,
#     num_workers=4,
#     batch_size=128
# )

# model = SpaRank(cfg=cfg, save_dir=f'../outputs/isocortex')
# model.register_modality("rna", adx_sc_rna)
# model.prepare(marker_features={'rna':gene_panel})    # use sp gene panel as a prior
# model.fit()              # train
# model.save()

3. Load pre-trained checkpoint

[4]:
model = SpaRank.load(f'../outputs/isocortex', device="cpu")

4. Quantitative evaluation

We used two metrics for evaluation:

  • JSD (Jensen-Shannon divergence) — lower is better

  • PCC (Pearson correlation coefficient) — higher is better

Both metrics are computed in two orientations:

  • row (spots) : how well each spot’s composition is predicted

  • col (cell types) : how well each cell type’s spatial pattern is recovered

[6]:
from typing import Union

import numpy as np
import pandas as pd
from scipy.spatial.distance import jensenshannon
from scipy.stats import pearsonr

def _to_array(x):
    return x.values if isinstance(x, pd.DataFrame) else np.asarray(x)

def jsd(true: Union[pd.DataFrame, np.ndarray], predicted: Union[pd.DataFrame, np.ndarray]) -> float:
    jsd_vals = jsd_per_col(true, predicted)
    return np.nanmean(jsd_vals)

def jsd_per_col(true: Union[pd.DataFrame, np.ndarray], predicted: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
    t = _to_array(true)
    p = _to_array(predicted)

    n_cols = t.shape[1]
    vals = np.full(n_cols, np.nan, dtype=float)
    for i in range(n_cols):
        a = t[:, i].astype(float)
        b = p[:, i].astype(float)
        sa = a.sum()
        sb = b.sum()
        if sa == 0 or sb == 0:
            vals[i] = np.nan
            continue
        pa = a / sa
        pb = b / sb
        vals[i] = float(jensenshannon(pa, pb, axis=0, base=2))
    return vals


def pcc(true: Union[pd.DataFrame, np.ndarray], predicted: Union[pd.DataFrame, np.ndarray]) -> float:
    r_vals = pcc_per_col(true, predicted)
    return float(np.nanmean(r_vals))


def pcc_per_col(true: Union[pd.DataFrame, np.ndarray], predicted: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
    t = _to_array(true)
    p = _to_array(predicted)

    n_cols = t.shape[1]
    vals = np.full(n_cols, np.nan, dtype=float)
    for i in range(n_cols):
        x = t[:, i]
        y = p[:, i]
        if np.std(x) == 0 or np.std(y) == 0:
            vals[i] = np.nan
            continue
        vals[i], _ = pearsonr(x, y)
    return vals


def eval_pred(targets, pred):

    jsd_mean_row = jsd(targets.T, pred.T)  # by row
    pcc_mean_row = pcc(targets.T, pred.T)
    jsd_mean_col = jsd(targets, pred)  # by row
    pcc_mean_col = pcc(targets, pred)

    d = {
        'jsd-mean-row': jsd_mean_row,
        'pcc-mean-row': pcc_mean_row,
        'jsd-mean-col': jsd_mean_col,
        'pcc-mean-col': pcc_mean_col,
    }
    return d

Deploy to all 27 sections

The trained model is applied to each of the 27 spatial sections in turn. For each section, predicted proportions are aligned to the ground-truth column order and both are re-normalised to sum to 1 before metric computation.

[ ]:
rs = []
for target_section in os.listdir(join(data_dir, 'sp/simulated')):
    adx_sp_rna = mu.read_h5mu(join(data_dir, f'sp/simulated/{target_section}/Isocortex/window=0.12/sp_data.h5mu')).mod['rna']
    preds = model.predict(mod_adatas={'rna':adx_sp_rna})

    trues = pd.DataFrame(adx_sp_rna.obsm['class_proportions'],
                        columns=adx_sp_rna.uns['class_proportion_names']).copy()
    preds = preds.reindex(columns=trues.columns, fill_value=0)

    trues = trues.div(trues.sum(axis=1), axis=0)
    preds = preds.div(preds.sum(axis=1), axis=0)
    _r = eval_pred(trues, preds)
    _r['dataset'] = target_section

    rs.append(_r)
[8]:
pd.DataFrame(rs).head()
[8]:
jsd-mean-row pcc-mean-row jsd-mean-col pcc-mean-col dataset
0 0.234484 0.881491 0.371106 0.768223 Zhuang-ABCA-2.018
1 0.232949 0.867011 0.326393 0.818452 Zhuang-ABCA-2.037
2 0.239517 0.876566 0.354756 0.828769 Zhuang-ABCA-2.050
3 0.239212 0.877254 0.299845 0.798714 Zhuang-ABCA-2.045
4 0.248456 0.891715 0.374022 0.731692 Zhuang-ABCA-2.047
[ ]: