Coverage for src / autoencodix / utils / feature_importance.py: 0%

259 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import os 

2import warnings 

3from typing import Dict 

4 

5import anndata as ad 

6import gseapy as gp 

7import matplotlib.pyplot as plt 

8import numpy as np 

9import pandas as pd 

10import scanpy as sc 

11import scipy 

12import seaborn as sns 

13import torch 

14import torch.nn as nn 

15from captum.attr import ( 

16 LRP, 

17 DeepLiftShap, 

18 GradientShap, 

19 IntegratedGradients, 

20 Lime, 

21 LimeBase, 

22) 

23from IPython.display import HTML, Image, clear_output, display 

24 

25from autoencodix.modeling._captum_forward import CaptumForward 

26from autoencodix.utils.adata_converter import AnnDataConverter 

27 

28warnings.filterwarnings( 

29 "ignore", 

30 message="Setting forward, backward hooks and attributes on non-linear", 

31 category=UserWarning, 

32) 

33 

34warnings.filterwarnings("ignore") 

35 

36 

37class Vanillix_EncoderSingleDim(nn.Module): 

38 def __init__(self, vae_model, dim): 

39 super(Vanillix_EncoderSingleDim, self).__init__() 

40 # Accessing the required components from the original VAE model 

41 self.encoder = vae_model._encoder 

42 self.input_dim = vae_model.input_dim 

43 self.dim = dim # latent dim 

44 

45 def forward(self, x): 

46 if x.shape[1] != self.input_dim: 

47 raise ValueError( 

48 f"Expected input with {self.input_dim} features, but got {x.shape[1]} features. " 

49 f"This may indicate missing data modalities or incorrect data preparation." 

50 ) 

51 

52 total_elements = x.numel() 

53 assert ( 

54 total_elements % self.input_dim == 0 

55 ), f"Total elements {total_elements} is not a multiple of input_dim {self.input_dim}" 

56 

57 batch_size = x.size(0) 

58 x = x.view(batch_size, -1) 

59 latent = self.encoder(x) 

60 output = latent[:, self.dim] 

61 output = output.unsqueeze(1) # Equivalent to output.reshape(output.shape[0], 1) 

62 return output 

63 

64 

65def make_DeepLiftShap_Vanillix_dim(model, inputs, baselines, latent_dimension): 

66 model_encoder_dim = Vanillix_EncoderSingleDim(model, dim=latent_dimension) 

67 deeplift = DeepLiftShap(model_encoder_dim) 

68 attributions, convergence = deeplift.attribute( 

69 inputs=inputs, baselines=baselines, return_convergence_delta=True 

70 ) 

71 avg_abs_attributions = attributions.abs().mean(dim=0) 

72 return avg_abs_attributions, convergence 

73 

74 

75def make_feature_importance_Vanillix( 

76 input_adata, 

77 van, 

78 method="DeepLiftShap", 

79 n_subset=100, 

80 seed_int=12, 

81 baseline_type="mean", 

82 baseline_group="all", 

83 obs_col=None, 

84): 

85 """ 

86 Computes attributions for a trained Vanillix model. 

87 

88 Parameters 

89 ---------- 

90 van : object 

91 Object containing the trained Vanillix model and input data as an AnnData object. 

92 

93 method : str, {'DeepLiftShap', 'IntegratedGradients'}, default='DeepLiftShap' 

94 post-hoc feature importance assessment method 

95 - 'DeepLiftShap': method from pytorch Captum library, approximates SHAP values using Deeplift 

96 - 'IntegratedGradients': method from pytorch Captum library, attribution via Integrated gradients 

97 

98 n_subset : int, default=100 

99 Subset of randomly selected cells to compute attributions on. 

100 

101 seed_int : int, default=12 

102 Seed for reproducible random sampling (NumPy & PyTorch). 

103 

104 baseline_type : str, {'mean', 'random_sample'}, default='mean' 

105 How to generate the baseline: 

106 - 'mean': average expression of the group 

107 - 'random_sample': a randomly picked cell from the group 

108 

109 baseline_group : str, default='all' 

110 Which group of cells to use as the baseline. 

111 Use 'all' or a specific group label from `adata.obs[obs_col]`. 

112 

113 obs_col : str or None, default=None 

114 Column in `adata.obs` that defines groups. Required if `baseline_group` is not 'all'. 

115 

116 Returns 

117 ------- 

118 df_attributions : pandas.DataFrame 

119 Gene-level attribution scores per latent dimension. 

120 """ 

