{ "cells": [ { "cell_type": "markdown", "id": "e45d2fe6", "metadata": {}, "source": [ "### Mouse Isocortex Deconvolution with SpaRank\n", "\n", "This tutorial shows how a single SpaRank model, trained on a mouse isocortex single-cell reference, can be directly applied to 27 spatial sections at spot resolution without any per-section retraining.\n", "\n", "Both the single-cell reference and the spatial data come from the mouse isocortex. The spatial dataset was generated from single-cell resolution MERFISH slides by overlaying a regular grid to create pseudo-spots, resulting in a total of 27 sections. Datasets are available on [Zenodo](https://zenodo.org/records/19691472)." ] }, { "cell_type": "code", "execution_count": null, "id": "ce568d72", "metadata": {}, "outputs": [], "source": [ "import muon as mu\n", "import os\n", "from os.path import join\n", "import scanpy as sc\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "import sparank\n", "from sparank.config import ExpConfig, ModalityConfig, SimulationConfig\n", "from sparank.framework import SpaRank" ] }, { "cell_type": "markdown", "id": "51564cf8", "metadata": {}, "source": [ "#### 1. Data\n", "\n", "We load the single-cell reference and one representative spatial section. The spatial section is used to obtain the MERFISH gene panel, which is passed to `prepare()` as prior knowledge. Since this panel was designed to include genes informative for mouse brain cell types, it provides a biologically grounded constraint for marker selection." ] }, { "cell_type": "code", "execution_count": null, "id": "dd2700ac", "metadata": {}, "outputs": [], "source": [ "data_dir = '../data/mouse_isocortex'\n", "\n", "# single-cell reference data\n", "adx_sc_rna = mu.read_h5mu(join(data_dir, 'sc/isocortex.h5mu')).mod['rna']\n", "adx_sc_rna.var_names_make_unique()\n", "\n", "# spatial data \n", "_section = 'Zhuang-ABCA-2.016' # taking one to extract the gene panel\n", "adx_sp_rna = mu.read_h5mu(join(data_dir, f'sp/simulated/{_section}/Isocortex/window=0.12/sp_data.h5mu')).mod['rna']\n", "adx_sp_rna.var_names_make_unique()\n", "\n", "gene_panel = adx_sp_rna.var_names.to_list()\n", "len(gene_panel)" ] }, { "cell_type": "markdown", "id": "badd7b90", "metadata": {}, "source": [ "#### 2. Training\n", "\n", "The cell below shows the training configuration used in the experiments reported in our paper. It is commented out by default, so you only need to run it if you want to reproduce the model from scratch. Training takes about 17 minutes on a single RTX 4090 GPU.\n", "\n", "A few settings are specific to this dataset:\n", "- `cell_sample_method=\"lognormal\"` : lognormal sampling better matches the long-tailed distribution of cell counts\n", "- `context_key=None` : no context conditioning is used\n", "- `marker_features={\"rna\": gene_panel}` : input features are restricted to genes in the MERFISH panel\n", "- `cl_weight=0, mrp_weight=0` : contrastive learning and reconstruction losses are disabled in this setting" ] }, { "cell_type": "code", "execution_count": 3, "id": "6e85b109", "metadata": {}, "outputs": [], "source": [ "# cfg = ExpConfig(\n", "# modalities=[\n", "# ModalityConfig(name=\"rna\", top_k=500)\n", "# ],\n", "# simulation=SimulationConfig(\n", "# total_samples=1_000_000, \n", "# batch_request_size=100_000, \n", "# cells_mean=15, \n", "# cells_std=10,\n", "# cell_sample_method='lognormal',\n", "# batch_key='batch',\n", "# ),\n", "# celltype_key='class',\n", "# batch_key='batch',\n", "# context_key=None,\n", "# dim=128,\n", "# depth=2,\n", "# heads=4,\n", "# cl_weight=0.,\n", "# cl_temperature=0.1,\n", "# mrp_weight=0.,\n", "# cls_weight=1.0,\n", "# epochs=4,\n", "# num_workers=4,\n", "# batch_size=128\n", "# )\n", "\n", "# model = SpaRank(cfg=cfg, save_dir=f'../outputs/isocortex')\n", "# model.register_modality(\"rna\", adx_sc_rna)\n", "# model.prepare(marker_features={'rna':gene_panel}) # use sp gene panel as a prior\n", "# model.fit() # train\n", "# model.save()" ] }, { "cell_type": "markdown", "id": "5dafb4a9", "metadata": {}, "source": [ "#### 3. Load pre-trained checkpoint" ] }, { "cell_type": "code", "execution_count": 4, "id": "cc28528a", "metadata": {}, "outputs": [], "source": [ "model = SpaRank.load(f'../outputs/isocortex', device=\"cpu\")" ] }, { "cell_type": "markdown", "id": "1c53a359", "metadata": {}, "source": [ "#### 4. Quantitative evaluation\n", "\n", "We used two metrics for evaluation:\n", "\n", "- **JSD** (Jensen-Shannon divergence) — lower is better\n", "- **PCC** (Pearson correlation coefficient) — higher is better\n", "\n", "Both metrics are computed in two orientations:\n", "- **row (spots)** : how well each spot's composition is predicted\n", "- **col (cell types)** : how well each cell type's spatial pattern is recovered" ] }, { "cell_type": "code", "execution_count": 6, "id": "ce0a73e4", "metadata": {}, "outputs": [], "source": [ "from typing import Union\n", "\n", "import numpy as np\n", "import pandas as pd\n", "from scipy.spatial.distance import jensenshannon\n", "from scipy.stats import pearsonr\n", "\n", "def _to_array(x):\n", " return x.values if isinstance(x, pd.DataFrame) else np.asarray(x)\n", "\n", "def jsd(true: Union[pd.DataFrame, np.ndarray], predicted: Union[pd.DataFrame, np.ndarray]) -> float:\n", " jsd_vals = jsd_per_col(true, predicted)\n", " return np.nanmean(jsd_vals)\n", "\n", "def jsd_per_col(true: Union[pd.DataFrame, np.ndarray], predicted: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:\n", " t = _to_array(true)\n", " p = _to_array(predicted)\n", "\n", " n_cols = t.shape[1]\n", " vals = np.full(n_cols, np.nan, dtype=float)\n", " for i in range(n_cols):\n", " a = t[:, i].astype(float)\n", " b = p[:, i].astype(float)\n", " sa = a.sum()\n", " sb = b.sum()\n", " if sa == 0 or sb == 0:\n", " vals[i] = np.nan\n", " continue\n", " pa = a / sa\n", " pb = b / sb\n", " vals[i] = float(jensenshannon(pa, pb, axis=0, base=2))\n", " return vals\n", "\n", "\n", "def pcc(true: Union[pd.DataFrame, np.ndarray], predicted: Union[pd.DataFrame, np.ndarray]) -> float:\n", " r_vals = pcc_per_col(true, predicted)\n", " return float(np.nanmean(r_vals))\n", "\n", "\n", "def pcc_per_col(true: Union[pd.DataFrame, np.ndarray], predicted: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:\n", " t = _to_array(true)\n", " p = _to_array(predicted)\n", "\n", " n_cols = t.shape[1]\n", " vals = np.full(n_cols, np.nan, dtype=float)\n", " for i in range(n_cols):\n", " x = t[:, i]\n", " y = p[:, i]\n", " if np.std(x) == 0 or np.std(y) == 0:\n", " vals[i] = np.nan\n", " continue\n", " vals[i], _ = pearsonr(x, y)\n", " return vals\n", "\n", "\n", "def eval_pred(targets, pred):\n", " \n", " jsd_mean_row = jsd(targets.T, pred.T) # by row\n", " pcc_mean_row = pcc(targets.T, pred.T)\n", " jsd_mean_col = jsd(targets, pred) # by row\n", " pcc_mean_col = pcc(targets, pred)\n", "\n", " d = {\n", " 'jsd-mean-row': jsd_mean_row,\n", " 'pcc-mean-row': pcc_mean_row,\n", " 'jsd-mean-col': jsd_mean_col,\n", " 'pcc-mean-col': pcc_mean_col,\n", " }\n", " return d" ] }, { "cell_type": "markdown", "id": "06fbea19", "metadata": {}, "source": [ "##### Deploy to all 27 sections\n", "\n", "The trained model is applied to each of the 27 spatial sections in turn. For each section, predicted proportions are aligned to the ground-truth column order and both are re-normalised to sum to 1 before metric computation." ] }, { "cell_type": "code", "execution_count": null, "id": "9cb3b499", "metadata": { "scrolled": true }, "outputs": [], "source": [ "rs = []\n", "for target_section in os.listdir(join(data_dir, 'sp/simulated')):\n", " adx_sp_rna = mu.read_h5mu(join(data_dir, f'sp/simulated/{target_section}/Isocortex/window=0.12/sp_data.h5mu')).mod['rna']\n", " preds = model.predict(mod_adatas={'rna':adx_sp_rna})\n", "\n", " trues = pd.DataFrame(adx_sp_rna.obsm['class_proportions'], \n", " columns=adx_sp_rna.uns['class_proportion_names']).copy()\n", " preds = preds.reindex(columns=trues.columns, fill_value=0)\n", " \n", " trues = trues.div(trues.sum(axis=1), axis=0)\n", " preds = preds.div(preds.sum(axis=1), axis=0)\n", " _r = eval_pred(trues, preds)\n", " _r['dataset'] = target_section\n", " \n", " rs.append(_r)" ] }, { "cell_type": "code", "execution_count": 8, "id": "30e9f5ed-3621-4331-bf4a-e823b5cb279f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
jsd-mean-rowpcc-mean-rowjsd-mean-colpcc-mean-coldataset
00.2344840.8814910.3711060.768223Zhuang-ABCA-2.018
10.2329490.8670110.3263930.818452Zhuang-ABCA-2.037
20.2395170.8765660.3547560.828769Zhuang-ABCA-2.050
30.2392120.8772540.2998450.798714Zhuang-ABCA-2.045
40.2484560.8917150.3740220.731692Zhuang-ABCA-2.047
\n", "
" ], "text/plain": [ " jsd-mean-row pcc-mean-row jsd-mean-col pcc-mean-col dataset\n", "0 0.234484 0.881491 0.371106 0.768223 Zhuang-ABCA-2.018\n", "1 0.232949 0.867011 0.326393 0.818452 Zhuang-ABCA-2.037\n", "2 0.239517 0.876566 0.354756 0.828769 Zhuang-ABCA-2.050\n", "3 0.239212 0.877254 0.299845 0.798714 Zhuang-ABCA-2.045\n", "4 0.248456 0.891715 0.374022 0.731692 Zhuang-ABCA-2.047" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(rs).head()" ] }, { "cell_type": "code", "execution_count": null, "id": "d4b9ae0e-c615-4232-9c06-6973f1edc01d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python (scvi)", "language": "python", "name": "scvi-env" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.11" } }, "nbformat": 4, "nbformat_minor": 5 }