Coverage for src / autoencodix / utils / feature_importance.py: 0%
259 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import os
2import warnings
3from typing import Dict
5import anndata as ad
6import gseapy as gp
7import matplotlib.pyplot as plt
8import numpy as np
9import pandas as pd
10import scanpy as sc
11import scipy
12import seaborn as sns
13import torch
14import torch.nn as nn
15from captum.attr import (
16 LRP,
17 DeepLiftShap,
18 GradientShap,
19 IntegratedGradients,
20 Lime,
21 LimeBase,
22)
23from IPython.display import HTML, Image, clear_output, display
25from autoencodix.modeling._captum_forward import CaptumForward
26from autoencodix.utils.adata_converter import AnnDataConverter
28warnings.filterwarnings(
29 "ignore",
30 message="Setting forward, backward hooks and attributes on non-linear",
31 category=UserWarning,
32)
34warnings.filterwarnings("ignore")
37class Vanillix_EncoderSingleDim(nn.Module):
38 def __init__(self, vae_model, dim):
39 super(Vanillix_EncoderSingleDim, self).__init__()
40 # Accessing the required components from the original VAE model
41 self.encoder = vae_model._encoder
42 self.input_dim = vae_model.input_dim
43 self.dim = dim # latent dim
45 def forward(self, x):
46 if x.shape[1] != self.input_dim:
47 raise ValueError(
48 f"Expected input with {self.input_dim} features, but got {x.shape[1]} features. "
49 f"This may indicate missing data modalities or incorrect data preparation."
50 )
52 total_elements = x.numel()
53 assert (
54 total_elements % self.input_dim == 0
55 ), f"Total elements {total_elements} is not a multiple of input_dim {self.input_dim}"
57 batch_size = x.size(0)
58 x = x.view(batch_size, -1)
59 latent = self.encoder(x)
60 output = latent[:, self.dim]
61 output = output.unsqueeze(1) # Equivalent to output.reshape(output.shape[0], 1)
62 return output
65def make_DeepLiftShap_Vanillix_dim(model, inputs, baselines, latent_dimension):
66 model_encoder_dim = Vanillix_EncoderSingleDim(model, dim=latent_dimension)
67 deeplift = DeepLiftShap(model_encoder_dim)
68 attributions, convergence = deeplift.attribute(
69 inputs=inputs, baselines=baselines, return_convergence_delta=True
70 )
71 avg_abs_attributions = attributions.abs().mean(dim=0)
72 return avg_abs_attributions, convergence
75def make_feature_importance_Vanillix(
76 input_adata,
77 van,
78 method="DeepLiftShap",
79 n_subset=100,
80 seed_int=12,
81 baseline_type="mean",
82 baseline_group="all",
83 obs_col=None,
84):
85 """
86 Computes attributions for a trained Vanillix model.
88 Parameters
89 ----------
90 van : object
91 Object containing the trained Vanillix model and input data as an AnnData object.
93 method : str, {'DeepLiftShap', 'IntegratedGradients'}, default='DeepLiftShap'
94 post-hoc feature importance assessment method
95 - 'DeepLiftShap': method from pytorch Captum library, approximates SHAP values using Deeplift
96 - 'IntegratedGradients': method from pytorch Captum library, attribution via Integrated gradients
98 n_subset : int, default=100
99 Subset of randomly selected cells to compute attributions on.
101 seed_int : int, default=12
102 Seed for reproducible random sampling (NumPy & PyTorch).
104 baseline_type : str, {'mean', 'random_sample'}, default='mean'
105 How to generate the baseline:
106 - 'mean': average expression of the group
107 - 'random_sample': a randomly picked cell from the group
109 baseline_group : str, default='all'
110 Which group of cells to use as the baseline.
111 Use 'all' or a specific group label from `adata.obs[obs_col]`.
113 obs_col : str or None, default=None
114 Column in `adata.obs` that defines groups. Required if `baseline_group` is not 'all'.
116 Returns
117 -------
118 df_attributions : pandas.DataFrame
119 Gene-level attribution scores per latent dimension.
120 """
121 if baseline_group != "all" and obs_col is None:
122 raise ValueError(
123 "obs_col must be provided when using a group-specific baseline."
124 )
125 np.random.seed(seed_int)
126 torch.manual_seed(seed_int)
127 model = van.result.model
128 # input_adata = van.raw_user_data["multi_sc"]["multi_sc"].mod["user-data"]
129 inputs = torch.tensor(
130 input_adata.X.toarray()
131 if scipy.sparse.issparse(input_adata.X)
132 else input_adata.X
133 )
134 if baseline_group == "all":
135 if baseline_type == "mean":
136 baseline_mean = inputs.mean(axis=0) # gene_means
137 baselines = torch.tensor(np.tile(baseline_mean, (inputs.shape[0], 1)))
138 if baseline_type == "random_sample":
139 baseline_random = inputs[torch.randint(0, inputs.size(0), (1,)).item()]
140 baselines = torch.tensor(np.tile(baseline_random, (inputs.shape[0], 1)))
141 else:
142 input_adata_filtered = input_adata[input_adata.obs[obs_col] == baseline_group]
143 inputs_filtered = torch.tensor(
144 input_adata_filtered.X.toarray()
145 if scipy.sparse.issparse(input_adata_filtered.X)
146 else input_adata_filtered.X
147 )
148 if baseline_type == "mean":
149 baseline_mean = inputs_filtered.mean(axis=0) # gene_means
150 baselines = torch.tensor(np.tile(baseline_mean, (inputs.shape[0], 1)))
151 if baseline_type == "random_sample":
152 baseline_random = inputs_filtered[
153 torch.randint(0, inputs_filtered.size(0), (1,)).item()
154 ]
155 baselines = torch.tensor(np.tile(baseline_random, (inputs.shape[0], 1)))
156 gene_names = input_adata.var_names
157 cell_IDs = input_adata.obs_names
159 latent_dimensions = list(range(0, van.result.adata_latent.shape[1]))
160 indices_DeepLiftShap = np.random.choice(
161 inputs.shape[0], size=n_subset, replace=False
162 )
163 all_attr = []
164 for latent_dim in latent_dimensions:
165 if method == "DeepLiftShap":
166 avg_abs_attributions, convergence = make_DeepLiftShap_Vanillix_dim(
167 model=model,
168 inputs=inputs[indices_DeepLiftShap].float(),
169 baselines=baselines[indices_DeepLiftShap].float(),
170 latent_dimension=latent_dim,
171 )
172 if method == "IntegratedGradients":
173 avg_abs_attributions, convergence = make_IntegratedGradients_Vanillix_dim(
174 model=model,
175 inputs=inputs[indices_DeepLiftShap].float(),
176 baselines=baselines[indices_DeepLiftShap].float(),
177 latent_dimension=latent_dim,
178 )
180 all_attr.append(avg_abs_attributions.detach().cpu())
182 attr_matrix = torch.stack(all_attr).T.numpy()
184 df_attributions = pd.DataFrame(
185 attr_matrix,
186 index=list(gene_names),
187 columns=[f"latent_dimension_{i}" for i in latent_dimensions],
188 )
189 return df_attributions
192def get_top_kgenes_per_latent_dimension(df_attributions, latent_dim=0, topk=10):
193 """
194 Returns the top-k genes with the highest attribution for a given latent dimension.
196 Parameters
197 ----------
198 df_attributions : pandas.DataFrame
199 DataFrame with genes as rows and latent dimensions as columns
201 latent_dim : int, default=0
202 Index of the latent dimension to extract top genes from
204 topk : int, default=10
205 Number of top genes to return.
207 Returns
208 -------
209 topk_genes : list of str
210 List of gene names with the highest attributions for the selected latent dimension.
211 """
212 topk_genes = (
213 df_attributions["latent_dimension_" + str(latent_dim)]
214 .nlargest(topk)
215 .index.tolist()
216 )
217 return topk_genes
220def plot_union_top_genes_heatmap(df, top_n=50, cmap="viridis", save=None):
221 """
222 Plots a heatmap for all latent dimensions, showing the union of the top N genes per dimension.
224 Parameters:
225 - df (pd.DataFrame): DataFrame with feature attributions.
226 Rows = genes/features, Columns = latent dimensions.
227 - top_n (int): Number of top genes to select per latent dimension (default: 50).
228 - cmap (str): Colormap for the heatmap (default: 'viridis').
229 - save (str) : path for saving
230 """
232 # Collect top N genes per latent dimension
233 top_genes_sets = []
234 for col in df.columns:
235 top_genes = df[col].nlargest(top_n).index
236 top_genes_sets.append(set(top_genes))
238 # Union of all top genes across dimensions
239 union_genes = sorted(set.union(*top_genes_sets))
241 # Subset the dataframe to only these genes
242 data = df.loc[union_genes]
244 # Plot heatmap
245 plt.figure(figsize=(len(df.columns) * 0.1 + 3, max(6, 0.14 * len(union_genes))))
246 ax = sns.heatmap(
247 data,
248 cmap=cmap,
249 annot=False,
250 linewidths=0.5,
251 linecolor="gray",
252 cbar_kws={"label": "Attribution Score"},
253 )
255 # Format ticks (centered)
256 # ax.set_xticks([i + 0.5 for i in range(len(data.columns))])
257 ax.set_xticklabels(data.columns, rotation=90, ha="right", fontsize=9)
259 ax.set_yticks([i + 0.5 for i in range(len(data.index))])
260 ax.set_yticklabels(data.index, rotation=0, fontsize=8)
262 plt.title(f"attribution of top {top_n} genes per latent dimension", fontsize=10)
264 plt.tight_layout()
265 if save is not None:
266 plt.savefig(save, bbox_inches="tight")
267 plt.show()
270def get_top_genes_per_dimension(df, top_n=100):
271 """
272 Get top N genes per latent dimension.
273 Returns a dict: {dimension_name: [list of top genes]}
274 """
275 top_genes = {}
276 for dim in df.columns:
277 top_genes[dim] = df[dim].nlargest(top_n).index.tolist()
278 return top_genes
281def run_go_enrichment(
282 top_genes_dict,
283 n_top_pathways=10,
284 gene_set_library="GO_Biological_Process_2021",
285 organism="Human",
286):
287 """
288 Run GO enrichment using Enrichr (via gseapy).
289 Returns a dict of DataFrames with results per latent dimension.
290 """
291 results = {}
292 for dim, gene_list in top_genes_dict.items():
293 enr = gp.enrichr(
294 gene_list=gene_list,
295 gene_sets=gene_set_library,
296 organism=organism,
297 outdir=None, # no file output
298 cutoff=0.05,
299 )
300 results[dim] = enr.results.sort_values("Adjusted P-value").head(n_top_pathways)
301 return results
304def plot_GO_log_odds_all(PE_dict, top_n=10, base_save_identifier=None):
305 """
306 Plots log odds ratio of top enriched GO terms for each latent dimension.
308 Parameters:
309 - PE_dict (dict): Dict of DataFrames, keyed by latent dimension name.
310 - top_n (int): Number of top GO terms to plot per dimension.
311 """
312 for dim, df in PE_dict.items():
313 if df.empty or "Odds Ratio" not in df.columns:
314 print(f"Skipping {dim}: empty or missing required columns.")
315 continue
317 # Compute log odds ratio
318 df["log_odds_ratio"] = np.log(df["Odds Ratio"])
320 # Select top_n terms
321 plot_df = df.sort_values(by="log_odds_ratio", ascending=True).tail(top_n)
323 # Plot
324 plt.figure(figsize=(10, 0.5 * top_n)) # wider and taller
325 bars = plt.barh(
326 plot_df["Term"],
327 plot_df["log_odds_ratio"],
328 color="#D4AF3A",
329 height=0.6, # makes bars thicker
330 )
332 # Add value labels
333 for bar in bars:
334 plt.text(
335 bar.get_width() - 1.05,
336 bar.get_y() + bar.get_height() / 2,
337 f"{bar.get_width():.2f}",
338 va="center",
339 ha="left",
340 fontsize=8,
341 color="#C40308",
342 )
344 plt.xlabel("log(odds Ratio)", fontsize=10)
345 plt.title(f"{dim} — top {top_n} Enriched GO terms", fontsize=12)
346 plt.xticks(fontsize=8)
347 plt.yticks(fontsize=9) # smaller tick labels
348 plt.tight_layout()
349 if base_save_identifier is not None:
350 plt.savefig(
351 base_save_identifier + "_" + str(dim) + ".pdf", bbox_inches="tight"
352 )
353 plt.show()
356def do_miraculix_visualization():
357 display(HTML('<p style="font-size:20px;">Calculate feature importance</p>'))
358 display(Image(filename="miraculix_zaubertrank.gif"))
359 clear_output(wait=True)
362def do_feature_importance_Vanillix(
363 van,
364 method="DeepLiftShap",
365 baseline_type="mean",
366 baseline_group="all",
367 obs_col=None,
368 n_subset=100,
369 seed_int=12,
370 do_visualizations=True,
371 top_n_genes_heatmap=50,
372 top_n_foreground_pathways=30,
373 gene_set_library=None,
374 organism="Human",
375 n_top_pathways=10,
376 save_out_path=None,
377 do_miraculix_vis=False,
378):
379 """
380 Computes DeepLiftShap attributions for a trained Vanillix model and visulaization.
382 Parameters
383 ----------
384 van : object
385 Object containing the trained Vanillix model and input data as an AnnData object.
387 method : str, {'DeepLiftShap', 'IntegratedGradients'}, default='DeepLiftShap'
388 post-hoc feature importance assessment method
389 - 'DeepLiftShap': method from pytorch Captum library, approximates SHAP values using Deeplift
390 - 'IntegratedGradients': method from pytorch Captum library, attribution via Integrated gradients
392 n_subset : int, default=100
393 Subset of randomly selected cells to compute attributions on.
395 seed_int : int, default=12
396 Seed for reproducible random sampling (NumPy & PyTorch).
398 baseline_type : str, {'mean', 'random_sample'}, default='mean'
399 How to generate the baseline:
400 - 'mean': average expression of the group
401 - 'random_sample': a randomly picked cell from the group
403 baseline_group : str, default='all'
404 Which group of cells to use as the baseline.
405 Use 'all' or a specific group label from `adata.obs[obs_col]`.
407 obs_col : str or None, default=None
408 Column in `adata.obs` that defines groups. Required if `baseline_group` is not 'all'.
410 top_n_genes_heatmap : int, default=50
411 Number of top genes (with highest attribution scores) per latent dimension to visualize in the heatmap.
413 top_n_foreground_pathways : int, default=30
414 Number of top genes to include as "foreground" in enrichment analysis.
416 gene_set_library : str, default=None,
417 The gene set library to use for enrichment analysis (from Enrichr), e.g. "GO_Biological_Process_2021".
419 organism : str, default='Human'
420 The organism relevant to the gene sets used in enrichment.
422 n_top_pathways : int, default=10
423 Number of top enriched pathways to visualize in the results per latent dimension.
425 save_out_path : str or None, default=None
426 File path to save the output results figures.
427 If None, results will not be saved to disk.
429 do_miraculix_vis : bool, default=True
430 Show Miraculix-style GIF.
432 Returns
433 -------
434 df_attributions : pandas.DataFrame
435 Gene-level attribution scores per latent dimension.
436 """
437 feature_importance_methods = {"DeepLiftShap", "IntegratedGradients"}
438 if method not in feature_importance_methods:
439 raise ValueError(
440 f"Invalid method '{method}'. Must be one of: {', '.join(feature_importance_methods)}."
441 )
443 # baseline_type
444 if do_miraculix_vis:
445 do_miraculix_visualization()
447 input_data: Dict[str, ad.Anndata] = AnnDataConverter.dataset_to_adata(
448 datasetcontainer=van.result.datasets
449 )
450 for data_name, adata in input_data.items():
451 df_attributions = make_feature_importance_Vanillix(
452 input_adata=adata,
453 van=van,
454 method=method,
455 n_subset=n_subset,
456 seed_int=seed_int,
457 baseline_type=baseline_type,
458 baseline_group=baseline_group,
459 obs_col=obs_col,
460 )
461 if save_out_path is not None:
462 save_path = os.path.join(save_out_path, "df_attributions.csv")
463 df_attributions.to_csv(save_path)
465 """
466 if not do_visualizations:
467 display(HTML('<p style="font-size:20px;">Feature importance completed!</p>'))
468 """
470 # top genes plot
471 if do_visualizations:
472 if save_out_path is not None:
473 save_path = os.path.join(save_out_path, "top_attributions.pdf")
474 plot_union_top_genes_heatmap(
475 df=df_attributions,
476 top_n=top_n_genes_heatmap,
477 cmap="plasma",
478 save=save_path,
479 )
480 else:
481 plot_union_top_genes_heatmap(
482 df=df_attributions, top_n=top_n_genes_heatmap, cmap="plasma", save=None
483 )
485 dict_top_genes = get_top_genes_per_dimension(
486 df=df_attributions, top_n=top_n_foreground_pathways
487 )
488 PE_dict = run_go_enrichment(
489 top_genes_dict=dict_top_genes,
490 n_top_pathways=n_top_pathways,
491 gene_set_library=gene_set_library,
492 organism=organism,
493 )
495 if save_out_path is not None:
496 base_save_identifier = os.path.join(save_out_path, "GO_pathways")
497 plot_GO_log_odds_all(
498 PE_dict=PE_dict,
499 top_n=n_top_pathways,
500 base_save_identifier=base_save_identifier,
501 )
502 else:
503 plot_GO_log_odds_all(
504 PE_dict=PE_dict, top_n=n_top_pathways, base_save_identifier=None
505 )
506 return df_attributions
509class Varix_EncoderSingleDim(nn.Module):
510 def __init__(self, vae_model, dim):
511 super(Varix_EncoderSingleDim, self).__init__()
512 # Accessing the required components from the original VAE model
513 self.encoder = vae_model._encoder
514 self.mu = vae_model._mu
515 self.logvar = vae_model._logvar
516 self.reparameterize = vae_model.reparameterize
517 self.input_dim = vae_model.input_dim
518 self.dim = dim # latent dim
520 def forward(self, x):
521 if x.shape[1] != self.input_dim:
522 raise ValueError(
523 f"Expected input with {self.input_dim} features, but got {x.shape[1]} features. "
524 f"This may indicate missing data modalities or incorrect data preparation."
525 )
527 total_elements = x.numel()
528 assert (
529 total_elements % self.input_dim == 0
530 ), f"Total elements {total_elements} is not a multiple of input_dim {self.input_dim}"
532 batch_size = x.size(0)
533 x = x.view(batch_size, -1)
534 latent = self.encoder(x)
535 mu = self.mu(latent)
536 logvar = self.logvar(latent)
537 z = self.reparameterize(mu, logvar)
538 output = z[:, self.dim]
539 output = output.unsqueeze(1) # Equivalent to output.reshape(output.shape[0], 1)
540 return output
543def make_DeepLiftShap_Varix_dim(model, inputs, baselines, latent_dimension):
544 model_encoder_dim = Varix_EncoderSingleDim(model, dim=latent_dimension)
545 deeplift = DeepLiftShap(model_encoder_dim)
546 attributions, convergence = deeplift.attribute(
547 inputs=inputs, baselines=baselines, return_convergence_delta=True
548 )
549 avg_abs_attributions = attributions.abs().mean(dim=0)
550 return avg_abs_attributions, convergence
553def make_feature_importance_Varix(
554 van,
555 method="DeepLiftShap",
556 n_subset=100,
557 seed_int=12,
558 baseline_type="mean",
559 baseline_group="all",
560 obs_col=None,
561):
562 """
563 Computes DeepLiftShap attributions for a trained Varix model .
565 Parameters
566 ----------
567 van : object
568 Object containing the trained Varix model and input data as an AnnData object.
570 method : str, {'DeepLiftShap', 'IntegratedGradients'}, default='DeepLiftShap'
571 post-hoc feature importance assessment method
572 - 'DeepLiftShap': method from pytorch Captum library, approximates SHAP values using Deeplift
573 - 'IntegratedGradients': method from pytorch Captum library, attribution via Integrated gradients
575 n_subset : int, default=100
576 Subset of randomly selected cells to compute attributions on.
578 seed_int : int, default=12
579 Seed for reproducible random sampling (NumPy & PyTorch).
581 baseline_type : str, {'mean', 'random_sample'}, default='mean'
582 How to generate the baseline:
583 - 'mean': average expression of the group
584 - 'random_sample': a randomly picked cell from the group
586 baseline_group : str, default='all'
587 Which group of cells to use as the baseline.
588 Use 'all' or a specific group label from `adata.obs[obs_col]`.
590 obs_col : str or None, default=None
591 Column in `adata.obs` that defines groups. Required if `baseline_group` is not 'all'.
593 Returns
594 -------
595 df_attributions : pandas.DataFrame
596 Gene-level attribution scores per latent dimension.
597 """
598 if baseline_group != "all" and obs_col is None:
599 raise ValueError(
600 "obs_col must be provided when using a group-specific baseline."
601 )
602 np.random.seed(seed_int)
603 torch.manual_seed(seed_int)
604 model = van.result.model
606 input_adata = van.raw_user_data["multi_sc"]["multi_sc"].mod["user-data"]
607 inputs = torch.tensor(
608 input_adata.X.toarray()
609 if scipy.sparse.issparse(input_adata.X)
610 else input_adata.X
611 )
612 if baseline_group == "all":
613 if baseline_type == "mean":
614 baseline_mean = inputs.mean(axis=0) # gene_means
615 baselines = torch.tensor(np.tile(baseline_mean, (inputs.shape[0], 1)))
616 if baseline_type == "random_sample":
617 baseline_random = inputs[torch.randint(0, inputs.size(0), (1,)).item()]
618 baselines = torch.tensor(np.tile(baseline_random, (inputs.shape[0], 1)))
619 else:
620 input_adata_filtered = input_adata[input_adata.obs[obs_col] == baseline_group]
621 inputs_filtered = torch.tensor(
622 input_adata_filtered.X.toarray()
623 if scipy.sparse.issparse(input_adata_filtered.X)
624 else input_adata_filtered.X
625 )
626 if baseline_type == "mean":
627 baseline_mean = inputs_filtered.mean(axis=0) # gene_means
628 baselines = torch.tensor(np.tile(baseline_mean, (inputs.shape[0], 1)))
629 if baseline_type == "random_sample":
630 baseline_random = inputs_filtered[
631 torch.randint(0, inputs_filtered.size(0), (1,)).item()
632 ]
633 baselines = torch.tensor(np.tile(baseline_random, (inputs.shape[0], 1)))
634 gene_names = input_adata.var_names
635 cell_IDs = input_adata.obs_names
637 latent_dimensions = list(range(0, van.result.adata_latent.shape[1]))
638 indices_DeepLiftShap = np.random.choice(
639 inputs.shape[0], size=n_subset, replace=False
640 )
641 all_attr = []
642 for latent_dim in latent_dimensions:
643 if method == "DeepLiftShap":
644 avg_abs_attributions, convergence = make_DeepLiftShap_Varix_dim(
645 model=model,
646 inputs=inputs[indices_DeepLiftShap].float(),
647 baselines=baselines[indices_DeepLiftShap].float(),
648 latent_dimension=latent_dim,
649 )
650 if method == "IntegratedGradients":
651 avg_abs_attributions, convergence = make_IntegratedGradients_Varix_dim(
652 model=model,
653 inputs=inputs[indices_DeepLiftShap].float(),
654 baselines=baselines[indices_DeepLiftShap].float(),
655 latent_dimension=latent_dim,
656 )
658 all_attr.append(avg_abs_attributions.detach().cpu())
659 attr_matrix = torch.stack(all_attr).T.numpy()
661 df_attributions = pd.DataFrame(
662 attr_matrix,
663 index=list(gene_names),
664 columns=[f"latent_dimension_{i}" for i in latent_dimensions],
665 )
666 return df_attributions
669def do_feature_importance_Varix(
670 van,
671 method="DeepLiftShap",
672 baseline_type="mean",
673 baseline_group="all",
674 obs_col=None,
675 n_subset=100,
676 seed_int=12,
677 do_visualizations=True,
678 top_n_genes_heatmap=50,
679 top_n_foreground_pathways=30,
680 gene_set_library=None,
681 organism="Human",
682 n_top_pathways=10,
683 save_out_path=None,
684 do_miraculix_vis=False,
685):
686 """
687 Computes DeepLiftShap attributions for a trained Varix model and visulaization.
689 Parameters
690 ----------
691 van : object
692 Object containing the trained Varix model and input data as an AnnData object.
694 method : str, {'DeepLiftShap', 'IntegratedGradients'}, default='DeepLiftShap'
695 post-hoc feature importance assessment method
696 - 'DeepLiftShap': method from pytorch Captum library, approximates SHAP values using Deeplift
697 - 'IntegratedGradients': method from pytorch Captum library, attribution via Integrated gradients
699 n_subset : int, default=100
700 Subset of randomly selected cells to compute attributions on.
702 seed_int : int, default=12
703 Seed for reproducible random sampling (NumPy & PyTorch).
705 baseline_type : str, {'mean', 'random_sample'}, default='mean'
706 How to generate the baseline:
707 - 'mean': average expression of the group
708 - 'random_sample': a randomly picked cell from the group
710 baseline_group : str, default='all'
711 Which group of cells to use as the baseline.
712 Use 'all' or a specific group label from `adata.obs[obs_col]`.
714 obs_col : str or None, default=None
715 Column in `adata.obs` that defines groups. Required if `baseline_group` is not 'all'.
717 top_n_genes_heatmap : int, default=50
718 Number of top genes (with highest attribution scores) per latent dimension to visualize in the heatmap.
720 top_n_foreground_pathways : int, default=30
721 Number of top genes to include as "foreground" in enrichment analysis.
723 gene_set_library : str, default=None,
724 The gene set library to use for enrichment analysis (from Enrichr), e.g. "GO_Biological_Process_2021".
726 organism : str, default='Human'
727 The organism relevant to the gene sets used in enrichment.
729 n_top_pathways : int, default=10
730 Number of top enriched pathways to visualize in the results per latent dimension.
732 save_out_path : str or None, default=None
733 File path to save the output results figures.
734 If None, results will not be saved to disk.
736 do_miraculix_vis : bool, default=True
737 Show Miraculix-style GIF.
739 Returns
740 -------
741 df_attributions : pandas.DataFrame
742 Gene-level attribution scores per latent dimension.
744 """
745 feature_importance_methods = {"DeepLiftShap", "IntegratedGradients"}
746 if method not in feature_importance_methods:
747 raise ValueError(
748 f"Invalid method '{method}'. Must be one of: {', '.join(feature_importance_methods)}."
749 )
751 # baseline_type
752 if do_miraculix_vis:
753 do_miraculix_visualization()
754 df_attributions = make_feature_importance_Varix(
755 van,
756 method=method,
757 n_subset=n_subset,
758 seed_int=seed_int,
759 baseline_type=baseline_type,
760 baseline_group=baseline_group,
761 obs_col=obs_col,
762 )
763 if save_out_path is not None:
764 save_path = os.path.join(save_out_path, "df_attributions.csv")
765 df_attributions.to_csv(save_path)
767 """
768 if not do_visualizations:
769 display(HTML('<p style="font-size:20px;">Feature importance completed!</p>'))
770 """
772 # top genes plot
773 if do_visualizations:
774 if save_out_path is not None:
775 save_path = os.path.join(save_out_path, "top_attributions.pdf")
776 plot_union_top_genes_heatmap(
777 df=df_attributions,
778 top_n=top_n_genes_heatmap,
779 cmap="plasma",
780 save=save_path,
781 )
782 else:
783 plot_union_top_genes_heatmap(
784 df=df_attributions, top_n=top_n_genes_heatmap, cmap="plasma", save=None
785 )
787 if gene_set_library is not None:
788 dict_top_genes = get_top_genes_per_dimension(
789 df=df_attributions, top_n=top_n_foreground_pathways
790 )
791 PE_dict = run_go_enrichment(
792 top_genes_dict=dict_top_genes,
793 n_top_pathways=n_top_pathways,
794 gene_set_library=gene_set_library,
795 organism=organism,
796 )
798 if save_out_path is not None:
799 base_save_identifier = os.path.join(save_out_path, "GO_pathways")
800 plot_GO_log_odds_all(
801 PE_dict=PE_dict,
802 top_n=n_top_pathways,
803 base_save_identifier=base_save_identifier,
804 )
805 else:
806 plot_GO_log_odds_all(
807 PE_dict=PE_dict, top_n=n_top_pathways, base_save_identifier=None
808 )
809 return df_attributions
812def make_IntegratedGradients_Varix_dim(model, inputs, baselines, latent_dimension):
813 model_encoder_dim = Varix_EncoderSingleDim(model, dim=latent_dimension)
814 integrated_gradients = IntegratedGradients(model_encoder_dim)
815 attributions, convergence = integrated_gradients.attribute(
816 inputs=inputs, baselines=baselines, return_convergence_delta=True
817 )
818 avg_abs_attributions = attributions.abs().mean(dim=0)
819 return avg_abs_attributions, convergence
822def make_IntegratedGradients_Vanillix_dim(model, inputs, baselines, latent_dimension):
823 model_encoder_dim = Vanillix_EncoderSingleDim(model, dim=latent_dimension)
824 integrated_gradients = IntegratedGradients(model_encoder_dim)
825 attributions, convergence = integrated_gradients.attribute(
826 inputs=inputs, baselines=baselines, return_convergence_delta=True
827 )
828 avg_abs_attributions = attributions.abs().mean(dim=0)
829 return avg_abs_attributions, convergence