121 if baseline_group != "all" and obs_col is None: 

122 raise ValueError( 

123 "obs_col must be provided when using a group-specific baseline." 

124 ) 

125 np.random.seed(seed_int) 

126 torch.manual_seed(seed_int) 

127 model = van.result.model 

128 # input_adata = van.raw_user_data["multi_sc"]["multi_sc"].mod["user-data"] 

129 inputs = torch.tensor( 

130 input_adata.X.toarray() 

131 if scipy.sparse.issparse(input_adata.X) 

132 else input_adata.X 

133 ) 

134 if baseline_group == "all": 

135 if baseline_type == "mean": 

136 baseline_mean = inputs.mean(axis=0) # gene_means 

137 baselines = torch.tensor(np.tile(baseline_mean, (inputs.shape[0], 1))) 

138 if baseline_type == "random_sample": 

139 baseline_random = inputs[torch.randint(0, inputs.size(0), (1,)).item()] 

140 baselines = torch.tensor(np.tile(baseline_random, (inputs.shape[0], 1))) 

141 else: 

142 input_adata_filtered = input_adata[input_adata.obs[obs_col] == baseline_group] 

143 inputs_filtered = torch.tensor( 

144 input_adata_filtered.X.toarray() 

145 if scipy.sparse.issparse(input_adata_filtered.X) 

146 else input_adata_filtered.X 

147 ) 

148 if baseline_type == "mean": 

149 baseline_mean = inputs_filtered.mean(axis=0) # gene_means 

150 baselines = torch.tensor(np.tile(baseline_mean, (inputs.shape[0], 1))) 

151 if baseline_type == "random_sample": 

152 baseline_random = inputs_filtered[ 

153 torch.randint(0, inputs_filtered.size(0), (1,)).item() 

154 ] 

155 baselines = torch.tensor(np.tile(baseline_random, (inputs.shape[0], 1))) 

156 gene_names = input_adata.var_names 

157 cell_IDs = input_adata.obs_names 

158 

159 latent_dimensions = list(range(0, van.result.adata_latent.shape[1])) 

160 indices_DeepLiftShap = np.random.choice( 

161 inputs.shape[0], size=n_subset, replace=False 

162 ) 

163 all_attr = [] 

164 for latent_dim in latent_dimensions: 

165 if method == "DeepLiftShap": 

166 avg_abs_attributions, convergence = make_DeepLiftShap_Vanillix_dim( 

167 model=model, 

168 inputs=inputs[indices_DeepLiftShap].float(), 

169 baselines=baselines[indices_DeepLiftShap].float(), 

170 latent_dimension=latent_dim, 

171 ) 

172 if method == "IntegratedGradients": 

173 avg_abs_attributions, convergence = make_IntegratedGradients_Vanillix_dim( 

174 model=model, 

175 inputs=inputs[indices_DeepLiftShap].float(), 

176 baselines=baselines[indices_DeepLiftShap].float(), 

177 latent_dimension=latent_dim, 

178 ) 

179 

180 all_attr.append(avg_abs_attributions.detach().cpu()) 

181 

182 attr_matrix = torch.stack(all_attr).T.numpy() 

183 

184 df_attributions = pd.DataFrame( 

185 attr_matrix, 

186 index=list(gene_names), 

187 columns=[f"latent_dimension_{i}" for i in latent_dimensions], 

188 ) 

189 return df_attributions 

190 

191 

192def get_top_kgenes_per_latent_dimension(df_attributions, latent_dim=0, topk=10): 

193 """ 

194 Returns the top-k genes with the highest attribution for a given latent dimension. 

195 

196 Parameters 

197 ---------- 

198 df_attributions : pandas.DataFrame 

199 DataFrame with genes as rows and latent dimensions as columns 

200 

201 latent_dim : int, default=0 

202 Index of the latent dimension to extract top genes from 

203 

204 topk : int, default=10 

205 Number of top genes to return. 

206 

207 Returns 

208 ------- 

209 topk_genes : list of str 

210 List of gene names with the highest attributions for the selected latent dimension. 

211 """ 

