Interpretability: Inferring GO-specific covariate attention values#
NetworkVI is a Variational Autoencoder for the paired and mosaic integration of multimodal single-cell data. Each single-cell modality is learned in individual encoders; the latent spaces are then aligned for a joint representation. In this notebook we demonstrate how to use the inherent interpretability of NetworkVI to infer importances for GO terms and genes. We use a publically available CITE-seq dataset created by [Luecken et al., 2021].
If you use NetworkVI, please consider citing:
Arnoldt, L., Upmeier zu Belzen, J., Herrmann, L., Nguyen, K., Theis, F.J., Wild, B. , Eils, R., “Biologically Guided Variational Inference for Interpretable Multimodal Single-Cell Integration”, bioRxiv, June 2025.
import os
import sys
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
!pip install networkvi
import numpy as np
import requests
import sys
sys.path.append("../../../src")
from networkvi.model import NETWORKVI
import scanpy as sc
import anndata as ad
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
import sklearn
from sklearn.preprocessing import normalize
import seaborn as sns
import pandas as pd
from goatools.obo_parser import GODag
import shutil
%%capture
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline
r = requests.get("http://purl.obolibrary.org/obo/go/go-basic.obo", allow_redirects=True)
open(os.path.join("../../../resources/", "go-basic.obo"), 'wb').write(r.content)
31321250
if not os.path.isfile("../../../gene_interactions.csv"):
import gdown
gdown.download("https://drive.google.com/uc?export=download&id=1MwAuqw2JVl6L9xfsWGfUH7rB02LsQ_Vv", "../../../gene_interactions.csv")
Data loading#
We provide a subsetted and preprocessed version of the CITE-seq and Multiome dataset for this tutorial. We filtered low qualty cells, lowly expressed features and performed subsampling to 5000 cells and feature subselection to 4000 highly variable genes and 20000 highly variable ATAC peaks. NetworkVI accepts raw counts data for all modalities. For details on the preprocessing please refer to the manuscript and the code in the NetworkVI reproducibility repository. In the repository we also provide scripts for downloading and processing the datasets used in the manuscript. The original CITE-seq dataset contains 90261 bone marrow mononuclear cells inferred at 4 different sites with 13953 genes and 143 proteins.
rna_cite_path = "neurips2021_cite_bmmc_luecken2021.h5ad"
try:
rna_cite = sc.read_h5ad(rna_cite_path)
except OSError:
import gdown
gdown.download("https://drive.google.com/uc?export=download&id=1A9o8wZgWS6udFMXJ-ybXre1dCcCdsKt5")
rna_cite = sc.read_h5ad(rna_cite_path)
rna_cite
AnnData object with n_obs × n_vars = 5000 × 4000
obs: 'cell_type', 'Site', 'DonorID', 'DonorSmoker', 'DonorBMI', 'DonorAge'
var: 'gene_stable_id', 'modality'
uns: 'protein_expression'
obsm: 'protein_expression'
Data setup#
To perform query-to-reference mapping, we select the cells inferred at one of the four different sites as the query dataset. We train NetworkVI on the rest of the dataset.
adata = rna_cite
query = adata[adata.obs["Site"] == "site1"].copy()
adata = adata[adata.obs["Site"] != "site1"].copy()
We register the fields and setup the anndata with a convenience function as shown below. NetworkVI requires as MultiVi to register the modality as the batch_key. Further batch information can be registered as categorical_covariate_keys.
categorical_covariate_keys = ['DonorSmoker']
continuous_covariate_keys = ['DonorBMI', 'DonorAge']
adata.obs["DonorBMI"] = (adata.obs["DonorBMI"] - adata.obs["DonorBMI"].mean()) / adata.obs["DonorBMI"].std()
adata.obs["DonorAge"] = (adata.obs["DonorAge"] - adata.obs["DonorAge"].mean()) / adata.obs["DonorAge"].std()
NETWORKVI.setup_anndata(
adata,
batch_key="Site",
protein_expression_obsm_key="protein_expression",
categorical_covariate_keys=categorical_covariate_keys,
continuous_covariate_keys=continuous_covariate_keys,
)
Model setup and training#
We demonstrate the setup and training of NetworkVI, which makes use of the Gene Ontology, as the encoder structure. Below we define a convenience function to get the latent representation and to visualize the latent space. NetworkVI requires as MultiVI the definition of the number of features per modality (n_genes, n_proteins, n_regions).
def get_plot_latent_representation(vae, adata):
latent_representation = vae.get_latent_representation(modality="joint")
adata.obsm["X_NetworkVI"] = latent_representation
sc.pp.neighbors(adata, use_rep="X_NetworkVI")
sc.tl.umap(adata, n_components=2)
return adata
NetworkVI#
We select NetworkVI by setting layers_encoder_type to “go”. The mapping of the features to the gene layer and the Gene Ontology requires ENSEMBL-IDs, which we provide in ensembl_ids_genes and ensembl_ids_proteins. We also provide the Gene Ontology and a gene-to-GO mapping file in obo_file and map_ensembl_go. Please don’t use the provided mapping file in production since it’s an artifically subsampled version of a real mapping file. Please refer to the scripts provided in https://github.com/Larnoldt/networkvi_reproducibility or https://geneontology.org/docs/go-annotation-file-gaf-format-2.1/ to download full mapping files.
vae =NETWORKVI(
adata,
n_genes=len(adata.var[adata.var["modality"] == "Gene Expression"]),
n_proteins=adata.obsm["protein_expression"].shape[1],
ensembl_ids_genes=np.array(adata.var[adata.var["modality"] == "Gene Expression"]["gene_stable_id"]),
ensembl_ids_proteins=np.array(adata.uns["protein_expression"]["var"]["gene_stable_id"]),
gene_layer_interaction_source="../../../gene_interactions.csv",
expression_gene_layer_type="interaction",
protein_gene_layer_type="interaction",
obo_file="../../../resources/go-basic.obo",
map_ensembl_go=["../../../resources/ensembl2go.gaf"],
layers_encoder_type="go",
encode_covariates=True,
deeply_inject_covariates=True,
fully_paired=True,
standard_gene_size=5,
standard_go_size=2,
n_layers_encoder=4,
)
vae.train(max_epochs=10, adversarial_mixing=False, save_best=False)
Generated block structure with len 4000 (0.00025 blocks).
Generated block structure with len 7342 (0.000458875 blocks).
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
_biological_process_go-basic_ensembl2go.v2_processed.pkl
Using attention in GO model at level -1.
Generated block structure with len 3904 (0.03253333333333333 blocks).
Generated block structure with len 3904 (0.03253333333333333 blocks).
Using attention in GO model at level 3.
Generated block structure with len 41 (0.059420289855072465 blocks).
Generated block structure with len 41 (0.059420289855072465 blocks).
Generated block structure with len 283 (0.003076086956521739 blocks).
Generated block structure with len 283 (0.003076086956521739 blocks).
Using attention in GO model at level 2.
Generated block structure with len 26 (0.07536231884057971 blocks).
Generated block structure with len 26 (0.07536231884057971 blocks).
Generated block structure with len 306 (0.0051 blocks).
Generated block structure with len 306 (0.0051 blocks).
Using attention in GO model at level 1.
Generated block structure with len 5 (0.018518518518518517 blocks).
Generated block structure with len 5 (0.018518518518518517 blocks).
Generated block structure with len 1 (0.004830917874396135 blocks).
Generated block structure with len 1 (0.004830917874396135 blocks).
Generated block structure with len 18 (0.13333333333333333 blocks).
Generated block structure with len 18 (0.13333333333333333 blocks).
Generated block structure with len 18 (0.0005 blocks).
Generated block structure with len 18 (0.0005 blocks).
Using attention in GO model at level 0.
Generated block structure with len 1 (0.03333333333333333 blocks).
Generated block structure with len 1 (0.03333333333333333 blocks).
Generated block structure with len 8 (0.8888888888888888 blocks).
Generated block structure with len 8 (0.8888888888888888 blocks).
Generated block structure with len 134 (0.0078125 blocks).
Generated block structure with len 154 (0.0093994140625 blocks).
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms; optional_attrs(relationship)
Using attention in GO model at level -1.
Generated block structure with len 108 (0.10546875 blocks).
Generated block structure with len 108 (0.10546875 blocks).
Using attention in GO model at level 3.
Generated block structure with len 10 (0.11363636363636363 blocks).
Generated block structure with len 10 (0.11363636363636363 blocks).
Generated block structure with len 17 (0.012073863636363636 blocks).
Generated block structure with len 17 (0.012073863636363636 blocks).
Using attention in GO model at level 2.
Generated block structure with len 14 (0.0979020979020979 blocks).
Generated block structure with len 14 (0.0979020979020979 blocks).
Generated block structure with len 42 (0.025240384615384616 blocks).
Generated block structure with len 42 (0.025240384615384616 blocks).
Using attention in GO model at level 1.
Generated block structure with len 2 (0.03125 blocks).
Generated block structure with len 2 (0.03125 blocks).
Generated block structure with len 1 (0.011363636363636364 blocks).
Generated block structure with len 1 (0.011363636363636364 blocks).
Generated block structure with len 16 (0.15384615384615385 blocks).
Generated block structure with len 16 (0.15384615384615385 blocks).
Using attention in GO model at level 0.
Generated block structure with len 1 (0.125 blocks).
Generated block structure with len 1 (0.125 blocks).
Generated block structure with len 7 (0.875 blocks).
Generated block structure with len 7 (0.875 blocks).
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Training: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 1/10: 0%| | 0/10 [00:00<?, ?it/s]
Epoch 1/10: 10%|███████████▎ | 1/10 [00:02<00:21, 2.39s/it]
Epoch 1/10: 10%|█████▎ | 1/10 [00:02<00:21, 2.39s/it, v_num=1, train_loss_step=4.68e+3, train_loss_epoch=1.12e+4]
Epoch 2/10: 10%|█████▎ | 1/10 [00:02<00:21, 2.39s/it, v_num=1, train_loss_step=4.68e+3, train_loss_epoch=1.12e+4]
Epoch 2/10: 20%|██████████▌ | 2/10 [00:04<00:15, 1.98s/it, v_num=1, train_loss_step=4.68e+3, train_loss_epoch=1.12e+4]
Epoch 2/10: 20%|██████████▌ | 2/10 [00:04<00:15, 1.98s/it, v_num=1, train_loss_step=9.76e+3, train_loss_epoch=9.64e+3]
Epoch 3/10: 20%|██████████▌ | 2/10 [00:04<00:15, 1.98s/it, v_num=1, train_loss_step=9.76e+3, train_loss_epoch=9.64e+3]
Epoch 3/10: 30%|███████████████▉ | 3/10 [00:05<00:12, 1.85s/it, v_num=1, train_loss_step=9.76e+3, train_loss_epoch=9.64e+3]
Epoch 3/10: 30%|███████████████▉ | 3/10 [00:05<00:12, 1.85s/it, v_num=1, train_loss_step=6.97e+3, train_loss_epoch=8.52e+3]
Epoch 4/10: 30%|███████████████▉ | 3/10 [00:05<00:12, 1.85s/it, v_num=1, train_loss_step=6.97e+3, train_loss_epoch=8.52e+3]
Epoch 4/10: 40%|█████████████████████▏ | 4/10 [00:07<00:10, 1.78s/it, v_num=1, train_loss_step=6.97e+3, train_loss_epoch=8.52e+3]
Epoch 4/10: 40%|█████████████████████▌ | 4/10 [00:07<00:10, 1.78s/it, v_num=1, train_loss_step=3.9e+3, train_loss_epoch=7.69e+3]
Epoch 5/10: 40%|█████████████████████▌ | 4/10 [00:07<00:10, 1.78s/it, v_num=1, train_loss_step=3.9e+3, train_loss_epoch=7.69e+3]
Epoch 5/10: 50%|███████████████████████████ | 5/10 [00:09<00:08, 1.75s/it, v_num=1, train_loss_step=3.9e+3, train_loss_epoch=7.69e+3]
Epoch 5/10: 50%|██████████████████████████▌ | 5/10 [00:09<00:08, 1.75s/it, v_num=1, train_loss_step=6.83e+3, train_loss_epoch=7.02e+3]
Epoch 6/10: 50%|██████████████████████████▌ | 5/10 [00:09<00:08, 1.75s/it, v_num=1, train_loss_step=6.83e+3, train_loss_epoch=7.02e+3]
Epoch 6/10: 60%|███████████████████████████████▊ | 6/10 [00:10<00:06, 1.73s/it, v_num=1, train_loss_step=6.83e+3, train_loss_epoch=7.02e+3]
Epoch 6/10: 60%|███████████████████████████████▊ | 6/10 [00:10<00:06, 1.73s/it, v_num=1, train_loss_step=5.94e+3, train_loss_epoch=6.44e+3]
Epoch 7/10: 60%|███████████████████████████████▊ | 6/10 [00:10<00:06, 1.73s/it, v_num=1, train_loss_step=5.94e+3, train_loss_epoch=6.44e+3]
Epoch 7/10: 70%|█████████████████████████████████████ | 7/10 [00:12<00:05, 1.71s/it, v_num=1, train_loss_step=5.94e+3, train_loss_epoch=6.44e+3]
Epoch 7/10: 70%|█████████████████████████████████████ | 7/10 [00:12<00:05, 1.71s/it, v_num=1, train_loss_step=7.67e+3, train_loss_epoch=5.95e+3]
Epoch 8/10: 70%|█████████████████████████████████████ | 7/10 [00:12<00:05, 1.71s/it, v_num=1, train_loss_step=7.67e+3, train_loss_epoch=5.95e+3]
Epoch 8/10: 80%|██████████████████████████████████████████▍ | 8/10 [00:14<00:03, 1.94s/it, v_num=1, train_loss_step=7.67e+3, train_loss_epoch=5.95e+3]
Epoch 8/10: 80%|██████████████████████████████████████████▍ | 8/10 [00:14<00:03, 1.94s/it, v_num=1, train_loss_step=6.38e+3, train_loss_epoch=5.52e+3]
Epoch 9/10: 80%|██████████████████████████████████████████▍ | 8/10 [00:14<00:03, 1.94s/it, v_num=1, train_loss_step=6.38e+3, train_loss_epoch=5.52e+3]
Epoch 9/10: 90%|███████████████████████████████████████████████▋ | 9/10 [00:17<00:02, 2.23s/it, v_num=1, train_loss_step=6.38e+3, train_loss_epoch=5.52e+3]
Epoch 9/10: 90%|███████████████████████████████████████████████▋ | 9/10 [00:17<00:02, 2.23s/it, v_num=1, train_loss_step=5.28e+3, train_loss_epoch=5.15e+3]
Epoch 10/10: 90%|██████████████████████████████████████████████▊ | 9/10 [00:17<00:02, 2.23s/it, v_num=1, train_loss_step=5.28e+3, train_loss_epoch=5.15e+3]
Epoch 10/10: 100%|███████████████████████████████████████████████████| 10/10 [00:19<00:00, 2.07s/it, v_num=1, train_loss_step=5.28e+3, train_loss_epoch=5.15e+3]
Epoch 10/10: 100%|███████████████████████████████████████████████████| 10/10 [00:19<00:00, 2.07s/it, v_num=1, train_loss_step=6.98e+3, train_loss_epoch=4.83e+3]
`Trainer.fit` stopped: `max_epochs=10` reached.
Epoch 10/10: 100%|███████████████████████████████████████████████████| 10/10 [00:19<00:00, 1.95s/it, v_num=1, train_loss_step=6.98e+3, train_loss_epoch=4.83e+3]
We can visualize the generated latent space:
adata = get_plot_latent_representation(vae, adata)
sc.pl.umap(adata, color=["cell_type", "Site"], frameon=False, ncols=1)
Interpretability#
In NetworkVI, a GO term-specific multi-head attention mechanism [Vaswani et al., 2017] is employed to incorporate covariates such as batch effects and modality. In this mechanism, the covariates serve as keys and values, while GO term representations act as queries (Q). The original GO embeddings are retained through a residual connection, ensuring their contribution to the learned representation. The modified representation is then propagated through the subsequent layers of the architecture. Please refer to the API documentation. We infer the attention values for all samples and average them per phenotype.
obodag = GODag(os.path.join("../../../resources/", "go-basic.obo"))
../../../resources/go-basic.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms
_, _, _, _, covariate_attention_registry_mean_phenotypes, _, _, _ = vae.calculate_covariate_attention(labels_column="cell_type",
results_dir=f"tutorial/covariate_attention",
save_results=False,
shuffle_set_split=False,
batch_size=512,
modality_categorical_covariate_keys=["Site"]+categorical_covariate_keys,
continuous_covariate_keys=continuous_covariate_keys
)
We then plot the GO-specific attention values per phenotype:
attention_values = []
for phenotype in covariate_attention_registry_mean_phenotypes.keys():
covariate_attention_registry_mean_phenotype_expression_filt = {key: value for key, value in
covariate_attention_registry_mean_phenotypes[phenotype][
"expression"].items() if "GO:" in key}
attention_values_phenotype = [val[-1] for go_term, val in covariate_attention_registry_mean_phenotype_expression_filt.items() if go_term in obodag.keys()]
attention_values.append(np.array(attention_values_phenotype))
go_terms = [obodag[go_term].name for go_term in covariate_attention_registry_mean_phenotype_expression_filt.keys() if go_term in obodag.keys()]
cell_types = covariate_attention_registry_mean_phenotypes.keys()
plt.figure(figsize=(12, 8))
sns.clustermap(np.array(attention_values).T, xticklabels=cell_types, yticklabels=go_terms)
plt.show()
<Figure size 1200x800 with 0 Axes>
shutil.rmtree("tutorial/covariate_attention")