Coverage for src/driada/intense/stats.py: 100.00%

105 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1import numpy as np 

2import scipy 

3from scipy.stats import * 

4from ..utils.data import populate_nested_dict, add_names_to_nested_dict 

5from ..experiment.exp_base import DEFAULT_STATS 

6 

7 

8def chebyshev_ineq(data, val): 

9 """ 

10 Calculate upper bound on tail probability using Chebyshev's inequality. 

11  

12 Parameters 

13 ---------- 

14 data : array-like 

15 Sample data to estimate mean and std from. 

16 val : float 

17 Value to compute tail probability for. 

18  

19 Returns 

20 ------- 

21 p_bound : float 

22 Upper bound on P(X >= val) based on Chebyshev's inequality. 

23 """ 

24 z = (val - np.mean(data))/np.std(data) 

25 return 1./z**2 

26 

27 

28def get_lognormal_p(data, val): 

29 """ 

30 Calculate p-value assuming log-normal distribution. 

31  

32 Parameters 

33 ---------- 

34 data : array-like 

35 Sample data to fit log-normal distribution. 

36 val : float 

37 Observed value to compute p-value for. 

38  

39 Returns 

40 ------- 

41 p_value : float 

42 P(X >= val) under fitted log-normal distribution. 

43 """ 

44 params = lognorm.fit(data, floc=0) 

45 rv = lognorm(*params) 

46 return rv.sf(val) 

47 

48 

49def get_gamma_p(data, val): 

50 """ 

51 Calculate p-value assuming gamma distribution. 

52  

53 Parameters 

54 ---------- 

55 data : array-like 

56 Sample data to fit gamma distribution. 

57 val : float 

58 Observed value to compute p-value for. 

59  

60 Returns 

61 ------- 

62 p_value : float 

63 P(X >= val) under fitted gamma distribution. 

64 """ 

65 params = gamma.fit(data, floc=0) 

66 rv = gamma(*params) 

67 return rv.sf(val) 

68 

69 

70def get_distribution_function(dist_name): 

71 """ 

72 Get distribution function from scipy.stats by name. 

73  

74 Parameters 

75 ---------- 

76 dist_name : str 

77 Name of distribution (e.g., 'gamma', 'lognorm', 'norm'). 

78  

79 Returns 

80 ------- 

81 dist : scipy.stats distribution 

82 Distribution function object. 

83  

84 Raises 

85 ------ 

86 ValueError 

87 If distribution name not found in scipy.stats. 

88 """ 

89 try: 

90 return getattr(scipy.stats, dist_name) 

91 except AttributeError: 

92 raise ValueError(f"Distribution '{dist_name}' not found in scipy.stats") 

93 

94 

95def get_mi_distr_pvalue(data, val, distr_type='gamma'): 

96 """ 

97 Calculate p-value by fitting a distribution to data. 

98  

99 Parameters 

100 ---------- 

101 data : array-like 

102 Sample data (typically shuffled metric values). 

103 val : float 

104 Observed value to compute p-value for. 

105 distr_type : str, optional 

106 Distribution type to fit. Default: 'gamma'. 

107  

108 Returns 

109 ------- 

110 p_value : float 

111 P(X >= val) under fitted distribution. 

112  

113 Notes 

114 ----- 

115 - For 'gamma' and 'lognorm', fits with floc=0 (zero lower bound) 

116 - For other distributions, uses default fitting 

117 """ 

118 distr = get_distribution_function(distr_type) 

119 #try: 

120 if distr_type in ['gamma', 'lognorm']: 

121 params = distr.fit(data, floc=0) 

122 else: 

123 params = distr.fit(data) 

124 

125 rv = distr(*params) 

126 return rv.sf(val) 

127 

128 #except: # some rare error in function fitting 

129 #return 1.0 

130 

131 

132def get_mask(ptable, rtable, pval_thr, rank_thr): 

133 """ 

134 Create binary mask based on p-value and rank thresholds. 

135  

136 Parameters 

137 ---------- 

138 ptable : np.ndarray 

139 Array of p-values. 

140 rtable : np.ndarray 

141 Array of ranks (0 to 1). 

142 pval_thr : float 

143 P-value threshold. 

144 rank_thr : float 

145 Rank threshold. 

146  

147 Returns 

148 ------- 

149 mask : np.ndarray 

150 Binary mask: 1 where both thresholds satisfied, 0 otherwise. 

151 """ 

152 mask = np.ones(ptable.shape) 