212 topk_genes = ( 

213 df_attributions["latent_dimension_" + str(latent_dim)] 

214 .nlargest(topk) 

215 .index.tolist() 

216 ) 

217 return topk_genes 

218 

219 

220def plot_union_top_genes_heatmap(df, top_n=50, cmap="viridis", save=None): 

221 """ 

222 Plots a heatmap for all latent dimensions, showing the union of the top N genes per dimension. 

223 

224 Parameters: 

225 - df (pd.DataFrame): DataFrame with feature attributions. 

226 Rows = genes/features, Columns = latent dimensions. 

227 - top_n (int): Number of top genes to select per latent dimension (default: 50). 

228 - cmap (str): Colormap for the heatmap (default: 'viridis'). 

229 - save (str) : path for saving 

230 """ 

231 

232 # Collect top N genes per latent dimension 

233 top_genes_sets = [] 

234 for col in df.columns: 

235 top_genes = df[col].nlargest(top_n).index 

236 top_genes_sets.append(set(top_genes)) 

237 

238 # Union of all top genes across dimensions 

239 union_genes = sorted(set.union(*top_genes_sets)) 

240 

241 # Subset the dataframe to only these genes 

242 data = df.loc[union_genes] 

243 

244 # Plot heatmap 

245 plt.figure(figsize=(len(df.columns) * 0.1 + 3, max(6, 0.14 * len(union_genes)))) 

246 ax = sns.heatmap( 

247 data, 

248 cmap=cmap, 

249 annot=False, 

250 linewidths=0.5, 

251 linecolor="gray", 

252 cbar_kws={"label": "Attribution Score"}, 

253 ) 

254 

255 # Format ticks (centered) 

256 # ax.set_xticks([i + 0.5 for i in range(len(data.columns))]) 

257 ax.set_xticklabels(data.columns, rotation=90, ha="right", fontsize=9) 

258 

259 ax.set_yticks([i + 0.5 for i in range(len(data.index))]) 

260 ax.set_yticklabels(data.index, rotation=0, fontsize=8) 

261 

262 plt.title(f"attribution of top {top_n} genes per latent dimension", fontsize=10) 

263 

264 plt.tight_layout() 

265 if save is not None: 

266 plt.savefig(save, bbox_inches="tight") 

267 plt.show() 

268 

269 

270def get_top_genes_per_dimension(df, top_n=100): 

271 """ 

272 Get top N genes per latent dimension. 

273 Returns a dict: {dimension_name: [list of top genes]} 

274 """ 

275 top_genes = {} 

276 for dim in df.columns: 

277 top_genes[dim] = df[dim].nlargest(top_n).index.tolist() 

278 return top_genes 

279 

280 

281def run_go_enrichment( 

282 top_genes_dict, 

283 n_top_pathways=10, 

284 gene_set_library="GO_Biological_Process_2021", 

285 organism="Human", 

286): 

287 """ 

288 Run GO enrichment using Enrichr (via gseapy). 

289 Returns a dict of DataFrames with results per latent dimension. 

290 """ 

291 results = {} 

292 for dim, gene_list in top_genes_dict.items(): 

293 enr = gp.enrichr( 

294 gene_list=gene_list, 

295 gene_sets=gene_set_library, 

296 organism=organism, 

297 outdir=None, # no file output 

298 cutoff=0.05, 

299 ) 

300 results[dim] = enr.results.sort_values("Adjusted P-value").head(n_top_pathways) 

301 return results 

302 

303 

304def plot_GO_log_odds_all(PE_dict, top_n=10, base_save_identifier=None): 

305 """ 

306 Plots log odds ratio of top enriched GO terms for each latent dimension. 

307 

308 Parameters: 

309 - PE_dict (dict): Dict of DataFrames, keyed by latent dimension name. 

310 - top_n (int): Number of top GO terms to plot per dimension. 

311 """ 

312 for dim, df in PE_dict.items(): 

313 if df.empty or "Odds Ratio" not in df.columns: 

314 print(f"Skipping {dim}: empty or missing required columns.") 

