Paired integration and query-to-reference mapping#
NetworkVI is a Variational Autoencoder for the paired and mosaic integration of multimodal single-cell data. Each single-cell modality is learned in individual encoders; the latent spaces are then aligned for a joint representation. In this notebook we demonstrate how to use NetworkVI for paired integration and query-to-reference mapping. We use a publically available CITE-seq dataset created by [Luecken et al., 2021].
If you use NetworkVI, please consider citing:
Arnoldt, L., Upmeier zu Belzen, J., Herrmann, L., Nguyen, K., Theis, F.J., Wild, B. , Eils, R., “Biologically Guided Variational Inference for Interpretable Multimodal Single-Cell Integration”, bioRxiv, June 2025.
import os
import sys
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
!pip install networkvi
import numpy as np
import requests
import sys
sys.path.append("../../../src")
import networkvi
from networkvi.model import NETWORKVI
import scanpy as sc
import anndata as ad
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
import sklearn
from sklearn.preprocessing import normalize
import seaborn as sns
import pandas as pd
%%capture
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline
if not os.path.isfile(os.path.join("../../../resources/", "go-basic.obo")):
r = requests.get("http://purl.obolibrary.org/obo/go/go-basic.obo", allow_redirects=True)
open(os.path.join("../../../resources/", "go-basic.obo"), 'wb').write(r.content)
if not os.path.isfile("../../../gene_interactions.csv"):
import gdown
gdown.download("https://drive.google.com/uc?export=download&id=1MwAuqw2JVl6L9xfsWGfUH7rB02LsQ_Vv", "../../../gene_interactions.csv")
Data loading#
We provide a subsetted and preprocessed version of the CITE-seq dataset for this tutorial. We filtered low qualty cells, lowly expressed features and performed subsampling to 5000 cells and feature subselection to 4000 highly variable genes. NetworkVI accepts raw counts data for all modalities. For details on the preprocessing please refer to the manuscript and the code in the NetworkVI reproducibility repository. In the repository we also provide scripts for downloading and processing the datasets used in the manuscript. The original CITE-seq dataset contains 90261 bone marrow mononuclear cells inferred at 4 different sites with 13953 genes and 143 proteins.
rna_cite_path = "neurips2021_cite_bmmc_luecken2021.h5ad"
try:
rna_cite = sc.read_h5ad(rna_cite_path)
except OSError:
import gdown
gdown.download("https://drive.google.com/uc?export=download&id=1A9o8wZgWS6udFMXJ-ybXre1dCcCdsKt5")
rna_cite = sc.read_h5ad(rna_cite_path)
rna_cite
AnnData object with n_obs × n_vars = 5000 × 4000
obs: 'cell_type', 'Site', 'DonorID', 'DonorSmoker', 'DonorBMI', 'DonorAge'
var: 'gene_stable_id', 'modality'
uns: 'protein_expression'
obsm: 'protein_expression'
Data setup#
To perform query-to-reference mapping, we select the cells inferred at one of the four different sites as the query dataset. We train NetworkVI on the rest of the dataset.
adata = rna_cite
query = adata[adata.obs["Site"] == "site1"].copy()
adata = adata[adata.obs["Site"] != "site1"].copy()
We register the fields and setup the anndata with a convenience function as shown below. NetworkVI requires as MultiVi to register the modality as the batch_key. Further batch information can be registered as categorical_covariate_keys.
NETWORKVI.setup_anndata(
adata,
batch_key="Site",
protein_expression_obsm_key="protein_expression",
)
Model setup and training#
We demonstrate the setup and training of NetworkVI, which makes use of the Gene Ontology, as the encoder structure. Below we define a convenience function to get the latent representation and to visualize the latent space. NetworkVI requires as MultiVI the definition of the number of features per modality (n_genes, n_proteins, n_regions).
def get_plot_latent_representation(vae, adata):
latent_representation = vae.get_latent_representation(modality="joint")
adata.obsm["X_NetworkVI"] = latent_representation
sc.pp.neighbors(adata, use_rep="X_NetworkVI")
sc.tl.umap(adata, n_components=2)
return adata
NetworkVI#
The mapping of the features to the gene layer and the Gene Ontology requires ENSEMBL-IDs, which we provide in ensembl_ids_genes and ensembl_ids_proteins. We also provide the Gene Ontology and a gene-to-GO mapping file in obo_file and map_ensembl_go. Please don’t use the provided mapping file in production since it’s an artifically subsampled version of a real mapping file. Please refer to the scripts provided in https://github.com/Larnoldt/networkvi_reproducibility or https://geneontology.org/docs/go-annotation-file-gaf-format-2.1/ to download full mapping files.
vae =NETWORKVI(
adata,
n_genes=len(adata.var[adata.var["modality"] == "Gene Expression"]),
n_proteins=adata.obsm["protein_expression"].shape[1],
ensembl_ids_genes=np.array(adata.var[adata.var["modality"] == "Gene Expression"]["gene_stable_id"]),
ensembl_ids_proteins=np.array(adata.uns["protein_expression"]["var"]["gene_stable_id"]),
gene_layer_interaction_source="../../../gene_interactions.csv",
expression_gene_layer_type="interaction",
protein_gene_layer_type="interaction",
obo_file="../../../resources/go-basic.obo",
map_ensembl_go=["../../../resources/ensembl2go.gaf"],
layers_encoder_type="go",
encode_covariates=True,
deeply_inject_covariates=True,
fully_paired=True,
standard_gene_size=5,
standard_go_size=2,
)
vae.train(max_epochs=10, adversarial_mixing=False, save_best=False)
Generated block structure with len 4000 (0.00025 blocks).
Generated block structure with len 7342 (0.000458875 blocks).
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
_biological_process_go-basic_ensembl2go.v2_processed.pkl
Using attention in GO model at level -1.
Generated block structure with len 4565 (0.07608333333333334 blocks).
Generated block structure with len 4565 (0.07608333333333334 blocks).
Using attention in GO model at level 1.
Generated block structure with len 18 (0.13333333333333333 blocks).
Generated block structure with len 18 (0.13333333333333333 blocks).
Generated block structure with len 297 (0.00825 blocks).
Generated block structure with len 297 (0.00825 blocks).
Using attention in GO model at level 0.
Generated block structure with len 1 (0.06666666666666667 blocks).
Generated block structure with len 1 (0.06666666666666667 blocks).
Generated block structure with len 8 (0.8888888888888888 blocks).
Generated block structure with len 8 (0.8888888888888888 blocks).
Generated block structure with len 134 (0.0078125 blocks).
Generated block structure with len 154 (0.0093994140625 blocks).
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
Using attention in GO model at level -1.
Generated block structure with len 182 (0.109375 blocks).
Generated block structure with len 182 (0.109375 blocks).
Using attention in GO model at level 1.
Generated block structure with len 16 (0.15384615384615385 blocks).
Generated block structure with len 16 (0.15384615384615385 blocks).
Generated block structure with len 27 (0.0263671875 blocks).
Generated block structure with len 27 (0.0263671875 blocks).
Using attention in GO model at level 0.
Generated block structure with len 1 (0.07692307692307693 blocks).
Generated block structure with len 1 (0.07692307692307693 blocks).
Generated block structure with len 7 (0.875 blocks).
Generated block structure with len 7 (0.875 blocks).
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Training: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 1/10: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 1/10: 10%|███████████▎ | 1/10 [00:02<00:18, 2.02s/it]
Epoch 1/10: 10%|█████▎ | 1/10 [00:02<00:18, 2.02s/it, v_num=1, train_loss_step=1.07e+4, train_loss_epoch=1.07e+4]
Epoch 2/10: 10%|█████▎ | 1/10 [00:02<00:18, 2.02s/it, v_num=1, train_loss_step=1.07e+4, train_loss_epoch=1.07e+4]
Epoch 2/10: 20%|██████████▌ | 2/10 [00:03<00:12, 1.60s/it, v_num=1, train_loss_step=1.07e+4, train_loss_epoch=1.07e+4]
Epoch 2/10: 20%|██████████▌ | 2/10 [00:03<00:12, 1.60s/it, v_num=1, train_loss_step=5.22e+3, train_loss_epoch=9.04e+3]
Epoch 3/10: 20%|██████████▌ | 2/10 [00:03<00:12, 1.60s/it, v_num=1, train_loss_step=5.22e+3, train_loss_epoch=9.04e+3]
Epoch 3/10: 30%|███████████████▉ | 3/10 [00:04<00:10, 1.46s/it, v_num=1, train_loss_step=5.22e+3, train_loss_epoch=9.04e+3]
Epoch 3/10: 30%|███████████████▉ | 3/10 [00:04<00:10, 1.46s/it, v_num=1, train_loss_step=5.71e+3, train_loss_epoch=8.08e+3]
Epoch 4/10: 30%|███████████████▉ | 3/10 [00:04<00:10, 1.46s/it, v_num=1, train_loss_step=5.71e+3, train_loss_epoch=8.08e+3]
Epoch 4/10: 40%|█████████████████████▏ | 4/10 [00:05<00:08, 1.39s/it, v_num=1, train_loss_step=5.71e+3, train_loss_epoch=8.08e+3]
Epoch 4/10: 40%|█████████████████████▌ | 4/10 [00:05<00:08, 1.39s/it, v_num=1, train_loss_step=6.46e+3, train_loss_epoch=7.3e+3]
Epoch 5/10: 40%|█████████████████████▌ | 4/10 [00:05<00:08, 1.39s/it, v_num=1, train_loss_step=6.46e+3, train_loss_epoch=7.3e+3]
Epoch 5/10: 50%|███████████████████████████ | 5/10 [00:07<00:06, 1.35s/it, v_num=1, train_loss_step=6.46e+3, train_loss_epoch=7.3e+3]
Epoch 5/10: 50%|██████████████████████████▌ | 5/10 [00:07<00:06, 1.35s/it, v_num=1, train_loss_step=3.59e+3, train_loss_epoch=6.59e+3]
Epoch 6/10: 50%|██████████████████████████▌ | 5/10 [00:07<00:06, 1.35s/it, v_num=1, train_loss_step=3.59e+3, train_loss_epoch=6.59e+3]
Epoch 6/10: 60%|███████████████████████████████▊ | 6/10 [00:08<00:05, 1.33s/it, v_num=1, train_loss_step=3.59e+3, train_loss_epoch=6.59e+3]
Epoch 6/10: 60%|█████████████████████████████████▌ | 6/10 [00:08<00:05, 1.33s/it, v_num=1, train_loss_step=8.17e+3, train_loss_epoch=6e+3]
Epoch 7/10: 60%|█████████████████████████████████▌ | 6/10 [00:08<00:05, 1.33s/it, v_num=1, train_loss_step=8.17e+3, train_loss_epoch=6e+3]
Epoch 7/10: 70%|███████████████████████████████████████▏ | 7/10 [00:09<00:03, 1.32s/it, v_num=1, train_loss_step=8.17e+3, train_loss_epoch=6e+3]
Epoch 7/10: 70%|█████████████████████████████████████ | 7/10 [00:09<00:03, 1.32s/it, v_num=1, train_loss_step=5.53e+3, train_loss_epoch=5.46e+3]
Epoch 8/10: 70%|█████████████████████████████████████ | 7/10 [00:09<00:03, 1.32s/it, v_num=1, train_loss_step=5.53e+3, train_loss_epoch=5.46e+3]
Epoch 8/10: 80%|██████████████████████████████████████████▍ | 8/10 [00:11<00:02, 1.31s/it, v_num=1, train_loss_step=5.53e+3, train_loss_epoch=5.46e+3]
Epoch 8/10: 80%|██████████████████████████████████████████▍ | 8/10 [00:11<00:02, 1.31s/it, v_num=1, train_loss_step=3.76e+3, train_loss_epoch=5.05e+3]
Epoch 9/10: 80%|██████████████████████████████████████████▍ | 8/10 [00:11<00:02, 1.31s/it, v_num=1, train_loss_step=3.76e+3, train_loss_epoch=5.05e+3]
Epoch 9/10: 90%|███████████████████████████████████████████████▋ | 9/10 [00:12<00:01, 1.31s/it, v_num=1, train_loss_step=3.76e+3, train_loss_epoch=5.05e+3]
Epoch 9/10: 90%|████████████████████████████████████████████████▌ | 9/10 [00:12<00:01, 1.31s/it, v_num=1, train_loss_step=3.82e+3, train_loss_epoch=4.7e+3]
Epoch 10/10: 90%|███████████████████████████████████████████████▋ | 9/10 [00:12<00:01, 1.31s/it, v_num=1, train_loss_step=3.82e+3, train_loss_epoch=4.7e+3]
Epoch 10/10: 100%|████████████████████████████████████████████████████| 10/10 [00:13<00:00, 1.30s/it, v_num=1, train_loss_step=3.82e+3, train_loss_epoch=4.7e+3]
Epoch 10/10: 100%|███████████████████████████████████████████████████| 10/10 [00:13<00:00, 1.30s/it, v_num=1, train_loss_step=3.65e+3, train_loss_epoch=4.42e+3]
`Trainer.fit` stopped: `max_epochs=10` reached.
Epoch 10/10: 100%|███████████████████████████████████████████████████| 10/10 [00:13<00:00, 1.37s/it, v_num=1, train_loss_step=3.65e+3, train_loss_epoch=4.42e+3]
We can visualize the generated latent space:
adata = get_plot_latent_representation(vae, adata)
sc.pl.umap(adata, color=["cell_type", "Site"], frameon=False, ncols=1)
Query-to-reference mapping#
NetworkVI can be used to train a reference model for mapping new query data. NetworkVI adapts the scarches approach introduced by [Lotfollahi et al., 2022]. Here we demonstrate query-to-reference mappng for networkvi. For a query dataset, the VAE is finetuned. In short, all GO layers are frozen, while the covariate embeddings are trained. The query dataset can contain both the same modalities as the reference dataset or any subset of the modalities. NetworkVI however requires a union of features. We demonstrate how to perform label transfer using a Random Forest Classifier trained on the latent space of the reference dataset. The classifier then is used to predict labels for the query dataset. However, any classifier could be trained on the latent space.
latent_representation = adata.obsm["X_NetworkVI"]
clf = RandomForestClassifier(
random_state=1,
class_weight="balanced_subsample",
verbose=1,
n_jobs=-1,
)
clf.fit(adata.obsm["X_NetworkVI"], adata.obs["cell_type"])
vae.latent_space_classifer_ = clf
[Parallel(n_jobs=-1)]: Done 49 tasks | elapsed: 0.8s
NetworkVI loads the query data and then trains the VAE for the query dataset. Here, we freeze all parts of the network except for the covariate embeddings in the encoders and the first layer of the decoders.
networkvi.model.NETWORKVI.prepare_query_anndata(query, vae)
vae_query = NETWORKVI.load_query_data(query, vae, freeze=True, unfreeze_first_layers=True)
vae_query.train(max_epochs=10, adversarial_mixing=False, save_best=False)
Generated block structure with len 4000 (0.00025 blocks).
Generated block structure with len 7342 (0.000458875 blocks).
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
Using attention in GO model at level -1.
Generated block structure with len 4565 (0.07608333333333334 blocks).
Generated block structure with len 4565 (0.07608333333333334 blocks).
Using attention in GO model at level 1.
Generated block structure with len 18 (0.13333333333333333 blocks).
Generated block structure with len 18 (0.13333333333333333 blocks).
Generated block structure with len 297 (0.00825 blocks).
Generated block structure with len 297 (0.00825 blocks).
Using attention in GO model at level 0.
Generated block structure with len 1 (0.06666666666666667 blocks).
Generated block structure with len 1 (0.06666666666666667 blocks).
Generated block structure with len 8 (0.8888888888888888 blocks).
Generated block structure with len 8 (0.8888888888888888 blocks).
Generated block structure with len 134 (0.0078125 blocks).
Generated block structure with len 154 (0.0093994140625 blocks).
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
Using attention in GO model at level -1.
Generated block structure with len 182 (0.109375 blocks).
Generated block structure with len 182 (0.109375 blocks).
Using attention in GO model at level 1.
Generated block structure with len 16 (0.15384615384615385 blocks).
Generated block structure with len 16 (0.15384615384615385 blocks).
Generated block structure with len 27 (0.0263671875 blocks).
Generated block structure with len 27 (0.0263671875 blocks).
Using attention in GO model at level 0.
Generated block structure with len 1 (0.07692307692307693 blocks).
Generated block structure with len 1 (0.07692307692307693 blocks).
Generated block structure with len 7 (0.875 blocks).
Generated block structure with len 7 (0.875 blocks).
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Training: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 1/10: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 1/10: 10%|███████████▎ | 1/10 [00:00<00:02, 3.09it/s]
Epoch 1/10: 10%|█████▎ | 1/10 [00:00<00:02, 3.09it/s, v_num=1, train_loss_step=7.83e+3, train_loss_epoch=8.16e+3]
Epoch 2/10: 10%|█████▎ | 1/10 [00:00<00:02, 3.09it/s, v_num=1, train_loss_step=7.83e+3, train_loss_epoch=8.16e+3]
Epoch 2/10: 20%|██████████▌ | 2/10 [00:00<00:02, 3.19it/s, v_num=1, train_loss_step=7.83e+3, train_loss_epoch=8.16e+3]
Epoch 2/10: 20%|██████████▌ | 2/10 [00:00<00:02, 3.19it/s, v_num=1, train_loss_step=7.04e+3, train_loss_epoch=8.06e+3]
Epoch 3/10: 20%|██████████▌ | 2/10 [00:00<00:02, 3.19it/s, v_num=1, train_loss_step=7.04e+3, train_loss_epoch=8.06e+3]
Epoch 3/10: 30%|███████████████▉ | 3/10 [00:00<00:02, 3.24it/s, v_num=1, train_loss_step=7.04e+3, train_loss_epoch=8.06e+3]
Epoch 3/10: 30%|███████████████▉ | 3/10 [00:00<00:02, 3.24it/s, v_num=1, train_loss_step=7.01e+3, train_loss_epoch=8.02e+3]
Epoch 4/10: 30%|███████████████▉ | 3/10 [00:00<00:02, 3.24it/s, v_num=1, train_loss_step=7.01e+3, train_loss_epoch=8.02e+3]
Epoch 4/10: 40%|█████████████████████▏ | 4/10 [00:01<00:01, 3.26it/s, v_num=1, train_loss_step=7.01e+3, train_loss_epoch=8.02e+3]
Epoch 4/10: 40%|█████████████████████▏ | 4/10 [00:01<00:01, 3.26it/s, v_num=1, train_loss_step=7.97e+3, train_loss_epoch=7.95e+3]
Epoch 5/10: 40%|█████████████████████▏ | 4/10 [00:01<00:01, 3.26it/s, v_num=1, train_loss_step=7.97e+3, train_loss_epoch=7.95e+3]
Epoch 5/10: 50%|██████████████████████████▌ | 5/10 [00:01<00:01, 3.27it/s, v_num=1, train_loss_step=7.97e+3, train_loss_epoch=7.95e+3]
Epoch 5/10: 50%|███████████████████████████ | 5/10 [00:01<00:01, 3.27it/s, v_num=1, train_loss_step=6.02e+3, train_loss_epoch=7.9e+3]
Epoch 6/10: 50%|███████████████████████████ | 5/10 [00:01<00:01, 3.27it/s, v_num=1, train_loss_step=6.02e+3, train_loss_epoch=7.9e+3]
Epoch 6/10: 60%|████████████████████████████████▍ | 6/10 [00:01<00:01, 3.28it/s, v_num=1, train_loss_step=6.02e+3, train_loss_epoch=7.9e+3]
Epoch 6/10: 60%|███████████████████████████████▊ | 6/10 [00:01<00:01, 3.28it/s, v_num=1, train_loss_step=8.35e+3, train_loss_epoch=7.83e+3]
Epoch 7/10: 60%|███████████████████████████████▊ | 6/10 [00:01<00:01, 3.28it/s, v_num=1, train_loss_step=8.35e+3, train_loss_epoch=7.83e+3]
Epoch 7/10: 70%|█████████████████████████████████████ | 7/10 [00:02<00:00, 3.29it/s, v_num=1, train_loss_step=8.35e+3, train_loss_epoch=7.83e+3]
Epoch 7/10: 70%|█████████████████████████████████████ | 7/10 [00:02<00:00, 3.29it/s, v_num=1, train_loss_step=1.03e+4, train_loss_epoch=7.79e+3]
Epoch 8/10: 70%|█████████████████████████████████████ | 7/10 [00:02<00:00, 3.29it/s, v_num=1, train_loss_step=1.03e+4, train_loss_epoch=7.79e+3]
Epoch 8/10: 80%|██████████████████████████████████████████▍ | 8/10 [00:02<00:00, 3.30it/s, v_num=1, train_loss_step=1.03e+4, train_loss_epoch=7.79e+3]
Epoch 8/10: 80%|██████████████████████████████████████████▍ | 8/10 [00:02<00:00, 3.30it/s, v_num=1, train_loss_step=7.04e+3, train_loss_epoch=7.74e+3]
Epoch 9/10: 80%|██████████████████████████████████████████▍ | 8/10 [00:02<00:00, 3.30it/s, v_num=1, train_loss_step=7.04e+3, train_loss_epoch=7.74e+3]
Epoch 9/10: 90%|███████████████████████████████████████████████▋ | 9/10 [00:02<00:00, 3.30it/s, v_num=1, train_loss_step=7.04e+3, train_loss_epoch=7.74e+3]
Epoch 9/10: 90%|████████████████████████████████████████████████▌ | 9/10 [00:02<00:00, 3.30it/s, v_num=1, train_loss_step=7.5e+3, train_loss_epoch=7.68e+3]
Epoch 10/10: 90%|███████████████████████████████████████████████▋ | 9/10 [00:02<00:00, 3.30it/s, v_num=1, train_loss_step=7.5e+3, train_loss_epoch=7.68e+3]
Epoch 10/10: 100%|████████████████████████████████████████████████████| 10/10 [00:03<00:00, 3.30it/s, v_num=1, train_loss_step=7.5e+3, train_loss_epoch=7.68e+3]
Epoch 10/10: 100%|███████████████████████████████████████████████████| 10/10 [00:03<00:00, 3.30it/s, v_num=1, train_loss_step=6.73e+3, train_loss_epoch=7.62e+3]
`Trainer.fit` stopped: `max_epochs=10` reached.
Epoch 10/10: 100%|███████████████████████████████████████████████████| 10/10 [00:03<00:00, 3.27it/s, v_num=1, train_loss_step=6.73e+3, train_loss_epoch=7.62e+3]
We can predict the cell types of the query dataset and visualize them accordingly in a confusion matrix:
predicted_labels = clf.predict(vae_query.get_latent_representation(modality="joint"))
[Parallel(n_jobs=1)]: Done 49 tasks | elapsed: 0.0s
plt.figure(figsize=(10, 8))
confusion_matrix = sklearn.metrics.confusion_matrix(query.obs["cell_type"], predicted_labels, labels=sorted(np.unique(sorted(np.unique(adata.obs["cell_type"])))))
norm_confusion_matrix = normalize(confusion_matrix, axis=0)
data_norm_confusion_matrix = []
for true_label_index, true_label in enumerate(sorted(np.unique(adata.obs["cell_type"]))):
for pred_label_index, pred_label in enumerate(sorted(np.unique(adata.obs["cell_type"]))):
data_norm_confusion_matrix.append([true_label, pred_label, norm_confusion_matrix[true_label_index, pred_label_index]])
df = pd.DataFrame(columns=["true labels", "predicted labels", "prevalence"], data=data_norm_confusion_matrix)
ax = sns.heatmap(data=df['prevalence'].to_numpy().reshape(len(sorted(np.unique(adata.obs["cell_type"]))), len(sorted(np.unique(adata.obs["cell_type"])))), xticklabels=sorted(np.unique(adata.obs["cell_type"])), yticklabels=sorted(np.unique(adata.obs["cell_type"])), cbar=True)
ax.set_xlabel("predicted labels", fontsize=10)
ax.set_ylabel("true labels", fontsize=10)
ax.tick_params(axis='x', labelsize=6)
ax.tick_params(axis='y', labelsize=6)
plt.show()