153 mask[np.where(ptable > pval_thr)] = 0 

154 mask[np.where(rtable < rank_thr)] = 0 

155 return mask 

156 

157 

158def stats_not_empty(pair_stats, current_data_hash, stage=1): 

159 """ 

160 Check if statistics are valid and complete for given stage. 

161  

162 Parameters 

163 ---------- 

164 pair_stats : dict 

165 Dictionary of computed statistics. 

166 current_data_hash : str 

167 Hash of current data to validate against. 

168 stage : int, optional 

169 Stage to check (1 or 2). Default: 1. 

170  

171 Returns 

172 ------- 

173 is_valid : bool 

174 True if stats are valid and complete, False otherwise. 

175 """ 

176 if stage == 1: 

177 stats_to_check = ['pre_rval', 'pre_pval'] 

178 elif stage == 2: 

179 stats_to_check = ['rval', 'pval', 'me'] 

180 else: 

181 raise ValueError(f'Stage should be 1 or 2, but {stage} was passed') 

182 

183 data_hash_from_stats = pair_stats['data_hash'] 

184 is_valid = (current_data_hash == data_hash_from_stats) 

185 is_not_empty = np.all(np.array([pair_stats[st] is not None for st in stats_to_check])) 

186 return is_valid and is_not_empty 

187 

188 

189def criterion1(pair_stats, nsh1, topk=1): 

190 """ 

191 Calculates whether the given neuron-feature pair is potentially significant after preliminary shuffling 

192 

193 Parameters 

194 ---------- 

195 pair_stats: dict 

196 dictionary of computed stats 

197 

198 nsh1: int 

199 number of shuffles for first stage 

200 

201 topk: int 

202 true MI should be among topk MI shuffles 

203 default: 1 

204 

205 Returns 

206 ------- 

207 crit_passed: bool 

208 True if significance confirmed, False if not. 

209 """ 

210 

211 if pair_stats.get('pre_rval') is not None: 

212 return pair_stats['pre_rval'] > (1 - 1.*topk/(nsh1+1)) 

213 #return pair_stats['pre_rval'] == 1 # true MI should be top-1 among all shuffles 

214 else: 

215 return False 

216 

217 

218def criterion2(pair_stats, nsh2, pval_thr, topk=5): 

219 """ 

220 Calculates whether the given neuron-feature pair is significant after full-scale shuffling 

221 

222 Parameters 

223 ---------- 

224 pair_stats: dict 

225 dictionary of computed stats 

226 

227 nsh2: int 

228 number of shuffles for second stage 

229 

230 pval_thr: float 

231 pvalue threshold for a single pair. It depends on a FWER significance level and multiple 

232 hypothesis correction algorithm. 

233 

234 topk: int 

235 true MI should be among topk MI shuffles 

236 default: 5 

237 

238 Returns 

239 ------- 

240 crit_passed: bool 

241 True if significance is confirmed, False if not. 

242 """ 

243 # whether pair passed stage 1 and has statistics from stage 2 

244 if pair_stats.get('rval') is not None and pair_stats.get('pval') is not None: 

245 # whether true MI is among topk shuffles (in practice it is top-1 almost always) 

246 if pair_stats['rval'] > (1 - 1.*topk/(nsh2+1)): 

247 criterion = pair_stats['pval'] < pval_thr 

248 return criterion 

249 else: 

250 return False 

251 else: 

252 return False 

253 

254 

255def get_all_nonempty_pvals(all_stats, ids1, ids2): 

256 """ 

257 Extract all non-empty p-values from nested statistics dictionary. 

258  

259 Parameters 

260 ---------- 

261 all_stats : dict of dict 

262 Nested dictionary with statistics. 

263 ids1 : list 

264 First dimension indices. 

265 ids2 : list 

266 Second dimension indices. 

267  

268 Returns 

269 ------- 

270 all_pvals : list 

271 List of all non-None p-values found. 

272 """ 

273 all_pvals = [] 

274 for i, id1 in enumerate(ids1): 

275 for j, id2 in enumerate(ids2): 

276 pval = all_stats[id1][id2].get('pval') 

277 if pval is not None: 

278 all_pvals.append(pval) 

279 

280 return all_pvals 

281 

282 

283def get_table_of_stats(metable, 

284 optimal_delays, 

285 precomputed_mask=None, 

286 metric_distr_type='gamma', 

287 nsh=0, 

288 stage=1): 