315 continue 

316 

317 # Compute log odds ratio 

318 df["log_odds_ratio"] = np.log(df["Odds Ratio"]) 

319 

320 # Select top_n terms 

321 plot_df = df.sort_values(by="log_odds_ratio", ascending=True).tail(top_n) 

322 

323 # Plot 

324 plt.figure(figsize=(10, 0.5 * top_n)) # wider and taller 

325 bars = plt.barh( 

326 plot_df["Term"], 

327 plot_df["log_odds_ratio"], 

328 color="#D4AF3A", 

329 height=0.6, # makes bars thicker 

330 ) 

331 

332 # Add value labels 

333 for bar in bars: 

334 plt.text( 

335 bar.get_width() - 1.05, 

336 bar.get_y() + bar.get_height() / 2, 

337 f"{bar.get_width():.2f}", 

338 va="center", 

339 ha="left", 

340 fontsize=8, 

341 color="#C40308", 

342 ) 

343 

344 plt.xlabel("log(odds Ratio)", fontsize=10) 

345 plt.title(f"{dim} — top {top_n} Enriched GO terms", fontsize=12) 

346 plt.xticks(fontsize=8) 

347 plt.yticks(fontsize=9) # smaller tick labels 

348 plt.tight_layout() 

349 if base_save_identifier is not None: 

350 plt.savefig( 

351 base_save_identifier + "_" + str(dim) + ".pdf", bbox_inches="tight" 

352 ) 

353 plt.show() 

354 

355 

356def do_miraculix_visualization(): 

357 display(HTML('<p style="font-size:20px;">Calculate feature importance</p>')) 

358 display(Image(filename="miraculix_zaubertrank.gif")) 

359 clear_output(wait=True) 

360 

361 

362def do_feature_importance_Vanillix( 

363 van, 

364 method="DeepLiftShap", 

365 baseline_type="mean", 

366 baseline_group="all", 

367 obs_col=None, 

368 n_subset=100, 

369 seed_int=12, 

370 do_visualizations=True, 

371 top_n_genes_heatmap=50, 

372 top_n_foreground_pathways=30, 

373 gene_set_library=None, 

374 organism="Human", 

375 n_top_pathways=10, 

376 save_out_path=None, 

377 do_miraculix_vis=False, 

378): 

379 """ 

380 Computes DeepLiftShap attributions for a trained Vanillix model and visulaization. 

381 

382 Parameters 

383 ---------- 

384 van : object 

385 Object containing the trained Vanillix model and input data as an AnnData object. 

386 

387 method : str, {'DeepLiftShap', 'IntegratedGradients'}, default='DeepLiftShap' 

388 post-hoc feature importance assessment method 

389 - 'DeepLiftShap': method from pytorch Captum library, approximates SHAP values using Deeplift 

390 - 'IntegratedGradients': method from pytorch Captum library, attribution via Integrated gradients 

391 

392 n_subset : int, default=100 

393 Subset of randomly selected cells to compute attributions on. 

394 

395 seed_int : int, default=12 

396 Seed for reproducible random sampling (NumPy & PyTorch). 

397 

398 baseline_type : str, {'mean', 'random_sample'}, default='mean' 

399 How to generate the baseline: 

400 - 'mean': average expression of the group 

401 - 'random_sample': a randomly picked cell from the group 

402 

403 baseline_group : str, default='all' 

404 Which group of cells to use as the baseline. 

405 Use 'all' or a specific group label from `adata.obs[obs_col]`. 

406 

407 obs_col : str or None, default=None 

408 Column in `adata.obs` that defines groups. Required if `baseline_group` is not 'all'. 

409 

410 top_n_genes_heatmap : int, default=50 

411 Number of top genes (with highest attribution scores) per latent dimension to visualize in the heatmap. 

412 

413 top_n_foreground_pathways : int, default=30 

414 Number of top genes to include as "foreground" in enrichment analysis. 

415 

416 gene_set_library : str, default=None, 

417 The gene set library to use for enrichment analysis (from Enrichr), e.g. "GO_Biological_Process_2021". 

418 

419 organism : str, default='Human' 

420 The organism relevant to the gene sets used in enrichment. 

421 

422 n_top_pathways : int, default=10 

423 Number of top enriched pathways to visualize in the results per latent dimension. 

424 

425 save_out_path : str or None, default=None 

426 File path to save the output results figures. 

427 If None, results will not be saved to disk. 

428 

429 do_miraculix_vis : bool, default=True 

430 Show Miraculix-style GIF. 

431 

432 Returns 

433 ------- 

434 df_attributions : pandas.DataFrame 

435 Gene-level attribution scores per latent dimension. 

436 """ 

