Coverage for src / autoencodix / utils / _explainer.py: 15%
96 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2import numpy as np
3import scipy
4import pandas as pd
5from autoencodix.modeling._captum_forward import CaptumForward
6from typing import Optional, Union
7from captum.attr import (
8 DeepLiftShap,
9 IntegratedGradients,
10)
11import warnings
13warnings.filterwarnings(
14 "ignore",
15 message="Setting forward, backward hooks and attributes on non-linear",
16 category=UserWarning,
17)
19warnings.filterwarnings("ignore")
22class FeatureImportanceExplainer:
23 """
24 More complete version with:
25 - method selection (DeepLiftShap, IG, etc.)
26 - baseline construction (mean / random / grouped)
27 - subset sampling
28 - reproducible randomness
29 """
31 def __init__(
32 self,
33 adata,
34 model,
35 method: str = "DeepLiftShap",
36 sel_latent_dim: Union[
37 list, int, None
38 ] = None, # list or int of latent dimensions to explain, if None explain all
39 n_subset: int = 100, # Randomly sample in input and baseline space for computational efficiency, if n_subset is None use all samples
40 seed_int: int = 12,
41 input_type: str = "random", # "random" or "grouped" for selecting subset of inputs to explain
42 input_group: Optional[
43 str
44 ] = None, # column in .obs for grouping inputs must be provided if input_type is "grouped"
45 baseline_type: str = "random", # "random" or "mean" or "grouped" for selecting baseline samples
46 baseline_group: str = None, # column in .obs for grouping inputs must be provided if input_type is "grouped", optional for "mean", ignored for "random"
47 anno_col: Optional[str] = None, # column in .obs for grouping
48 ):
49 super(FeatureImportanceExplainer, self).__init__()
50 self.adata_ACX = adata
51 self.model = model
52 self.latent_dim = model.config.latent_dim
54 self.method = method
55 if self.method not in {"DeepLiftShap", "IntegratedGradients"}:
56 raise ValueError(f"Invalid method {method}.")
57 self.n_subset = n_subset
58 self.seed_int = seed_int
59 self.baseline_type = baseline_type
60 self.baseline_group = baseline_group
61 self.anno_col = anno_col
62 self.input_type = input_type
63 self.input_group = input_group
64 self.sel_latent_dim = sel_latent_dim
66 torch.manual_seed(seed_int)
67 np.random.seed(seed_int)
69 def explain(self):
70 adata_ACX = self.adata_ACX
71 gene_names = adata_ACX.var_names
72 inputs, baselines = return_inputs_baseline(
73 adata_ACX,
74 self.baseline_group,
75 self.baseline_type,
76 self.anno_col,
77 self.input_type,
78 self.input_group,
79 )
80 # set n_samples to whatever is lowest from: inputs, baselines, n_subset (if not None)
81 n_samples = min(inputs.shape[0], baselines.shape[0])
82 if self.n_subset is not None:
83 n_samples = min(n_samples, self.n_subset)
85 # Downsample input
86 if self.n_subset is not None:
87 indices_keep = np.random.choice(
88 inputs.shape[0], size=n_samples, replace=False
89 )
90 else:
91 indices_keep = np.arange(inputs.shape[0])
92 # Downsample baselines in the same way
93 if self.n_subset is not None:
94 indices_keep_baseline = np.random.choice(
95 baselines.shape[0], size=n_samples, replace=False
96 )
97 else:
98 indices_keep_baseline = np.arange(baselines.shape[0])
99 # for latent_dim in range(self.latent_dim):
100 all_attr = []
101 if isinstance(
102 self.sel_latent_dim, list
103 ): # Provide list of latent dimensions to explain
104 latent_dims = self.sel_latent_dim
105 elif isinstance(
106 self.sel_latent_dim, int
107 ): # Provide single latent dimension to explain
108 latent_dims = [self.sel_latent_dim]
109 elif self.sel_latent_dim is None: # Explain all latent dimensions
110 latent_dims = list(range(self.latent_dim))
111 else:
112 raise ValueError(
113 f"Invalid sel_latent_dim {self.sel_latent_dim}. Must be int, list of ints, or None."
114 )
116 for latent_dim in latent_dims:
117 print("Calculating attributions for latent dimension:", latent_dim)
118 cp_forward_dim = CaptumForward(model=self.model, dim=latent_dim)
119 if self.method == "DeepLiftShap":
120 cp_explainer = DeepLiftShap(cp_forward_dim)
121 if self.method == "IntegratedGradients":
122 cp_explainer = IntegratedGradients(cp_forward_dim)
123 attributions = cp_explainer.attribute(
124 inputs=inputs[indices_keep].float(),
125 baselines=baselines[indices_keep_baseline].float(),
126 return_convergence_delta=False,
127 )
128 avg_abs_attributions = attributions.abs().mean(dim=0)
129 all_attr.append(avg_abs_attributions.detach().cpu())
130 attr_matrix = torch.stack(all_attr).T.numpy()
131 if hasattr(self.model, "ontologies") and self.model.ontologies is not None:
132 cols = [list(self.model.ontologies[0].keys())[i] for i in latent_dims]
133 else:
134 cols = [f"Latent_Dim_{i}" for i in latent_dims]
135 df_attributions = pd.DataFrame(
136 attr_matrix,
137 index=list(gene_names),
138 columns=cols,
139 )
141 return df_attributions
144def return_inputs_baseline(
145 adata, baseline_group, baseline_type, anno_col, input_type, input_group
146):
147 baseline_torch_all = torch.tensor(
148 adata.X.toarray() if scipy.sparse.issparse(adata.X) else adata.X
149 )
150 ## Define inputs
151 if input_type == "grouped":
152 if anno_col is None or input_group is None:
153 raise ValueError(
154 "If input_type is 'grouped', anno_col and input_group must be provided to specify the grouping column in .obs."
155 )
156 input_adata = adata[adata.obs[anno_col] == input_group]
157 inputs = torch.tensor(
158 input_adata.X.toarray()
159 if scipy.sparse.issparse(input_adata.X)
160 else input_adata.X
161 )
162 elif input_type == "random":
163 inputs = torch.tensor(
164 adata.X.toarray() if scipy.sparse.issparse(adata.X) else adata.X
165 )
167 ## Define baselines
168 if baseline_type == "random":
169 # Generates a bootstrap random from input in the same size
170 baseline_random = baseline_torch_all[
171 torch.randint(0, baseline_torch_all.size(0), (1,)).item()
172 ]
173 baselines = torch.tensor(
174 np.tile(baseline_random, (baseline_torch_all.shape[0], 1))
175 )
176 elif baseline_type == "grouped":
177 if anno_col is None or baseline_group is None:
178 raise ValueError(
179 "If baseline_type is 'grouped', anno_col and baseline_group must be provided to specify the grouping column in .obs."
180 )
181 base_adata_filtered = adata[adata.obs[anno_col] == baseline_group]
182 base_filtered = torch.tensor(
183 base_adata_filtered.X.toarray()
184 if scipy.sparse.issparse(base_adata_filtered.X)
185 else base_adata_filtered.X
186 )
187 # Generates a bootstrap random from the filtered input in the same size
188 baseline_grouped = base_filtered[
189 torch.randint(0, base_filtered.size(0), (1,)).item()
190 ]
191 baselines = torch.tensor(
192 np.tile(baseline_grouped, (base_adata_filtered.shape[0], 1))
193 )
194 elif baseline_type == "mean":
195 if baseline_group == None:
196 baseline_mean = baseline_torch_all.mean(axis=0) # gene_means
197 else:
198 if anno_col is None:
199 raise ValueError(
200 "If baseline_group is not 'all', anno_col must be provided to specify the grouping column in .obs."
201 )
202 base_adata_filtered = adata[adata.obs[anno_col] == baseline_group]
203 base_filtered = torch.tensor(
204 base_adata_filtered.X.toarray()
205 if scipy.sparse.issparse(base_adata_filtered.X)
206 else base_adata_filtered.X
207 )
208 baseline_mean = base_filtered.mean(axis=0) # gene_means
209 baselines = torch.tensor(
210 np.tile(baseline_mean, (baseline_torch_all.shape[0], 1))
211 )
212 else:
213 raise ValueError(f"Invalid baseline_type {baseline_type}.")
215 # inputs = torch.tensor(
216 # input_adata.X.toarray()
217 # if scipy.sparse.issparse(input_adata.X)
218 # else input_adata.X
219 # )
220 # if baseline_group == "all":
221 # if baseline_type == "mean":
222 # baseline_mean = inputs.mean(axis=0) # gene_means
223 # baselines = torch.tensor(np.tile(baseline_mean, (inputs.shape[0], 1)))
224 # if baseline_type == "random_sample":
225 # baseline_random = inputs[torch.randint(0, inputs.size(0), (1,)).item()]
226 # baselines = torch.tensor(np.tile(baseline_random, (inputs.shape[0], 1)))
227 # else:
228 # if anno_col is None:
229 # raise ValueError(
230 # "If baseline_group is not 'all', anno_col must be provided to specify the grouping column in .obs."
231 # )
232 # input_adata_filtered = input_adata[input_adata.obs[anno_col] == baseline_group]
233 # inputs_filtered = torch.tensor(
234 # input_adata_filtered.X.toarray()
235 # if scipy.sparse.issparse(input_adata_filtered.X)
236 # else input_adata_filtered.X
237 # )
238 # if baseline_type == "mean":
239 # baseline_mean = inputs_filtered.mean(axis=0) # gene_means
240 # baselines = torch.tensor(np.tile(baseline_mean, (inputs.shape[0], 1)))
241 # if baseline_type == "random_sample":
242 # baseline_random = inputs_filtered[
243 # torch.randint(0, inputs_filtered.size(0), (1,)).item()
244 # ]
245 # baselines = torch.tensor(np.tile(baseline_random, (inputs.shape[0], 1)))
246 return inputs, baselines