Coverage for src/driada/intense/stats.py: 100.00%
105 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1import numpy as np
2import scipy
3from scipy.stats import *
4from ..utils.data import populate_nested_dict, add_names_to_nested_dict
5from ..experiment.exp_base import DEFAULT_STATS
8def chebyshev_ineq(data, val):
9 """
10 Calculate upper bound on tail probability using Chebyshev's inequality.
12 Parameters
13 ----------
14 data : array-like
15 Sample data to estimate mean and std from.
16 val : float
17 Value to compute tail probability for.
19 Returns
20 -------
21 p_bound : float
22 Upper bound on P(X >= val) based on Chebyshev's inequality.
23 """
24 z = (val - np.mean(data))/np.std(data)
25 return 1./z**2
28def get_lognormal_p(data, val):
29 """
30 Calculate p-value assuming log-normal distribution.
32 Parameters
33 ----------
34 data : array-like
35 Sample data to fit log-normal distribution.
36 val : float
37 Observed value to compute p-value for.
39 Returns
40 -------
41 p_value : float
42 P(X >= val) under fitted log-normal distribution.
43 """
44 params = lognorm.fit(data, floc=0)
45 rv = lognorm(*params)
46 return rv.sf(val)
49def get_gamma_p(data, val):
50 """
51 Calculate p-value assuming gamma distribution.
53 Parameters
54 ----------
55 data : array-like
56 Sample data to fit gamma distribution.
57 val : float
58 Observed value to compute p-value for.
60 Returns
61 -------
62 p_value : float
63 P(X >= val) under fitted gamma distribution.
64 """
65 params = gamma.fit(data, floc=0)
66 rv = gamma(*params)
67 return rv.sf(val)
70def get_distribution_function(dist_name):
71 """
72 Get distribution function from scipy.stats by name.
74 Parameters
75 ----------
76 dist_name : str
77 Name of distribution (e.g., 'gamma', 'lognorm', 'norm').
79 Returns
80 -------
81 dist : scipy.stats distribution
82 Distribution function object.
84 Raises
85 ------
86 ValueError
87 If distribution name not found in scipy.stats.
88 """
89 try:
90 return getattr(scipy.stats, dist_name)
91 except AttributeError:
92 raise ValueError(f"Distribution '{dist_name}' not found in scipy.stats")
95def get_mi_distr_pvalue(data, val, distr_type='gamma'):
96 """
97 Calculate p-value by fitting a distribution to data.
99 Parameters
100 ----------
101 data : array-like
102 Sample data (typically shuffled metric values).
103 val : float
104 Observed value to compute p-value for.
105 distr_type : str, optional
106 Distribution type to fit. Default: 'gamma'.
108 Returns
109 -------
110 p_value : float
111 P(X >= val) under fitted distribution.
113 Notes
114 -----
115 - For 'gamma' and 'lognorm', fits with floc=0 (zero lower bound)
116 - For other distributions, uses default fitting
117 """
118 distr = get_distribution_function(distr_type)
119 #try:
120 if distr_type in ['gamma', 'lognorm']:
121 params = distr.fit(data, floc=0)
122 else:
123 params = distr.fit(data)
125 rv = distr(*params)
126 return rv.sf(val)
128 #except: # some rare error in function fitting
129 #return 1.0
132def get_mask(ptable, rtable, pval_thr, rank_thr):
133 """
134 Create binary mask based on p-value and rank thresholds.
136 Parameters
137 ----------
138 ptable : np.ndarray
139 Array of p-values.
140 rtable : np.ndarray
141 Array of ranks (0 to 1).
142 pval_thr : float
143 P-value threshold.
144 rank_thr : float
145 Rank threshold.
147 Returns
148 -------
149 mask : np.ndarray
150 Binary mask: 1 where both thresholds satisfied, 0 otherwise.
151 """
152 mask = np.ones(ptable.shape)
153 mask[np.where(ptable > pval_thr)] = 0
154 mask[np.where(rtable < rank_thr)] = 0
155 return mask
158def stats_not_empty(pair_stats, current_data_hash, stage=1):
159 """
160 Check if statistics are valid and complete for given stage.
162 Parameters
163 ----------
164 pair_stats : dict
165 Dictionary of computed statistics.
166 current_data_hash : str
167 Hash of current data to validate against.
168 stage : int, optional
169 Stage to check (1 or 2). Default: 1.
171 Returns
172 -------
173 is_valid : bool
174 True if stats are valid and complete, False otherwise.
175 """
176 if stage == 1:
177 stats_to_check = ['pre_rval', 'pre_pval']
178 elif stage == 2:
179 stats_to_check = ['rval', 'pval', 'me']
180 else:
181 raise ValueError(f'Stage should be 1 or 2, but {stage} was passed')
183 data_hash_from_stats = pair_stats['data_hash']
184 is_valid = (current_data_hash == data_hash_from_stats)
185 is_not_empty = np.all(np.array([pair_stats[st] is not None for st in stats_to_check]))
186 return is_valid and is_not_empty
189def criterion1(pair_stats, nsh1, topk=1):
190 """
191 Calculates whether the given neuron-feature pair is potentially significant after preliminary shuffling
193 Parameters
194 ----------
195 pair_stats: dict
196 dictionary of computed stats
198 nsh1: int
199 number of shuffles for first stage
201 topk: int
202 true MI should be among topk MI shuffles
203 default: 1
205 Returns
206 -------
207 crit_passed: bool
208 True if significance confirmed, False if not.
209 """
211 if pair_stats.get('pre_rval') is not None:
212 return pair_stats['pre_rval'] > (1 - 1.*topk/(nsh1+1))
213 #return pair_stats['pre_rval'] == 1 # true MI should be top-1 among all shuffles
214 else:
215 return False
218def criterion2(pair_stats, nsh2, pval_thr, topk=5):
219 """
220 Calculates whether the given neuron-feature pair is significant after full-scale shuffling
222 Parameters
223 ----------
224 pair_stats: dict
225 dictionary of computed stats
227 nsh2: int
228 number of shuffles for second stage
230 pval_thr: float
231 pvalue threshold for a single pair. It depends on a FWER significance level and multiple
232 hypothesis correction algorithm.
234 topk: int
235 true MI should be among topk MI shuffles
236 default: 5
238 Returns
239 -------
240 crit_passed: bool
241 True if significance is confirmed, False if not.
242 """
243 # whether pair passed stage 1 and has statistics from stage 2
244 if pair_stats.get('rval') is not None and pair_stats.get('pval') is not None:
245 # whether true MI is among topk shuffles (in practice it is top-1 almost always)
246 if pair_stats['rval'] > (1 - 1.*topk/(nsh2+1)):
247 criterion = pair_stats['pval'] < pval_thr
248 return criterion
249 else:
250 return False
251 else:
252 return False
255def get_all_nonempty_pvals(all_stats, ids1, ids2):
256 """
257 Extract all non-empty p-values from nested statistics dictionary.
259 Parameters
260 ----------
261 all_stats : dict of dict
262 Nested dictionary with statistics.
263 ids1 : list
264 First dimension indices.
265 ids2 : list
266 Second dimension indices.
268 Returns
269 -------
270 all_pvals : list
271 List of all non-None p-values found.
272 """
273 all_pvals = []
274 for i, id1 in enumerate(ids1):
275 for j, id2 in enumerate(ids2):
276 pval = all_stats[id1][id2].get('pval')
277 if pval is not None:
278 all_pvals.append(pval)
280 return all_pvals
283def get_table_of_stats(metable,
284 optimal_delays,
285 precomputed_mask=None,
286 metric_distr_type='gamma',
287 nsh=0,
288 stage=1):
289 """
290 Convert metric table to statistics dictionary.
292 Parameters
293 ----------
294 metable : np.ndarray of shape (n1, n2, nsh+1)
295 Metric values where [:,:,0] is true values, [:,:,1:] are shuffles.
296 optimal_delays : np.ndarray of shape (n1, n2)
297 Optimal delays for each pair.
298 precomputed_mask : np.ndarray, optional
299 Binary mask: 1 = compute stats, 0 = skip. Default: all ones.
300 metric_distr_type : str, optional
301 Distribution for p-value calculation. Default: 'gamma'.
302 nsh : int, optional
303 Number of shuffles. Default: 0.
304 stage : int, optional
305 Stage (1 or 2) determines which stats to compute. Default: 1.
307 Returns
308 -------
309 stage_stats : dict of dict
310 Nested dictionary with computed statistics for each pair.
311 """
312 # 0 in mask values means that stats for this pair will not be calculated
313 # 1 in mask values means that stats for this pair will be calculated from new results.
314 if precomputed_mask is None:
315 precomputed_mask = np.ones(metable.shape[:2])
317 a, b, sh = metable.shape
318 stage_stats = populate_nested_dict(dict(), range(a), range(b))
320 ranked_total_mi = rankdata(metable, axis=2, nan_policy='omit')
321 ranks = (ranked_total_mi[:, :, 0] / (nsh + 1)) # how many shuffles have MI lower than true mi
323 for i in range(a):
324 for j in range(b):
325 if precomputed_mask[i, j]:
326 new_stats = {}#DEFAULT_STATS.copy()
327 me = metable[i, j, 0]
328 random_mi_samples = metable[i, j, 1:]
329 pval = get_mi_distr_pvalue(random_mi_samples, me, distr_type=metric_distr_type)
330 opt_delay = optimal_delays[i, j]
332 if stage == 1:
333 new_stats['pre_rval'] = ranks[i, j]
334 new_stats['pre_pval'] = pval
335 new_stats['opt_delay'] = opt_delay
336 new_stats['me'] = metable[i, j, 0] # Add MI value for stage 1 too
338 elif stage == 2:
339 new_stats['rval'] = ranks[i,j]
340 new_stats['pval'] = pval
341 new_stats['me'] = metable[i,j,0]
342 new_stats['opt_delay'] = opt_delay
344 stage_stats[i][j].update(new_stats)
346 return stage_stats
349def merge_stage_stats(stage1_stats, stage2_stats):
350 """
351 Merge statistics from stage 1 and stage 2.
353 Parameters
354 ----------
355 stage1_stats : dict of dict
356 Statistics from stage 1 (preliminary).
357 stage2_stats : dict of dict
358 Statistics from stage 2 (full).
360 Returns
361 -------
362 merged_stats : dict of dict
363 Combined statistics with both stage 1 and 2 results.
364 """
365 merged_stats = stage2_stats.copy()
366 for i in stage2_stats:
367 for j in stage2_stats[i]:
368 # Only merge if the entry exists in stage1_stats
369 if i in stage1_stats and j in stage1_stats[i] and stage1_stats[i][j]:
370 if 'pre_rval' in stage1_stats[i][j]:
371 merged_stats[i][j]['pre_rval'] = stage1_stats[i][j]['pre_rval']
372 if 'pre_pval' in stage1_stats[i][j]:
373 merged_stats[i][j]['pre_pval'] = stage1_stats[i][j]['pre_pval']
375 return merged_stats
378def merge_stage_significance(stage_1_significance, stage_2_significance):
379 """
380 Merge significance results from stage 1 and stage 2.
382 Parameters
383 ----------
384 stage_1_significance : dict of dict
385 Significance results from stage 1.
386 stage_2_significance : dict of dict
387 Significance results from stage 2.
389 Returns
390 -------
391 merged_significance : dict of dict
392 Combined significance results.
393 """
394 merged_significance = stage_2_significance.copy()
395 for i in stage_2_significance:
396 for j in stage_2_significance[i]:
397 merged_significance[i][j].update(stage_1_significance[i][j])
399 return merged_significance