437 feature_importance_methods = {"DeepLiftShap", "IntegratedGradients"} 

438 if method not in feature_importance_methods: 

439 raise ValueError( 

440 f"Invalid method '{method}'. Must be one of: {', '.join(feature_importance_methods)}." 

441 ) 

442 

443 # baseline_type 

444 if do_miraculix_vis: 

445 do_miraculix_visualization() 

446 

447 input_data: Dict[str, ad.Anndata] = AnnDataConverter.dataset_to_adata( 

448 datasetcontainer=van.result.datasets 

449 ) 

450 for data_name, adata in input_data.items(): 

451 df_attributions = make_feature_importance_Vanillix( 

452 input_adata=adata, 

453 van=van, 

454 method=method, 

455 n_subset=n_subset, 

456 seed_int=seed_int, 

457 baseline_type=baseline_type, 

458 baseline_group=baseline_group, 

459 obs_col=obs_col, 

460 ) 

461 if save_out_path is not None: 

462 save_path = os.path.join(save_out_path, "df_attributions.csv") 

463 df_attributions.to_csv(save_path) 

464 

465 """ 

466 if not do_visualizations: 

467 display(HTML('<p style="font-size:20px;">Feature importance completed!</p>')) 

468 """ 

469 

470 # top genes plot 

471 if do_visualizations: 

472 if save_out_path is not None: 

473 save_path = os.path.join(save_out_path, "top_attributions.pdf") 

474 plot_union_top_genes_heatmap( 

475 df=df_attributions, 

476 top_n=top_n_genes_heatmap, 

477 cmap="plasma", 

478 save=save_path, 

479 ) 

480 else: 

481 plot_union_top_genes_heatmap( 

482 df=df_attributions, top_n=top_n_genes_heatmap, cmap="plasma", save=None 

483 ) 

484 

485 dict_top_genes = get_top_genes_per_dimension( 

486 df=df_attributions, top_n=top_n_foreground_pathways 

487 ) 

488 PE_dict = run_go_enrichment( 

489 top_genes_dict=dict_top_genes, 

490 n_top_pathways=n_top_pathways, 

491 gene_set_library=gene_set_library, 

492 organism=organism, 

493 ) 

494 

495 if save_out_path is not None: 

496 base_save_identifier = os.path.join(save_out_path, "GO_pathways") 

497 plot_GO_log_odds_all( 

498 PE_dict=PE_dict, 

499 top_n=n_top_pathways, 

500 base_save_identifier=base_save_identifier, 

501 ) 

502 else: 

503 plot_GO_log_odds_all( 

504 PE_dict=PE_dict, top_n=n_top_pathways, base_save_identifier=None 

505 ) 

506 return df_attributions 

507 

508 

509class Varix_EncoderSingleDim(nn.Module): 

510 def __init__(self, vae_model, dim): 

511 super(Varix_EncoderSingleDim, self).__init__() 

512 # Accessing the required components from the original VAE model 

513 self.encoder = vae_model._encoder 

514 self.mu = vae_model._mu 

515 self.logvar = vae_model._logvar 

516 self.reparameterize = vae_model.reparameterize 

517 self.input_dim = vae_model.input_dim 

518 self.dim = dim # latent dim 

519 

520 def forward(self, x): 

521 if x.shape[1] != self.input_dim: 

522 raise ValueError( 

523 f"Expected input with {self.input_dim} features, but got {x.shape[1]} features. " 

524 f"This may indicate missing data modalities or incorrect data preparation." 

525 ) 

526 

527 total_elements = x.numel() 

528 assert ( 

529 total_elements % self.input_dim == 0 

530 ), f"Total elements {total_elements} is not a multiple of input_dim {self.input_dim}" 

531 

