{ "cells": [ { "cell_type": "markdown", "id": "71b1d70e", "metadata": {}, "source": [ "### Human Thymus Multimodal Deconvolution\n", "\n", "This tutorial demonstrates multimodal deconvolution on simulated human thymus spatial data derived from a single-cell multiomic (CITE-seq) dataset with paired RNA and protein (ADT) measurements.\n", "\n", "The single-cell multiomic dataset contains 11 batches, which are paired to form 11 batch-pair datasets. Within each pair, one batch is used as the reference and the other is used to simulate spatial data. The simulation scheme follows the idea described in *[Spatial transcriptomics deconvolution methods generalize well to spatial chromatin accessibility data](https://academic.oup.com/bioinformatics/article/41/Supplement_1/i314/8199385)*. \n", "\n", "Datasets are available on [Zenodo](https://zenodo.org/records/19691472).\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0ad9a20d", "metadata": {}, "outputs": [], "source": [ "import muon as mu\n", "import os\n", "from os.path import join\n", "import scanpy as sc\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import glob\n", "\n", "import sparank\n", "from sparank.config import ExpConfig, ModalityConfig, SimulationConfig\n", "from sparank.framework import SpaRank" ] }, { "cell_type": "code", "execution_count": 2, "id": "e6416d14", "metadata": {}, "outputs": [], "source": [ "data_dir = '../data/human_thymus/simulated'\n", "batch_pairs = pd.read_csv(f'{data_dir}/batch_pairs.csv')" ] }, { "cell_type": "markdown", "id": "96527998", "metadata": {}, "source": [ "#### Training\n", "\n", "The cell below shows the training configuration used in our experiments. It is commented out by default, so you only need to run it if you want to reproduce the models from scratch.\n", "\n", "A few settings are worth noting:\n", "\n", "- `modalities=[ModalityConfig(name=\"rna\", ...), ModalityConfig(name=\"adt\", ...)]`: specific configurations for RNA and protein modalities.\n", "- `context_key=None`: no context conditioning is used." ] }, { "cell_type": "code", "execution_count": 3, "id": "35a3be6c", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# for (spot_batch, cell_batch) in zip(batch_pairs.spot, batch_pairs.cell):\n", "# pair_id = f'{spot_batch}-{cell_batch}'\n", "\n", "# mdata_ref = mu.read_h5mu(\n", "# glob.glob(join(data_dir, f'{cell_batch}-*', \"ref_data.h5mu\"))[0]\n", "# )\n", "# adx_sc_rna = mdata_ref.mod['rna']\n", "# adx_sc_adt = mdata_ref.mod['adt']\n", "\n", "# adx_sc_rna.var_names_make_unique()\n", "# adx_sc_adt.var_names_make_unique()\n", "\n", "# cfg = ExpConfig(\n", "# modalities=[\n", "# ModalityConfig(name=\"rna\", top_k=500, cl_dropout_rate=0.3, mrp_mask_rate=0.3),\n", "# ModalityConfig(name=\"adt\", top_k=500, cl_dropout_rate=0.3, mrp_mask_rate=0.3) # top_k will subset to the actual protein panel\n", "# ],\n", "# simulation=SimulationConfig(\n", "# total_samples=500_000, \n", "# batch_request_size=100_000,\n", "# batch_key='sample',\n", "# ),\n", "# celltype_key='annotation',\n", "# batch_key='sample',\n", "# context_key=None,\n", "# dim=128,\n", "# depth=2,\n", "# heads=4,\n", "# cl_weight=0.5,\n", "# cl_temperature=0.1,\n", "# mrp_weight=0.5,\n", "# cls_weight=1.0,\n", "# epochs=20,\n", "# num_workers=0\n", "# )\n", " \n", "# model = SpaRank(cfg=cfg, save_dir=f'../outputs/human_thymus/{pair_id}')\n", "# model.register_modality(\"rna\", adx_sc_rna)\n", "# model.register_modality(\"adt\", adx_sc_adt)\n", "# model.prepare() # use sp gene panel as a prior\n", "# model.fit() # train\n", "# model.save()\n", "\n", "# break" ] }, { "cell_type": "markdown", "id": "ccfe317d", "metadata": {}, "source": [ "#### Evaluation\n", "\n", "After training, we load the checkpoint and apply it to the corresponding simulated spatial sections. The predicted cell-type proportions are then compared against the ground truth from simulation.\n", "\n", "We report Jensen–Shannon divergence and Pearson correlation scores to assess deconvolution performance. We also track modality gate scores to examine how the model balances RNA and ADT signals." ] }, { "cell_type": "code", "execution_count": 4, "id": "a5d22529", "metadata": {}, "outputs": [], "source": [ "from typing import Union\n", "\n", "import numpy as np\n", "import pandas as pd\n", "from scipy.spatial.distance import jensenshannon\n", "from scipy.stats import pearsonr\n", "\n", "def _to_array(x):\n", " return x.values if isinstance(x, pd.DataFrame) else np.asarray(x)\n", "\n", "def jsd(true: Union[pd.DataFrame, np.ndarray], predicted: Union[pd.DataFrame, np.ndarray]) -> float:\n", " jsd_vals = jsd_per_col(true, predicted)\n", " return np.nanmean(jsd_vals)\n", "\n", "def jsd_per_col(true: Union[pd.DataFrame, np.ndarray], predicted: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:\n", " t = _to_array(true)\n", " p = _to_array(predicted)\n", "\n", " n_cols = t.shape[1]\n", " vals = np.full(n_cols, np.nan, dtype=float)\n", " for i in range(n_cols):\n", " a = t[:, i].astype(float)\n", " b = p[:, i].astype(float)\n", " sa = a.sum()\n", " sb = b.sum()\n", " if sa == 0 or sb == 0:\n", " vals[i] = np.nan\n", " continue\n", " pa = a / sa\n", " pb = b / sb\n", " vals[i] = float(jensenshannon(pa, pb, axis=0, base=2))\n", " return vals\n", "\n", "\n", "def pcc(true: Union[pd.DataFrame, np.ndarray], predicted: Union[pd.DataFrame, np.ndarray]) -> float:\n", " r_vals = pcc_per_col(true, predicted)\n", " return float(np.nanmean(r_vals))\n", "\n", "\n", "def pcc_per_col(true: Union[pd.DataFrame, np.ndarray], predicted: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:\n", " t = _to_array(true)\n", " p = _to_array(predicted)\n", "\n", " n_cols = t.shape[1]\n", " vals = np.full(n_cols, np.nan, dtype=float)\n", " for i in range(n_cols):\n", " x = t[:, i]\n", " y = p[:, i]\n", " if np.std(x) == 0 or np.std(y) == 0:\n", " vals[i] = np.nan\n", " continue\n", " vals[i], _ = pearsonr(x, y)\n", " return vals\n", "\n", "\n", "def eval_pred(targets, pred):\n", " \n", " jsd_mean_row = jsd(targets.T, pred.T) # by row\n", " pcc_mean_row = pcc(targets.T, pred.T)\n", " jsd_mean_col = jsd(targets, pred) # by row\n", " pcc_mean_col = pcc(targets, pred)\n", "\n", " d = {\n", " 'jsd-mean-row': jsd_mean_row,\n", " 'pcc-mean-row': pcc_mean_row,\n", " 'jsd-mean-col': jsd_mean_col,\n", " 'pcc-mean-col': pcc_mean_col,\n", " }\n", " return d" ] }, { "cell_type": "markdown", "id": "47ab8d0a", "metadata": {}, "source": [ "##### Metrics" ] }, { "cell_type": "code", "execution_count": 5, "id": "72e409c5", "metadata": { "scrolled": true }, "outputs": [], "source": [ "import torch\n", "\n", "rs = []\n", "gate_scores = {}\n", "for (spot_batch, cell_batch) in zip(batch_pairs.spot, batch_pairs.cell):\n", " pair_id = f'{spot_batch}-{cell_batch}'\n", "\n", " model = SpaRank.load(f'../outputs/human_thymus/{pair_id}', device=\"cuda:0\")\n", " mdata_sp = mu.read_h5mu(\n", " glob.glob(join(data_dir, f'{spot_batch}-*', \"sp_data.h5mu\"))[0]\n", " )\n", "\n", " adx_sp_rna = mdata_sp.mod['rna']\n", " adx_sp_adt = mdata_sp.mod['adt']\n", "\n", " adx_sp_rna.var_names_make_unique()\n", " adx_sp_adt.var_names_make_unique()\n", "\n", " gate_scores[pair_id] = []\n", " for ep in np.arange(model.cfg.epochs):\n", " ckpt = join(model.save_dir, f\"model_epoch{ep+1}.pth\")\n", " model.model.load_state_dict(torch.load(ckpt, map_location=model.device))\n", " \n", " preds, gates = model.predict(mod_adatas={'rna':adx_sp_rna, 'adt':adx_sp_adt}, return_gate_scores=True)\n", " df_true = pd.DataFrame(adx_sp_rna.obsm['proportions'], \n", " columns=adx_sp_rna.uns['proportion_names']).copy()\n", " df_pred = preds.reindex(columns=df_true.columns, fill_value=0)\n", " \n", " df_true = df_true.div(df_true.sum(axis=1), axis=0) \n", " df_pred = df_pred.div(df_pred.sum(axis=1), axis=0)\n", " rna_r = eval_pred(df_true, df_pred)\n", " rna_r['dataset'] = pair_id\n", " rna_r['ep'] = ep + 1\n", " \n", " # print(ep, rna_r)\n", " rs.append(rna_r)\n", " gate_scores[pair_id].append(gates.mean(axis=0)[0])\n", "\n", " break" ] }, { "cell_type": "code", "execution_count": 6, "id": "02e4c58f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | jsd-mean-row | \n", "pcc-mean-row | \n", "jsd-mean-col | \n", "pcc-mean-col | \n", "dataset | \n", "ep | \n", "
|---|---|---|---|---|---|---|
| 0 | \n", "0.620966 | \n", "0.422434 | \n", "0.547996 | \n", "0.498087 | \n", "TT-CITE-1-TT-CITE-5 | \n", "1 | \n", "
| 1 | \n", "0.585704 | \n", "0.506616 | \n", "0.538801 | \n", "0.522315 | \n", "TT-CITE-1-TT-CITE-5 | \n", "2 | \n", "
| 2 | \n", "0.542583 | \n", "0.595207 | \n", "0.524391 | \n", "0.572617 | \n", "TT-CITE-1-TT-CITE-5 | \n", "3 | \n", "
| 3 | \n", "0.530953 | \n", "0.616906 | \n", "0.513516 | \n", "0.593068 | \n", "TT-CITE-1-TT-CITE-5 | \n", "4 | \n", "
| 4 | \n", "0.521424 | \n", "0.630698 | \n", "0.506310 | \n", "0.619039 | \n", "TT-CITE-1-TT-CITE-5 | \n", "5 | \n", "
| 5 | \n", "0.508718 | \n", "0.642302 | \n", "0.501077 | \n", "0.636270 | \n", "TT-CITE-1-TT-CITE-5 | \n", "6 | \n", "
| 6 | \n", "0.507621 | \n", "0.642506 | \n", "0.494738 | \n", "0.651287 | \n", "TT-CITE-1-TT-CITE-5 | \n", "7 | \n", "
| 7 | \n", "0.521434 | \n", "0.614370 | \n", "0.496680 | \n", "0.651511 | \n", "TT-CITE-1-TT-CITE-5 | \n", "8 | \n", "
| 8 | \n", "0.511970 | \n", "0.618444 | \n", "0.488771 | \n", "0.661643 | \n", "TT-CITE-1-TT-CITE-5 | \n", "9 | \n", "
| 9 | \n", "0.495957 | \n", "0.649375 | \n", "0.484811 | \n", "0.663655 | \n", "TT-CITE-1-TT-CITE-5 | \n", "10 | \n", "
| 10 | \n", "0.505268 | \n", "0.631040 | \n", "0.484866 | \n", "0.667103 | \n", "TT-CITE-1-TT-CITE-5 | \n", "11 | \n", "
| 11 | \n", "0.497875 | \n", "0.639952 | \n", "0.484147 | \n", "0.666498 | \n", "TT-CITE-1-TT-CITE-5 | \n", "12 | \n", "
| 12 | \n", "0.515348 | \n", "0.606359 | \n", "0.494060 | \n", "0.645583 | \n", "TT-CITE-1-TT-CITE-5 | \n", "13 | \n", "
| 13 | \n", "0.507942 | \n", "0.630888 | \n", "0.487754 | \n", "0.666721 | \n", "TT-CITE-1-TT-CITE-5 | \n", "14 | \n", "
| 14 | \n", "0.499108 | \n", "0.633821 | \n", "0.484229 | \n", "0.662645 | \n", "TT-CITE-1-TT-CITE-5 | \n", "15 | \n", "
| 15 | \n", "0.504156 | \n", "0.623252 | \n", "0.484261 | \n", "0.663488 | \n", "TT-CITE-1-TT-CITE-5 | \n", "16 | \n", "
| 16 | \n", "0.498450 | \n", "0.632995 | \n", "0.485791 | \n", "0.663203 | \n", "TT-CITE-1-TT-CITE-5 | \n", "17 | \n", "
| 17 | \n", "0.497066 | \n", "0.632111 | \n", "0.484906 | \n", "0.657270 | \n", "TT-CITE-1-TT-CITE-5 | \n", "18 | \n", "
| 18 | \n", "0.509916 | \n", "0.616092 | \n", "0.486829 | \n", "0.661512 | \n", "TT-CITE-1-TT-CITE-5 | \n", "19 | \n", "
| 19 | \n", "0.497279 | \n", "0.638984 | \n", "0.485762 | \n", "0.666644 | \n", "TT-CITE-1-TT-CITE-5 | \n", "20 | \n", "