289 """ 

290 Convert metric table to statistics dictionary. 

291  

292 Parameters 

293 ---------- 

294 metable : np.ndarray of shape (n1, n2, nsh+1) 

295 Metric values where [:,:,0] is true values, [:,:,1:] are shuffles. 

296 optimal_delays : np.ndarray of shape (n1, n2) 

297 Optimal delays for each pair. 

298 precomputed_mask : np.ndarray, optional 

299 Binary mask: 1 = compute stats, 0 = skip. Default: all ones. 

300 metric_distr_type : str, optional 

301 Distribution for p-value calculation. Default: 'gamma'. 

302 nsh : int, optional 

303 Number of shuffles. Default: 0. 

304 stage : int, optional 

305 Stage (1 or 2) determines which stats to compute. Default: 1. 

306  

307 Returns 

308 ------- 

309 stage_stats : dict of dict 

310 Nested dictionary with computed statistics for each pair. 

311 """ 

312 # 0 in mask values means that stats for this pair will not be calculated 

313 # 1 in mask values means that stats for this pair will be calculated from new results. 

314 if precomputed_mask is None: 

315 precomputed_mask = np.ones(metable.shape[:2]) 

316 

317 a, b, sh = metable.shape 

318 stage_stats = populate_nested_dict(dict(), range(a), range(b)) 

319 

320 ranked_total_mi = rankdata(metable, axis=2, nan_policy='omit') 

321 ranks = (ranked_total_mi[:, :, 0] / (nsh + 1)) # how many shuffles have MI lower than true mi 

322 

323 for i in range(a): 

324 for j in range(b): 

325 if precomputed_mask[i, j]: 

326 new_stats = {}#DEFAULT_STATS.copy() 

327 me = metable[i, j, 0] 

328 random_mi_samples = metable[i, j, 1:] 

329 pval = get_mi_distr_pvalue(random_mi_samples, me, distr_type=metric_distr_type) 

330 opt_delay = optimal_delays[i, j] 

331 

332 if stage == 1: 

333 new_stats['pre_rval'] = ranks[i, j] 

334 new_stats['pre_pval'] = pval 

335 new_stats['opt_delay'] = opt_delay 

336 new_stats['me'] = metable[i, j, 0] # Add MI value for stage 1 too 

337 

338 elif stage == 2: 

339 new_stats['rval'] = ranks[i,j] 

340 new_stats['pval'] = pval 

341 new_stats['me'] = metable[i,j,0] 

342 new_stats['opt_delay'] = opt_delay 

343 

344 stage_stats[i][j].update(new_stats) 

345 

346 return stage_stats 

347 

348 

349def merge_stage_stats(stage1_stats, stage2_stats): 

350 """ 

351 Merge statistics from stage 1 and stage 2. 

352  

353 Parameters 

354 ---------- 

355 stage1_stats : dict of dict 

356 Statistics from stage 1 (preliminary). 

357 stage2_stats : dict of dict 

358 Statistics from stage 2 (full). 

359  

360 Returns 

361 ------- 

362 merged_stats : dict of dict 

363 Combined statistics with both stage 1 and 2 results. 

364 """ 

365 merged_stats = stage2_stats.copy() 

366 for i in stage2_stats: 

367 for j in stage2_stats[i]: 

368 # Only merge if the entry exists in stage1_stats 

369 if i in stage1_stats and j in stage1_stats[i] and stage1_stats[i][j]: 

370 if 'pre_rval' in stage1_stats[i][j]: 

371 merged_stats[i][j]['pre_rval'] = stage1_stats[i][j]['pre_rval'] 

372 if 'pre_pval' in stage1_stats[i][j]: 

373 merged_stats[i][j]['pre_pval'] = stage1_stats[i][j]['pre_pval'] 

374 

375 return merged_stats 

376 

377 

378def merge_stage_significance(stage_1_significance, stage_2_significance): 

379 """ 

380 Merge significance results from stage 1 and stage 2. 

381  

382 Parameters 

383 ---------- 

384 stage_1_significance : dict of dict 

385 Significance results from stage 1. 

386 stage_2_significance : dict of dict 

387 Significance results from stage 2. 

388  

389 Returns 

390 ------- 

391 merged_significance : dict of dict 

392 Combined significance results. 

393 """ 

394 merged_significance = stage_2_significance.copy() 

395 for i in stage_2_significance: 

396 for j in stage_2_significance[i]: 

397 merged_significance[i][j].update(stage_1_significance[i][j]) 

398 

399 return merged_significance