532 batch_size = x.size(0) 

533 x = x.view(batch_size, -1) 

534 latent = self.encoder(x) 

535 mu = self.mu(latent) 

536 logvar = self.logvar(latent) 

537 z = self.reparameterize(mu, logvar) 

538 output = z[:, self.dim] 

539 output = output.unsqueeze(1) # Equivalent to output.reshape(output.shape[0], 1) 

540 return output 

541 

542 

543def make_DeepLiftShap_Varix_dim(model, inputs, baselines, latent_dimension): 

544 model_encoder_dim = Varix_EncoderSingleDim(model, dim=latent_dimension) 

545 deeplift = DeepLiftShap(model_encoder_dim) 

546 attributions, convergence = deeplift.attribute( 

547 inputs=inputs, baselines=baselines, return_convergence_delta=True 

548 ) 

549 avg_abs_attributions = attributions.abs().mean(dim=0) 

550 return avg_abs_attributions, convergence 

551 

552 

553def make_feature_importance_Varix( 

554 van, 

555 method="DeepLiftShap", 

556 n_subset=100, 

557 seed_int=12, 

558 baseline_type="mean", 

559 baseline_group="all", 

560 obs_col=None, 

561): 

562 """ 

563 Computes DeepLiftShap attributions for a trained Varix model . 

564 

565 Parameters 

566 ---------- 

567 van : object 

568 Object containing the trained Varix model and input data as an AnnData object. 

569 

570 method : str, {'DeepLiftShap', 'IntegratedGradients'}, default='DeepLiftShap' 

571 post-hoc feature importance assessment method 

572 - 'DeepLiftShap': method from pytorch Captum library, approximates SHAP values using Deeplift 

573 - 'IntegratedGradients': method from pytorch Captum library, attribution via Integrated gradients 

574 

575 n_subset : int, default=100 

576 Subset of randomly selected cells to compute attributions on. 

577 

578 seed_int : int, default=12 

579 Seed for reproducible random sampling (NumPy & PyTorch). 

580 

581 baseline_type : str, {'mean', 'random_sample'}, default='mean' 

582 How to generate the baseline: 

583 - 'mean': average expression of the group 

584 - 'random_sample': a randomly picked cell from the group 

585 

586 baseline_group : str, default='all' 

587 Which group of cells to use as the baseline. 

588 Use 'all' or a specific group label from `adata.obs[obs_col]`. 

589 

590 obs_col : str or None, default=None 

591 Column in `adata.obs` that defines groups. Required if `baseline_group` is not 'all'. 

592 

593 Returns 

594 ------- 

595 df_attributions : pandas.DataFrame 

596 Gene-level attribution scores per latent dimension. 

597 """ 

598 if baseline_group != "all" and obs_col is None: 

599 raise ValueError( 

600 "obs_col must be provided when using a group-specific baseline." 

601 ) 

602 np.random.seed(seed_int) 

603 torch.manual_seed(seed_int) 

604 model = van.result.model 

605 

606 input_adata = van.raw_user_data["multi_sc"]["multi_sc"].mod["user-data"] 

607 inputs = torch.tensor( 

608 input_adata.X.toarray() 

609 if scipy.sparse.issparse(input_adata.X) 

610 else input_adata.X 

611 ) 

612 if baseline_group == "all": 

613 if baseline_type == "mean": 

614 baseline_mean = inputs.mean(axis=0) # gene_means 

615 baselines = torch.tensor(np.tile(baseline_mean, (inputs.shape[0], 1))) 

616 if baseline_type == "random_sample": 

617 baseline_random = inputs[torch.randint(0, inputs.size(0), (1,)).item()] 

618 baselines = torch.tensor(np.tile(baseline_random, (inputs.shape[0], 1))) 

619 else: 

620 input_adata_filtered = input_adata[input_adata.obs[obs_col] == baseline_group] 

621 inputs_filtered = torch.tensor( 

622 input_adata_filtered.X.toarray() 

623 if scipy.sparse.issparse(input_adata_filtered.X) 

624 else input_adata_filtered.X 

625 ) 

626 if baseline_type == "mean": 

627 baseline_mean = inputs_filtered.mean(axis=0) # gene_means 

