Coverage for src / autoencodix / utils / _explainer.py: 15%

96 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2import numpy as np 

3import scipy 

4import pandas as pd 

5from autoencodix.modeling._captum_forward import CaptumForward 

6from typing import Optional, Union 

7from captum.attr import ( 

8 DeepLiftShap, 

9 IntegratedGradients, 

10) 

11import warnings 

12 

13warnings.filterwarnings( 

14 "ignore", 

15 message="Setting forward, backward hooks and attributes on non-linear", 

16 category=UserWarning, 

17) 

18 

19warnings.filterwarnings("ignore") 

20 

21 

22class FeatureImportanceExplainer: 

23 """ 

24 More complete version with: 

25 - method selection (DeepLiftShap, IG, etc.) 

26 - baseline construction (mean / random / grouped) 

27 - subset sampling 

28 - reproducible randomness 

29 """ 

30 

31 def __init__( 

32 self, 

33 adata, 

34 model, 

35 method: str = "DeepLiftShap", 

36 sel_latent_dim: Union[ 

37 list, int, None 

38 ] = None, # list or int of latent dimensions to explain, if None explain all 

39 n_subset: int = 100, # Randomly sample in input and baseline space for computational efficiency, if n_subset is None use all samples 

40 seed_int: int = 12, 

41 input_type: str = "random", # "random" or "grouped" for selecting subset of inputs to explain 

42 input_group: Optional[ 

43 str 

44 ] = None, # column in .obs for grouping inputs must be provided if input_type is "grouped" 

45 baseline_type: str = "random", # "random" or "mean" or "grouped" for selecting baseline samples 

46 baseline_group: str = None, # column in .obs for grouping inputs must be provided if input_type is "grouped", optional for "mean", ignored for "random" 

47 anno_col: Optional[str] = None, # column in .obs for grouping 

48 ): 

49 super(FeatureImportanceExplainer, self).__init__() 

50 self.adata_ACX = adata 

51 self.model = model 

52 self.latent_dim = model.config.latent_dim 

53 

54 self.method = method 

55 if self.method not in {"DeepLiftShap", "IntegratedGradients"}: 

56 raise ValueError(f"Invalid method {method}.") 

57 self.n_subset = n_subset 

58 self.seed_int = seed_int 

59 self.baseline_type = baseline_type 

60 self.baseline_group = baseline_group 

61 self.anno_col = anno_col 

62 self.input_type = input_type 

63 self.input_group = input_group 

64 self.sel_latent_dim = sel_latent_dim 

65 

66 torch.manual_seed(seed_int) 

67 np.random.seed(seed_int) 

68 

69 def explain(self): 

70 adata_ACX = self.adata_ACX 

71 gene_names = adata_ACX.var_names 

72 inputs, baselines = return_inputs_baseline( 

73 adata_ACX, 

74 self.baseline_group, 

75 self.baseline_type, 

76 self.anno_col, 

77 self.input_type, 

78 self.input_group, 

79 ) 

80 # set n_samples to whatever is lowest from: inputs, baselines, n_subset (if not None) 

81 n_samples = min(inputs.shape[0], baselines.shape[0]) 

82 if self.n_subset is not None: 

83 n_samples = min(n_samples, self.n_subset) 

84 

85 # Downsample input 

86 if self.n_subset is not None: 

87 indices_keep = np.random.choice( 

88 inputs.shape[0], size=n_samples, replace=False 

89 ) 

90 else: 

91 indices_keep = np.arange(inputs.shape[0]) 

92 # Downsample baselines in the same way 

93 if self.n_subset is not None: 

94 indices_keep_baseline = np.random.choice( 

95 baselines.shape[0], size=n_samples, replace=False 

96 ) 

97 else: 

98 indices_keep_baseline = np.arange(baselines.shape[0]) 

99 # for latent_dim in range(self.latent_dim): 

100 all_attr = [] 

101 if isinstance( 

102 self.sel_latent_dim, list 

103 ): # Provide list of latent dimensions to explain 

104 latent_dims = self.sel_latent_dim 

105 elif isinstance( 

106 self.sel_latent_dim, int 

107 ): # Provide single latent dimension to explain 

108 latent_dims = [self.sel_latent_dim] 

109 elif self.sel_latent_dim is None: # Explain all latent dimensions 

110 latent_dims = list(range(self.latent_dim)) 

111 else: 

112 raise ValueError( 

113 f"Invalid sel_latent_dim {self.sel_latent_dim}. Must be int, list of ints, or None." 

114 ) 

115 

116 for latent_dim in latent_dims: 

117 print("Calculating attributions for latent dimension:", latent_dim) 

118 cp_forward_dim = CaptumForward(model=self.model, dim=latent_dim) 

119 if self.method == "DeepLiftShap": 

120 cp_explainer = DeepLiftShap(cp_forward_dim) 

121 if self.method == "IntegratedGradients": 

122 cp_explainer = IntegratedGradients(cp_forward_dim) 

123 attributions = cp_explainer.attribute( 

124 inputs=inputs[indices_keep].float(), 

125 baselines=baselines[indices_keep_baseline].float(), 

126 return_convergence_delta=False, 

127 ) 

128 avg_abs_attributions = attributions.abs().mean(dim=0) 

129 all_attr.append(avg_abs_attributions.detach().cpu()) 

130 attr_matrix = torch.stack(all_attr).T.numpy() 

