Interpretability: Inferring importances for GO and genes#
NetworkVI is a Variational Autoencoder for the paired and mosaic integration of multimodal single-cell data. Each single-cell modality is learned in individual encoders; the latent spaces are then aligned for a joint representation. In this notebook we demonstrate how to use the inherent interpretability of NetworkVI to infer importances for GO terms and genes. We use a publically available CITE-seq dataset created by [Luecken et al., 2021].
If you use NetworkVI, please consider citing:
Arnoldt, L., Upmeier zu Belzen, J., Herrmann, L., Nguyen, K., Theis, F.J., Wild, B. , Eils, R., “Biologically Guided Variational Inference for Interpretable Multimodal Single-Cell Integration”, bioRxiv, June 2025.
import os
import sys
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
!pip install networkvi
import numpy as np
import requests
import sys
sys.path.append("../../../src")
import networkvi
from networkvi.model import NETWORKVI
import scanpy as sc
import anndata as ad
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
import sklearn
from sklearn.preprocessing import normalize
import seaborn as sns
import pandas as pd
from goatools.obo_parser import GODag
import shutil
%%capture
import warnings
from sklearn.exceptions import ConvergenceWarning
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=ConvergenceWarning)
%matplotlib inline
r = requests.get("http://purl.obolibrary.org/obo/go/go-basic.obo", allow_redirects=True)
open(os.path.join("../../../resources/", "go-basic.obo"), 'wb').write(r.content)
31321250
if not os.path.isfile("../../../gene_interactions.csv"):
import gdown
gdown.download("https://drive.google.com/uc?export=download&id=1MwAuqw2JVl6L9xfsWGfUH7rB02LsQ_Vv", "../../../gene_interactions.csv")
Data loading#
We provide a subsetted and preprocessed version of the CITE-seq and Multiome dataset for this tutorial. We filtered low qualty cells, lowly expressed features and performed subsampling to 5000 cells and feature subselection to 4000 highly variable genes and 20000 highly variable ATAC peaks. NetworkVI accepts raw counts data for all modalities. For details on the preprocessing please refer to the manuscript and the code in the NetworkVI reproducibility repository. In the repository we also provide scripts for downloading and processing the datasets used in the manuscript. The original CITE-seq dataset contains 90261 bone marrow mononuclear cells inferred at 4 different sites with 13953 genes and 143 proteins.
rna_cite_path = "neurips2021_cite_bmmc_luecken2021.h5ad"
try:
rna_cite = sc.read_h5ad(rna_cite_path)
except OSError:
import gdown
gdown.download("https://drive.google.com/uc?export=download&id=1A9o8wZgWS6udFMXJ-ybXre1dCcCdsKt5")
rna_cite = sc.read_h5ad(rna_cite_path)
rna_cite
AnnData object with n_obs × n_vars = 5000 × 4000
obs: 'cell_type', 'Site', 'DonorID', 'DonorSmoker', 'DonorBMI', 'DonorAge'
var: 'gene_stable_id', 'modality'
uns: 'protein_expression'
obsm: 'protein_expression'
Data setup#
To perform query-to-reference mapping, we select the cells inferred at one of the four different sites as the query dataset. We train NetworkVI on the rest of the dataset.
adata = rna_cite
query = adata[adata.obs["Site"] == "site1"].copy()
adata = adata[adata.obs["Site"] != "site1"].copy()
We register the fields and setup the anndata with a convenience function as shown below. NetworkVI requires as MultiVi to register the modality as the batch_key. Further batch information can be registered as categorical_covariate_keys.
NETWORKVI.setup_anndata(
adata,
batch_key="Site",
protein_expression_obsm_key="protein_expression",
)
Model setup and training#
We demonstrate the setup and training of NetworkVI, which makes use of the Gene Ontology, as the encoder structure. Below we define a convenience function to get the latent representation and to visualize the latent space. NetworkVI requires as MultiVI the definition of the number of features per modality (n_genes, n_proteins, n_regions).
def get_plot_latent_representation(vae, adata):
latent_representation = vae.get_latent_representation(modality="joint")
adata.obsm["X_NetworkVI"] = latent_representation
sc.pp.neighbors(adata, use_rep="X_NetworkVI")
sc.tl.umap(adata, n_components=2)
return adata
NetworkVI#
We select NetworkVI by setting layers_encoder_type to “go”. The mapping of the features to the gene layer and the Gene Ontology requires ENSEMBL-IDs, which we provide in ensembl_ids_genes and ensembl_ids_proteins. We also provide the Gene Ontology and a gene-to-GO mapping file in obo_file and map_ensembl_go. Please don’t use the provided mapping file in production since it’s an artifically subsampled version of a real mapping file. Please refer to the scripts provided in https://github.com/Larnoldt/networkvi_reproducibility or https://geneontology.org/docs/go-annotation-file-gaf-format-2.1/ to download full mapping files.
vae = NETWORKVI(
adata,
n_genes=len(adata.var[adata.var["modality"] == "Gene Expression"]),
n_proteins=adata.obsm["protein_expression"].shape[1],
ensembl_ids_genes=np.array(adata.var[adata.var["modality"] == "Gene Expression"]["gene_stable_id"]),
ensembl_ids_proteins=np.array(adata.uns["protein_expression"]["var"]["gene_stable_id"]),
gene_layer_interaction_source="../../../gene_interactions.csv",
expression_gene_layer_type="interaction",
protein_gene_layer_type="interaction",
obo_file="../../../resources/go-basic.obo",
map_ensembl_go=["../../../resources/ensembl2go.gaf"],
layers_encoder_type="go",
encode_covariates=True,
deeply_inject_covariates=True,
fully_paired=True,
standard_gene_size=5,
standard_go_size=2,
)
vae.train(max_epochs=10, adversarial_mixing=False, save_best=False)
Generated block structure with len 4000 (0.00025 blocks).
Generated block structure with len 7342 (0.000458875 blocks).
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
_biological_process_go-basic_ensembl2go.v2_processed.pkl
Using attention in GO model at level -1.
Generated block structure with len 4565 (0.07608333333333334 blocks).
Generated block structure with len 4565 (0.07608333333333334 blocks).
Using attention in GO model at level 1.
Generated block structure with len 18 (0.13333333333333333 blocks).
Generated block structure with len 18 (0.13333333333333333 blocks).
Generated block structure with len 297 (0.00825 blocks).
Generated block structure with len 297 (0.00825 blocks).
Using attention in GO model at level 0.
Generated block structure with len 1 (0.06666666666666667 blocks).
Generated block structure with len 1 (0.06666666666666667 blocks).
Generated block structure with len 8 (0.8888888888888888 blocks).
Generated block structure with len 8 (0.8888888888888888 blocks).
Generated block structure with len 134 (0.0078125 blocks).
Generated block structure with len 154 (0.0093994140625 blocks).
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
Using attention in GO model at level -1.
Generated block structure with len 182 (0.109375 blocks).
Generated block structure with len 182 (0.109375 blocks).
Using attention in GO model at level 1.
Generated block structure with len 16 (0.15384615384615385 blocks).
Generated block structure with len 16 (0.15384615384615385 blocks).
Generated block structure with len 27 (0.0263671875 blocks).
Generated block structure with len 27 (0.0263671875 blocks).
Using attention in GO model at level 0.
Generated block structure with len 1 (0.07692307692307693 blocks).
Generated block structure with len 1 (0.07692307692307693 blocks).
Generated block structure with len 7 (0.875 blocks).
Generated block structure with len 7 (0.875 blocks).
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Training: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 1/10: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 1/10: 10%|███████████▎ | 1/10 [00:02<00:18, 2.06s/it]
Epoch 1/10: 10%|█████▎ | 1/10 [00:02<00:18, 2.06s/it, v_num=1, train_loss_step=5.09e+3, train_loss_epoch=1.03e+4]
Epoch 2/10: 10%|█████▎ | 1/10 [00:02<00:18, 2.06s/it, v_num=1, train_loss_step=5.09e+3, train_loss_epoch=1.03e+4]
Epoch 2/10: 20%|██████████▌ | 2/10 [00:03<00:12, 1.62s/it, v_num=1, train_loss_step=5.09e+3, train_loss_epoch=1.03e+4]
Epoch 2/10: 20%|██████████▊ | 2/10 [00:03<00:12, 1.62s/it, v_num=1, train_loss_step=4.18e+3, train_loss_epoch=8.9e+3]
Epoch 3/10: 20%|██████████▊ | 2/10 [00:03<00:12, 1.62s/it, v_num=1, train_loss_step=4.18e+3, train_loss_epoch=8.9e+3]
Epoch 3/10: 30%|████████████████▏ | 3/10 [00:04<00:10, 1.48s/it, v_num=1, train_loss_step=4.18e+3, train_loss_epoch=8.9e+3]
Epoch 3/10: 30%|███████████████▉ | 3/10 [00:04<00:10, 1.48s/it, v_num=1, train_loss_step=5.83e+3, train_loss_epoch=7.96e+3]
Epoch 4/10: 30%|███████████████▉ | 3/10 [00:04<00:10, 1.48s/it, v_num=1, train_loss_step=5.83e+3, train_loss_epoch=7.96e+3]
Epoch 4/10: 40%|█████████████████████▏ | 4/10 [00:05<00:08, 1.41s/it, v_num=1, train_loss_step=5.83e+3, train_loss_epoch=7.96e+3]
Epoch 4/10: 40%|█████████████████████▏ | 4/10 [00:05<00:08, 1.41s/it, v_num=1, train_loss_step=9.61e+3, train_loss_epoch=7.22e+3]
Epoch 5/10: 40%|█████████████████████▏ | 4/10 [00:05<00:08, 1.41s/it, v_num=1, train_loss_step=9.61e+3, train_loss_epoch=7.22e+3]
Epoch 5/10: 50%|██████████████████████████▌ | 5/10 [00:07<00:06, 1.37s/it, v_num=1, train_loss_step=9.61e+3, train_loss_epoch=7.22e+3]
Epoch 5/10: 50%|███████████████████████████ | 5/10 [00:07<00:06, 1.37s/it, v_num=1, train_loss_step=4.53e+3, train_loss_epoch=6.6e+3]
Epoch 6/10: 50%|███████████████████████████ | 5/10 [00:07<00:06, 1.37s/it, v_num=1, train_loss_step=4.53e+3, train_loss_epoch=6.6e+3]
Epoch 6/10: 60%|████████████████████████████████▍ | 6/10 [00:08<00:05, 1.35s/it, v_num=1, train_loss_step=4.53e+3, train_loss_epoch=6.6e+3]
Epoch 6/10: 60%|███████████████████████████████▊ | 6/10 [00:08<00:05, 1.35s/it, v_num=1, train_loss_step=3.56e+3, train_loss_epoch=6.05e+3]
Epoch 7/10: 60%|███████████████████████████████▊ | 6/10 [00:08<00:05, 1.35s/it, v_num=1, train_loss_step=3.56e+3, train_loss_epoch=6.05e+3]
Epoch 7/10: 70%|█████████████████████████████████████ | 7/10 [00:09<00:04, 1.34s/it, v_num=1, train_loss_step=3.56e+3, train_loss_epoch=6.05e+3]
Epoch 7/10: 70%|█████████████████████████████████████ | 7/10 [00:09<00:04, 1.34s/it, v_num=1, train_loss_step=8.27e+3, train_loss_epoch=5.59e+3]
Epoch 8/10: 70%|█████████████████████████████████████ | 7/10 [00:09<00:04, 1.34s/it, v_num=1, train_loss_step=8.27e+3, train_loss_epoch=5.59e+3]
Epoch 8/10: 80%|██████████████████████████████████████████▍ | 8/10 [00:11<00:02, 1.33s/it, v_num=1, train_loss_step=8.27e+3, train_loss_epoch=5.59e+3]
Epoch 8/10: 80%|██████████████████████████████████████████▍ | 8/10 [00:11<00:02, 1.33s/it, v_num=1, train_loss_step=4.08e+3, train_loss_epoch=5.16e+3]
Epoch 9/10: 80%|██████████████████████████████████████████▍ | 8/10 [00:11<00:02, 1.33s/it, v_num=1, train_loss_step=4.08e+3, train_loss_epoch=5.16e+3]
Epoch 9/10: 90%|███████████████████████████████████████████████▋ | 9/10 [00:12<00:01, 1.32s/it, v_num=1, train_loss_step=4.08e+3, train_loss_epoch=5.16e+3]
Epoch 9/10: 90%|████████████████████████████████████████████████▌ | 9/10 [00:12<00:01, 1.32s/it, v_num=1, train_loss_step=7.8e+3, train_loss_epoch=4.82e+3]
Epoch 10/10: 90%|███████████████████████████████████████████████▋ | 9/10 [00:12<00:01, 1.32s/it, v_num=1, train_loss_step=7.8e+3, train_loss_epoch=4.82e+3]
Epoch 10/10: 100%|████████████████████████████████████████████████████| 10/10 [00:13<00:00, 1.32s/it, v_num=1, train_loss_step=7.8e+3, train_loss_epoch=4.82e+3]
Epoch 10/10: 100%|███████████████████████████████████████████████████| 10/10 [00:13<00:00, 1.32s/it, v_num=1, train_loss_step=4.56e+3, train_loss_epoch=4.51e+3]
`Trainer.fit` stopped: `max_epochs=10` reached.
Epoch 10/10: 100%|███████████████████████████████████████████████████| 10/10 [00:13<00:00, 1.38s/it, v_num=1, train_loss_step=4.56e+3, train_loss_epoch=4.51e+3]
We can visualize the generated latent space:
adata = get_plot_latent_representation(vae, adata)
sc.pl.umap(adata, color=["cell_type", "Site"], frameon=False, ncols=1)
Interpretability#
By fitting a Logistic Regression on the activations of the nodes representing a GO term (calc_go_terms), a gene (calc_genes) or a gene_group (calc_gene_groups) we may infer the importance of this node for the performance of the neural network as described previously by [Ma et al., 2018]. Here, we demonstrate how to utilize the get_go_gene_importances and get_perturbation_go_gene_importances functions to infer the importances for the original dataset and for perturbed datasets in a sensitivity analysis. The importance is calculated in standard mode for each label against the rest. Using labels_selection, labels_selection_groups, restrict_by_column_key_activations_mask, restrict_by_column_values_activations_mask and comparison this behaviour may be modified. Please refer to the API documentation.
obodag = GODag(os.path.join("../../../resources/", "go-basic.obo"))
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms
_ = vae.get_go_gene_importances(labels_column="cell_type",
results_dir=f"tutorial/calc_go_terms",
shuffle_set_split=False,
batch_size=512,
calc_go_terms=True,
calc_genes=False,
calc_gene_groups=False,
save_fit=False)
get_go_gene_importances saves the activations, GO importances and further information in a pandas pickle file. We then load these files and visualize the GO importance for a selection of GO terms for all available cell types.
go_importances = []
cell_types = []
for cell_type in sorted(np.unique(adata.obs["cell_type"])):
cell_type_mod = cell_type.replace(" ", "_").replace("/", "_")
try:
go_importances.append(np.array(pd.read_pickle(f"./tutorial/calc_go_terms/protein_cell_type_vs_rest_{cell_type_mod}_res_unp_act_mask_False_flip_False_go_terms_evaluated_graph_csv.pkl")["predictors_label_valid_roc"]).astype("float"))
cell_types.append(cell_type)
except FileNotFoundError:
continue
go_terms = list(pd.read_pickle(f"./tutorial/calc_go_terms/protein_cell_type_vs_rest_{cell_type_mod}_res_unp_act_mask_False_flip_False_go_terms_evaluated_graph_csv.pkl").index)
go_terms = [f"{obodag[go_term].name} ({go_term})" if go_term in obodag.keys() else go_term for go_term in go_terms]
plt.figure(figsize=(10, 6))
sns.heatmap(np.array(go_importances).T, xticklabels=cell_types, yticklabels=go_terms)
plt.show()
shutil.rmtree("tutorial/calc_go_terms")
We also show how to perform a sensitivity analysis by perturbing features. Here we perturb CD86 (ENSG00000114013) by setting all values to zero. The function get_perturbation_go_gene_importances internally calls get_go_gene_importances, but additionally we need to provide the modality of the features in modality_key and the ENSEMBL-IDs in gene_stable_id_key.
perturbation_gene_stable_ids = ["ENSG00000114013"]
go_importances_perturbation = vae.get_perturbation_go_gene_importances(labels_column="cell_type",
modality_key="modality",
gene_stable_id_key="gene_stable_id",
perturbation_gene_stable_ids=perturbation_gene_stable_ids,
batch_size=512,
shuffle_set_split=False,
calc_go_terms=True,
calc_genes=False,
calc_gene_groups=False,)
We can visualize the change of the importance in comparison to the not-perturbed dataset accordingly:
go_importances_pert = []
for cell_type in cell_types:
cell_type_mod = cell_type.replace(" ", "_").replace("/", "_")
go_importances_pert.append(go_importances_perturbation["ENSG00000114013"]['protein']["protein"][cell_type_mod].astype("float"))
plt.figure(figsize=(10, 6))
sns.heatmap((np.array(go_importances) - np.array(go_importances_pert)).T, xticklabels=cell_types, yticklabels=go_terms)
plt.show()