628 baselines = torch.tensor(np.tile(baseline_mean, (inputs.shape[0], 1))) 

629 if baseline_type == "random_sample": 

630 baseline_random = inputs_filtered[ 

631 torch.randint(0, inputs_filtered.size(0), (1,)).item() 

632 ] 

633 baselines = torch.tensor(np.tile(baseline_random, (inputs.shape[0], 1))) 

634 gene_names = input_adata.var_names 

635 cell_IDs = input_adata.obs_names 

636 

637 latent_dimensions = list(range(0, van.result.adata_latent.shape[1])) 

638 indices_DeepLiftShap = np.random.choice( 

639 inputs.shape[0], size=n_subset, replace=False 

640 ) 

641 all_attr = [] 

642 for latent_dim in latent_dimensions: 

643 if method == "DeepLiftShap": 

644 avg_abs_attributions, convergence = make_DeepLiftShap_Varix_dim( 

645 model=model, 

646 inputs=inputs[indices_DeepLiftShap].float(), 

647 baselines=baselines[indices_DeepLiftShap].float(), 

648 latent_dimension=latent_dim, 

649 ) 

650 if method == "IntegratedGradients": 

651 avg_abs_attributions, convergence = make_IntegratedGradients_Varix_dim( 

652 model=model, 

653 inputs=inputs[indices_DeepLiftShap].float(), 

654 baselines=baselines[indices_DeepLiftShap].float(), 

655 latent_dimension=latent_dim, 

656 ) 

657 

658 all_attr.append(avg_abs_attributions.detach().cpu()) 

659 attr_matrix = torch.stack(all_attr).T.numpy() 

660 

661 df_attributions = pd.DataFrame( 

662 attr_matrix, 

663 index=list(gene_names), 

664 columns=[f"latent_dimension_{i}" for i in latent_dimensions], 

665 ) 

666 return df_attributions 

667 

668 

669def do_feature_importance_Varix( 

670 van, 

671 method="DeepLiftShap", 

672 baseline_type="mean", 

673 baseline_group="all", 

674 obs_col=None, 

675 n_subset=100, 

676 seed_int=12, 

677 do_visualizations=True, 

678 top_n_genes_heatmap=50, 

679 top_n_foreground_pathways=30, 

680 gene_set_library=None, 

681 organism="Human", 

682 n_top_pathways=10, 

683 save_out_path=None, 

684 do_miraculix_vis=False, 

685): 

686 """ 

687 Computes DeepLiftShap attributions for a trained Varix model and visulaization. 

688 

689 Parameters 

690 ---------- 

691 van : object 

692 Object containing the trained Varix model and input data as an AnnData object. 

693 

694 method : str, {'DeepLiftShap', 'IntegratedGradients'}, default='DeepLiftShap' 

695 post-hoc feature importance assessment method 

696 - 'DeepLiftShap': method from pytorch Captum library, approximates SHAP values using Deeplift 

697 - 'IntegratedGradients': method from pytorch Captum library, attribution via Integrated gradients 

698 

699 n_subset : int, default=100 

700 Subset of randomly selected cells to compute attributions on. 

701 

702 seed_int : int, default=12 

703 Seed for reproducible random sampling (NumPy & PyTorch). 

704 

705 baseline_type : str, {'mean', 'random_sample'}, default='mean' 

706 How to generate the baseline: 

707 - 'mean': average expression of the group 

708 - 'random_sample': a randomly picked cell from the group 

709 

710 baseline_group : str, default='all' 

711 Which group of cells to use as the baseline. 

712 Use 'all' or a specific group label from `adata.obs[obs_col]`. 

713 

714 obs_col : str or None, default=None 

715 Column in `adata.obs` that defines groups. Required if `baseline_group` is not 'all'. 

716 

717 top_n_genes_heatmap : int, default=50 

718 Number of top genes (with highest attribution scores) per latent dimension to visualize in the heatmap. 

719 

720 top_n_foreground_pathways : int, default=30 

721 Number of top genes to include as "foreground" in enrichment analysis. 

722 

723 gene_set_library : str, default=None, 

724 The gene set library to use for enrichment analysis (from Enrichr), e.g. "GO_Biological_Process_2021". 

725 

726 organism : str, default='Human' 

727 The organism relevant to the gene sets used in enrichment. 

728 

729 n_top_pathways : int, default=10 

730 Number of top enriched pathways to visualize in the results per latent dimension. 

731 

732 save_out_path : str or None, default=None 

733 File path to save the output results figures. 

734 If None, results will not be saved to disk. 

735 

736 do_miraculix_vis : bool, default=True 

737 Show Miraculix-style GIF. 

738 

739 Returns 

740 ------- 

741 df_attributions : pandas.DataFrame 

742 Gene-level attribution scores per latent dimension. 

743 

744 """ 