131 if hasattr(self.model, "ontologies") and self.model.ontologies is not None: 

132 cols = [list(self.model.ontologies[0].keys())[i] for i in latent_dims] 

133 else: 

134 cols = [f"Latent_Dim_{i}" for i in latent_dims] 

135 df_attributions = pd.DataFrame( 

136 attr_matrix, 

137 index=list(gene_names), 

138 columns=cols, 

139 ) 

140 

141 return df_attributions 

142 

143 

144def return_inputs_baseline( 

145 adata, baseline_group, baseline_type, anno_col, input_type, input_group 

146): 

147 baseline_torch_all = torch.tensor( 

148 adata.X.toarray() if scipy.sparse.issparse(adata.X) else adata.X 

149 ) 

150 ## Define inputs 

151 if input_type == "grouped": 

152 if anno_col is None or input_group is None: 

153 raise ValueError( 

154 "If input_type is 'grouped', anno_col and input_group must be provided to specify the grouping column in .obs." 

155 ) 

156 input_adata = adata[adata.obs[anno_col] == input_group] 

157 inputs = torch.tensor( 

158 input_adata.X.toarray() 

159 if scipy.sparse.issparse(input_adata.X) 

160 else input_adata.X 

161 ) 

162 elif input_type == "random": 

163 inputs = torch.tensor( 

164 adata.X.toarray() if scipy.sparse.issparse(adata.X) else adata.X 

165 ) 

166 

167 ## Define baselines 

168 if baseline_type == "random": 

169 # Generates a bootstrap random from input in the same size 

170 baseline_random = baseline_torch_all[ 

171 torch.randint(0, baseline_torch_all.size(0), (1,)).item() 

172 ] 

173 baselines = torch.tensor( 

174 np.tile(baseline_random, (baseline_torch_all.shape[0], 1)) 

175 ) 

176 elif baseline_type == "grouped": 

177 if anno_col is None or baseline_group is None: 

178 raise ValueError( 

179 "If baseline_type is 'grouped', anno_col and baseline_group must be provided to specify the grouping column in .obs." 

180 ) 

181 base_adata_filtered = adata[adata.obs[anno_col] == baseline_group] 

182 base_filtered = torch.tensor( 

183 base_adata_filtered.X.toarray() 

184 if scipy.sparse.issparse(base_adata_filtered.X) 

185 else base_adata_filtered.X 

186 ) 

187 # Generates a bootstrap random from the filtered input in the same size 

188 baseline_grouped = base_filtered[ 

189 torch.randint(0, base_filtered.size(0), (1,)).item() 

190 ] 

191 baselines = torch.tensor( 

192 np.tile(baseline_grouped, (base_adata_filtered.shape[0], 1)) 

193 ) 

194 elif baseline_type == "mean": 

195 if baseline_group == None: 

196 baseline_mean = baseline_torch_all.mean(axis=0) # gene_means 

197 else: 

198 if anno_col is None: 

199 raise ValueError( 

200 "If baseline_group is not 'all', anno_col must be provided to specify the grouping column in .obs." 

201 ) 

202 base_adata_filtered = adata[adata.obs[anno_col] == baseline_group] 

203 base_filtered = torch.tensor( 

204 base_adata_filtered.X.toarray() 

205 if scipy.sparse.issparse(base_adata_filtered.X) 

206 else base_adata_filtered.X 

207 ) 

208 baseline_mean = base_filtered.mean(axis=0) # gene_means 

209 baselines = torch.tensor( 

210 np.tile(baseline_mean, (baseline_torch_all.shape[0], 1)) 

211 ) 

212 else: 

213 raise ValueError(f"Invalid baseline_type {baseline_type}.") 

214 

215 # inputs = torch.tensor( 

216 # input_adata.X.toarray() 

217 # if scipy.sparse.issparse(input_adata.X) 

218 # else input_adata.X 

219 # ) 

220 # if baseline_group == "all": 

221 # if baseline_type == "mean": 

222 # baseline_mean = inputs.mean(axis=0) # gene_means 

223 # baselines = torch.tensor(np.tile(baseline_mean, (inputs.shape[0], 1))) 

224 # if baseline_type == "random_sample": 

225 # baseline_random = inputs[torch.randint(0, inputs.size(0), (1,)).item()] 

226 # baselines = torch.tensor(np.tile(baseline_random, (inputs.shape[0], 1))) 

227 # else: 

228 # if anno_col is None: 

229 # raise ValueError( 

230 # "If baseline_group is not 'all', anno_col must be provided to specify the grouping column in .obs." 

231 # ) 

232 # input_adata_filtered = input_adata[input_adata.obs[anno_col] == baseline_group] 

233 # inputs_filtered = torch.tensor( 

234 # input_adata_filtered.X.toarray() 

235 # if scipy.sparse.issparse(input_adata_filtered.X) 

236 # else input_adata_filtered.X 

237 # ) 

238 # if baseline_type == "mean": 

239 # baseline_mean = inputs_filtered.mean(axis=0) # gene_means 

240 # baselines = torch.tensor(np.tile(baseline_mean, (inputs.shape[0], 1))) 

241 # if baseline_type == "random_sample": 

242 # baseline_random = inputs_filtered[ 

243 # torch.randint(0, inputs_filtered.size(0), (1,)).item() 

244 # ] 

245 # baselines = torch.tensor(np.tile(baseline_random, (inputs.shape[0], 1))) 

246 return inputs, baselines