745 feature_importance_methods = {"DeepLiftShap", "IntegratedGradients"} 

746 if method not in feature_importance_methods: 

747 raise ValueError( 

748 f"Invalid method '{method}'. Must be one of: {', '.join(feature_importance_methods)}." 

749 ) 

750 

751 # baseline_type 

752 if do_miraculix_vis: 

753 do_miraculix_visualization() 

754 df_attributions = make_feature_importance_Varix( 

755 van, 

756 method=method, 

757 n_subset=n_subset, 

758 seed_int=seed_int, 

759 baseline_type=baseline_type, 

760 baseline_group=baseline_group, 

761 obs_col=obs_col, 

762 ) 

763 if save_out_path is not None: 

764 save_path = os.path.join(save_out_path, "df_attributions.csv") 

765 df_attributions.to_csv(save_path) 

766 

767 """ 

768 if not do_visualizations: 

769 display(HTML('<p style="font-size:20px;">Feature importance completed!</p>')) 

770 """ 

771 

772 # top genes plot 

773 if do_visualizations: 

774 if save_out_path is not None: 

775 save_path = os.path.join(save_out_path, "top_attributions.pdf") 

776 plot_union_top_genes_heatmap( 

777 df=df_attributions, 

778 top_n=top_n_genes_heatmap, 

779 cmap="plasma", 

780 save=save_path, 

781 ) 

782 else: 

783 plot_union_top_genes_heatmap( 

784 df=df_attributions, top_n=top_n_genes_heatmap, cmap="plasma", save=None 

785 ) 

786 

787 if gene_set_library is not None: 

788 dict_top_genes = get_top_genes_per_dimension( 

789 df=df_attributions, top_n=top_n_foreground_pathways 

790 ) 

791 PE_dict = run_go_enrichment( 

792 top_genes_dict=dict_top_genes, 

793 n_top_pathways=n_top_pathways, 

794 gene_set_library=gene_set_library, 

795 organism=organism, 

796 ) 

797 

798 if save_out_path is not None: 

799 base_save_identifier = os.path.join(save_out_path, "GO_pathways") 

800 plot_GO_log_odds_all( 

801 PE_dict=PE_dict, 

802 top_n=n_top_pathways, 

803 base_save_identifier=base_save_identifier, 

804 ) 

805 else: 

806 plot_GO_log_odds_all( 

807 PE_dict=PE_dict, top_n=n_top_pathways, base_save_identifier=None 

808 ) 

809 return df_attributions 

810 

811 

812def make_IntegratedGradients_Varix_dim(model, inputs, baselines, latent_dimension): 

813 model_encoder_dim = Varix_EncoderSingleDim(model, dim=latent_dimension) 

814 integrated_gradients = IntegratedGradients(model_encoder_dim) 

815 attributions, convergence = integrated_gradients.attribute( 

816 inputs=inputs, baselines=baselines, return_convergence_delta=True 

817 ) 

818 avg_abs_attributions = attributions.abs().mean(dim=0) 

819 return avg_abs_attributions, convergence 

820 

821 

822def make_IntegratedGradients_Vanillix_dim(model, inputs, baselines, latent_dimension): 

823 model_encoder_dim = Vanillix_EncoderSingleDim(model, dim=latent_dimension) 

824 integrated_gradients = IntegratedGradients(model_encoder_dim) 

825 attributions, convergence = integrated_gradients.attribute( 

826 inputs=inputs, baselines=baselines, return_convergence_delta=True 

827 ) 

828 avg_abs_attributions = attributions.abs().mean(dim=0) 

829 return avg_abs_attributions, convergence