Coverage for src/driada/intense/intense_base.py: 87.19%
406 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1import numpy as np
2import tqdm
3from joblib import Parallel, delayed
4import multiprocessing
5import scipy.stats
7from .stats import *
8from ..information.info_base import TimeSeries, MultiTimeSeries, get_1d_mi, get_multi_mi, get_mi, get_sim
9from ..utils.data import write_dict_to_hdf5, nested_dict_to_seq_of_tables
12def validate_time_series_bunches(ts_bunch1, ts_bunch2, allow_mixed_dimensions=False):
13 """
14 Validate time series bunches for INTENSE computations.
16 Parameters
17 ----------
18 ts_bunch1 : list
19 First set of time series.
20 ts_bunch2 : list
21 Second set of time series.
22 allow_mixed_dimensions : bool, optional
23 Whether to allow mixed TimeSeries and MultiTimeSeries. Default: False.
25 Raises
26 ------
27 ValueError
28 If validation fails.
29 """
30 if len(ts_bunch1) == 0:
31 raise ValueError("ts_bunch1 cannot be empty")
32 if len(ts_bunch2) == 0:
33 raise ValueError("ts_bunch2 cannot be empty")
35 # Check time series types
36 if not allow_mixed_dimensions:
37 ts1_types = [type(ts) for ts in ts_bunch1]
38 ts2_types = [type(ts) for ts in ts_bunch2]
40 if not all(t == TimeSeries for t in ts1_types):
41 if any(t == MultiTimeSeries for t in ts1_types):
42 raise ValueError("MultiTimeSeries found in ts_bunch1 but allow_mixed_dimensions=False")
43 else:
44 raise ValueError("ts_bunch1 must contain TimeSeries objects")
46 if not all(t == TimeSeries for t in ts2_types):
47 if any(t == MultiTimeSeries for t in ts2_types):
48 raise ValueError("MultiTimeSeries found in ts_bunch2 but allow_mixed_dimensions=False")
49 else:
50 raise ValueError("ts_bunch2 must contain TimeSeries objects")
52 # Check lengths match
53 lengths1 = [len(ts.data) if isinstance(ts, TimeSeries) else ts.data.shape[1] for ts in ts_bunch1]
54 lengths2 = [len(ts.data) if isinstance(ts, TimeSeries) else ts.data.shape[1] for ts in ts_bunch2]
56 if len(set(lengths1)) > 1:
57 raise ValueError(f"All time series in ts_bunch1 must have same length, got {set(lengths1)}")
58 if len(set(lengths2)) > 1:
59 raise ValueError(f"All time series in ts_bunch2 must have same length, got {set(lengths2)}")
60 if lengths1[0] != lengths2[0]:
61 raise ValueError(f"Time series lengths don't match: {lengths1[0]} vs {lengths2[0]}")
64def validate_metric(metric, allow_scipy=True):
65 """
66 Validate metric name and check if it's supported.
68 Parameters
69 ----------
70 metric : str
71 Metric name to validate.
72 allow_scipy : bool, optional
73 Whether to allow scipy correlation metrics. Default: True.
75 Returns
76 -------
77 metric_type : str
78 Type of metric: 'mi', 'correlation', 'special', or 'scipy'.
80 Raises
81 ------
82 ValueError
83 If metric is not supported.
84 """
85 # Built-in metrics
86 if metric == 'mi':
87 return 'mi'
89 # Special metrics
90 if metric in ['av', 'fast_pearsonr']:
91 return 'special'
93 # Common correlation metrics (shorthand names)
94 correlation_metrics = ['spearman', 'pearson', 'kendall']
95 if metric in correlation_metrics:
96 return 'correlation'
98 # Full scipy names
99 scipy_correlation_metrics = ['spearmanr', 'pearsonr', 'kendalltau']
100 if metric in scipy_correlation_metrics:
101 return 'scipy'
103 # Check if it's a scipy function
104 if allow_scipy:
105 try:
106 import scipy.stats
107 if hasattr(scipy.stats, metric):
108 return 'scipy'
109 except ImportError:
110 pass
112 # If we get here, metric is not supported
113 raise ValueError(f"Unsupported metric: {metric}. Supported metrics include: "
114 f"'mi', 'av', 'fast_pearsonr', 'spearman', 'pearson', 'kendall', "
115 f"'spearmanr', 'pearsonr', 'kendalltau', and other scipy.stats functions.")
118def validate_common_parameters(shift_window=None, ds=None, nsh=None, noise_const=None):
119 """
120 Validate common INTENSE parameters.
122 Parameters
123 ----------
124 shift_window : int, optional
125 Maximum shift window in frames.
126 ds : int, optional
127 Downsampling factor.
128 nsh : int, optional
129 Number of shuffles.
130 noise_const : float, optional
131 Noise constant for numerical stability.
133 Raises
134 ------
135 ValueError
136 If any parameter is invalid.
137 """
138 if shift_window is not None and shift_window < 0:
139 raise ValueError(f"shift_window must be non-negative, got {shift_window}")
141 if ds is not None and ds <= 0:
142 raise ValueError(f"ds must be positive, got {ds}")
144 if nsh is not None and nsh <= 0:
145 raise ValueError(f"nsh must be positive, got {nsh}")
147 if noise_const is not None and noise_const < 0:
148 raise ValueError(f"noise_const must be non-negative, got {noise_const}")
151def calculate_optimal_delays(ts_bunch1, ts_bunch2, metric,
152 shift_window, ds, verbose=True, enable_progressbar=True):
153 """
154 Calculate optimal temporal delays between pairs of time series.
156 Finds the delay that maximizes the similarity metric between each pair of time series
157 from ts_bunch1 and ts_bunch2. This accounts for temporal offsets in neural responses
158 relative to behavioral variables.
160 Parameters
161 ----------
162 ts_bunch1 : list of TimeSeries
163 First set of time series (typically neural signals).
164 ts_bunch2 : list of TimeSeries
165 Second set of time series (typically behavioral variables).
166 metric : str
167 Similarity metric to maximize. Options include:
168 - 'mi': Mutual information
169 - 'spearman': Spearman correlation
170 - Other metrics supported by get_sim function
171 shift_window : int
172 Maximum shift to test in each direction (frames).
173 Will test shifts from -shift_window to +shift_window.
174 ds : int
175 Downsampling factor. Every ds-th point is used from the time series.
176 Default: 1 (no downsampling).
177 verbose : bool, optional
178 Whether to print progress information. Default: True.
179 enable_progressbar : bool, optional
180 Whether to show progress bar. Default: True.
182 Returns
183 -------
184 optimal_delays : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2))
185 Optimal delay (in frames) for each pair. Positive values indicate
186 that ts2 leads ts1, negative values indicate ts1 leads ts2.
188 Notes
189 -----
190 - Computational complexity: O(n1 * n2 * shifts) where n1, n2 are lengths
191 of ts_bunch1 and ts_bunch2, and shifts = 2 * shift_window / ds
192 - The optimal delay is found by exhaustive search over all possible shifts
193 - Memory efficient: only stores final optimal delays, not all tested values
195 Examples
196 --------
197 >>> neurons = [neuron1.ca, neuron2.ca] # calcium signals
198 >>> behaviors = [speed_ts, direction_ts] # behavioral variables
199 >>> delays = calculate_optimal_delays(neurons, behaviors, 'mi',
200 ... shift_window=100, ds=1)
201 >>> print(f"Neuron 1 optimal delay with speed: {delays[0, 0]} frames")
202 """
203 # Validate inputs
204 validate_time_series_bunches(ts_bunch1, ts_bunch2, allow_mixed_dimensions=False)
205 validate_metric(metric)
206 validate_common_parameters(shift_window=shift_window, ds=ds)
208 if verbose:
209 print('Calculating optimal delays:')
211 optimal_delays = np.zeros((len(ts_bunch1), len(ts_bunch2)), dtype=int)
212 shifts = np.arange(-shift_window, shift_window, ds) // ds
214 for i, ts1 in tqdm.tqdm(enumerate(ts_bunch1), total=len(ts_bunch1), disable=not enable_progressbar):
215 for j, ts2 in enumerate(ts_bunch2):
216 shifted_me = []
217 for shift in shifts:
218 lag_me = get_sim(ts1, ts2, metric, ds=ds, shift=int(shift))
219 shifted_me.append(lag_me)
221 best_shift = shifts[np.argmax(shifted_me)]
222 optimal_delays[i, j] = int(best_shift*ds)
224 return optimal_delays
227def calculate_optimal_delays_parallel(ts_bunch1, ts_bunch2, metric,
228 shift_window, ds, verbose=True, n_jobs=-1):
229 """
230 Calculate optimal temporal delays between pairs of time series using parallel processing.
232 Parallel version of calculate_optimal_delays that distributes computation across
233 multiple CPU cores for improved performance with large datasets.
235 Parameters
236 ----------
237 ts_bunch1 : list of TimeSeries
238 First set of time series (typically neural signals).
239 ts_bunch2 : list of TimeSeries
240 Second set of time series (typically behavioral variables).
241 metric : str
242 Similarity metric to maximize. Options include:
243 - 'mi': Mutual information
244 - 'spearman': Spearman correlation
245 - Other metrics supported by get_sim function
246 shift_window : int
247 Maximum shift to test in each direction (frames).
248 Will test shifts from -shift_window to +shift_window.
249 ds : int
250 Downsampling factor. Every ds-th point is used from the time series.
251 verbose : bool, optional
252 Whether to print progress information. Default: True.
253 n_jobs : int, optional
254 Number of parallel jobs to run. Default: -1 (use all available cores).
256 Returns
257 -------
258 optimal_delays : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2))
259 Optimal delay (in frames) for each pair. Positive values indicate
260 that ts2 leads ts1, negative values indicate ts1 leads ts2.
262 Notes
263 -----
264 - Parallelization is done by splitting ts_bunch1 across workers
265 - Each worker processes a subset of ts_bunch1 against all of ts_bunch2
266 - Memory usage scales with number of workers
267 - Speedup is typically sublinear due to overhead and memory bandwidth
269 See Also
270 --------
271 calculate_optimal_delays : Sequential version of this function
273 Examples
274 --------
275 >>> neurons = [neuron.ca for neuron in exp.neurons[:100]]
276 >>> behaviors = [exp.speed, exp.direction]
277 >>> # Use 8 cores for faster computation
278 >>> delays = calculate_optimal_delays_parallel(neurons, behaviors, 'mi',
279 ... shift_window=100, ds=1, n_jobs=8)
280 """
281 # Validate inputs
282 validate_time_series_bunches(ts_bunch1, ts_bunch2, allow_mixed_dimensions=False)
283 validate_metric(metric)
284 validate_common_parameters(shift_window=shift_window, ds=ds)
286 if verbose:
287 print('Calculating optimal delays in parallel mode:')
289 optimal_delays = np.zeros((len(ts_bunch1), len(ts_bunch2)), dtype=int)
291 if n_jobs == -1:
292 n_jobs = min(multiprocessing.cpu_count(), len(ts_bunch1))
294 split_ts_bunch1_inds = np.array_split(np.arange(len(ts_bunch1)), n_jobs)
295 split_ts_bunch1 = [np.array(ts_bunch1)[idxs] for idxs in split_ts_bunch1_inds]
297 parallel_delays = Parallel(n_jobs=n_jobs, verbose=True)(
298 delayed(calculate_optimal_delays)(small_ts_bunch,
299 ts_bunch2,
300 metric,
301 shift_window,
302 ds,
303 verbose=False,
304 enable_progressbar=False)
305 for small_ts_bunch in split_ts_bunch1)
307 for i, pd in enumerate(parallel_delays):
308 inds_of_interest = split_ts_bunch1_inds[i]
309 optimal_delays[inds_of_interest, :] = pd
311 return optimal_delays
314def get_calcium_feature_me_profile(exp, cell_id=None, feat_id=None, cbunch=None, fbunch=None,
315 window=1000, ds=1, metric='mi', data_type='calcium'):
316 """
317 Compute metric profile between neurons and behavioral features across time shifts.
319 Parameters
320 ----------
321 exp : Experiment
322 Experiment object containing neurons and behavioral features.
323 cell_id : int, optional
324 Index of a single neuron in exp.neurons. Deprecated - use cbunch instead.
325 feat_id : str or tuple of str, optional
326 Single feature name(s) to analyze. Deprecated - use fbunch instead.
327 cbunch : int, iterable or None, optional
328 Neuron indices. If None (default), all neurons will be analyzed.
329 Takes precedence over cell_id if both provided.
330 fbunch : str, iterable or None, optional
331 Feature names. If None (default), all single features will be analyzed.
332 Takes precedence over feat_id if both provided.
333 window : int, optional
334 Maximum shift to test in each direction (frames). Default: 1000.
335 ds : int, optional
336 Downsampling factor. Default: 1 (no downsampling).
337 metric : str, optional
338 Similarity metric to compute. Default: 'mi'.
339 - 'mi': Mutual information
340 - 'spearman': Spearman correlation
341 - Other metrics supported by get_sim function
342 data_type : str, optional
343 Type of neural data to use. Default: 'calcium'.
344 - 'calcium': Use calcium imaging data
345 - 'spikes': Use spike data
347 Returns
348 -------
349 dict
350 If single cell_id and feat_id provided (backward compatibility):
351 {'me0': float, 'shifted_me': list of float}
352 If cbunch or fbunch used:
353 Nested dictionary with structure:
354 {cell_id: {feat_id: {'me0': float, 'shifted_me': list}}}
355 where shifted_me contains metric values from -window to +window.
357 Notes
358 -----
359 - Total number of shifts tested: 2 * window / ds
360 - Multi-feature analysis (tuple feat_id) only supported for metric='mi'
361 - Progress bar shows computation progress
363 Examples
364 --------
365 >>> # Backward compatibility - single cell and feature
366 >>> mi_zero, mi_profile = get_calcium_feature_me_profile(exp, 0, 'speed')
367 >>>
368 >>> # New usage - analyze multiple cells and features
369 >>> results = get_calcium_feature_me_profile(exp, cbunch=[0, 1, 2], fbunch=['speed', 'head_direction'])
370 >>> # Access specific result: results[cell_id][feat_id]['me0'] and ['shifted_me']
371 >>>
372 >>> # Analyze all cells with all features
373 >>> results = get_calcium_feature_me_profile(exp, cbunch=None, fbunch=None)
374 >>>
375 >>> # Multi-feature joint mutual information
376 >>> results = get_calcium_feature_me_profile(exp, cbunch=[0], fbunch=[('x', 'y')])
377 """
378 # Validate inputs
379 validate_common_parameters(ds=ds)
380 validate_metric(metric)
382 if window <= 0:
383 raise ValueError(f"window must be positive, got {window}")
385 # Check if single cell/feature mode (backward compatibility)
386 single_mode = (cell_id is not None and feat_id is not None and
387 cbunch is None and fbunch is None)
389 # Handle backward compatibility - if old-style single cell_id/feat_id provided
390 if cbunch is None and cell_id is not None:
391 cbunch = cell_id
392 if fbunch is None and feat_id is not None:
393 fbunch = feat_id
395 # Process cbunch and fbunch using experiment's methods
396 cell_ids = exp._process_cbunch(cbunch)
397 feat_ids = exp._process_fbunch(fbunch, allow_multifeatures=True, mode=data_type)
399 # Validate cell indices
400 for cid in cell_ids:
401 if not (0 <= cid < len(exp.neurons)):
402 raise ValueError(f"cell_id {cid} out of range [0, {len(exp.neurons)-1}]")
404 # Initialize results dictionary
405 results = {}
407 # Progress bar for all combinations
408 total_combinations = len(cell_ids) * len(feat_ids)
409 pbar = tqdm.tqdm(total=total_combinations, desc="Computing ME profiles")
411 for cid in cell_ids:
412 cell = exp.neurons[cid]
413 ts1 = cell.ca if data_type == 'calcium' else cell.spikes
414 results[cid] = {}
416 for fid in feat_ids:
417 shifted_me = []
419 if isinstance(fid, str):
420 # Single feature
421 ts2 = exp.dynamic_features[fid]
422 me0 = get_sim(ts1, ts2, metric, ds=ds)
424 for shift in np.arange(-window, window, ds)//ds:
425 lag_me = get_sim(ts1, ts2, metric, ds=ds, shift=shift)
426 shifted_me.append(lag_me)
428 else:
429 # Multi-feature (tuple)
430 if metric != 'mi':
431 raise ValueError(f"Multi-feature analysis only supported for metric='mi', got '{metric}'")
432 feats = [exp.dynamic_features[f] for f in fid]
433 me0 = get_multi_mi(feats, ts1, ds=ds)
435 for shift in np.arange(-window, window, ds)//ds:
436 lag_me = get_multi_mi(feats, ts1, ds=ds, shift=shift)
437 shifted_me.append(lag_me)
439 results[cid][fid] = {'me0': me0, 'shifted_me': shifted_me}
440 pbar.update(1)
442 pbar.close()
444 # Return format based on usage mode
445 if single_mode:
446 # Backward compatibility - return simple format
447 return results[cell_ids[0]][feat_ids[0]]['me0'], results[cell_ids[0]][feat_ids[0]]['shifted_me']
448 else:
449 # New format - return full results dictionary
450 return results
453def scan_pairs(ts_bunch1,
454 ts_bunch2,
455 metric,
456 nsh,
457 optimal_delays,
458 joint_distr=False,
459 ds=1,
460 mask=None,
461 noise_const=1e-3,
462 seed=None,
463 allow_mixed_dimensions=False,
464 enable_progressbar=True):
466 """
467 Calculates MI shuffles for 2 given sets of TimeSeries
468 This function is generally assumed to be used internally,
469 but can be also called manually to "look inside" high-level computation routines
471 Parameters
472 ----------
473 ts_bunch1: list of TimeSeries objects
475 ts_bunch2: list of TimeSeries objects
477 metric: similarity metric between TimeSeries
479 nsh: int
480 number of shuffles
482 joint_distr: bool
483 if joint_distr=True, ALL (sic!) TimeSeries in ts_bunch2 will be treated as components of a single multifeature
484 default: False
486 ds: int
487 Downsampling constant. Every "ds" point will be taken from the data time series.
488 default: 1
490 mask: np.array of shape (len(ts_bunch1), len(ts_bunch2)) or (len(ts_bunch), 1) if joint_distr=True
491 precomputed mask for skipping some of possible pairs.
492 0 in mask values means calculation will be skipped.
493 1 in mask values means calculation will proceed.
495 noise_const: float
496 Small noise amplitude, which is added to MI and shuffled MI to improve numerical fit
497 default: 1e-3
499 optimal_delays: np.array of shape (len(ts_bunch1), len(ts_bunch2)) or (len(ts_bunch), 1) if joint_distr=True
500 best shifts from original time series alignment in terms of MI.
502 seed: int
503 Random seed for reproducibility
505 Returns
506 -------
507 random_shifts: np.array of shape (len(ts_bunch1), len(ts_bunch2), nsh)
508 signals shifts used for MI distribution computation
510 me_total: np.array of shape (len(ts_bunch1), len(ts_bunch2)), nsh+1) or (len(ts_bunch1), 1, nsh+1) if joint_distr==True
511 Aggregated array of true and shuffled MI values.
512 True MI matrix can be obtained by me_total[:,:,0]
513 Shuffled MI tensor of shape (len(ts_bunch1), len(ts_bunch2)), nsh) or (len(ts_bunch1), 1, nsh) if joint_distr==True
514 can be obtained by me_total[:,:,1:]
515 """
517 # Validate inputs
518 validate_time_series_bunches(ts_bunch1, ts_bunch2, allow_mixed_dimensions=allow_mixed_dimensions)
519 validate_metric(metric)
520 validate_common_parameters(ds=ds, nsh=nsh, noise_const=noise_const)
522 # Validate optimal_delays shape
523 n1 = len(ts_bunch1)
524 n2 = 1 if joint_distr else len(ts_bunch2)
526 if optimal_delays.shape != (n1, n2):
527 raise ValueError(f"optimal_delays shape {optimal_delays.shape} doesn't match expected ({n1}, {n2})")
529 if seed is None:
530 seed = 0
532 np.random.seed(seed)
534 lengths1 = [len(ts.data) if isinstance(ts, TimeSeries) else ts.data.shape[1] for ts in ts_bunch1]
535 lengths2 = [len(ts.data) if isinstance(ts, TimeSeries) else ts.data.shape[1] for ts in ts_bunch2]
536 if len(set(lengths1)) == 1 and len(set(lengths2)) == 1 and set(lengths1) == set(lengths2):
537 t = lengths1[0] # full length is the same for all time series
538 else:
539 raise ValueError('Lenghts of TimeSeries do not match!')
541 if mask is None:
542 mask = np.ones((n1, n2))
544 me_table = np.zeros((n1, n2))
545 me_table_shuffles = np.zeros((n1, n2, nsh))
546 random_shifts = np.zeros((n1, n2, nsh), dtype=int)
548 # fill random shifts according to the allowed shuffles masks of both time series
549 for i, ts1 in enumerate(ts_bunch1):
550 if joint_distr:
551 np.random.seed(seed)
552 # Combine shuffle masks from ts1 and all ts in tsbunch2
553 combined_shuffle_mask = ts1.shuffle_mask.copy()
554 for ts2 in ts_bunch2:
555 combined_shuffle_mask = combined_shuffle_mask & ts2.shuffle_mask
556 # move shuffle mask according to optimal shift
557 combined_shuffle_mask = np.roll(combined_shuffle_mask, int(optimal_delays[i, 0]))
558 indices_to_select = np.arange(t)[combined_shuffle_mask]
559 random_shifts[i, 0, :] = np.random.choice(indices_to_select, size=nsh) // ds
561 else:
562 for j, ts2 in enumerate(ts_bunch2):
563 np.random.seed(seed)
564 combined_shuffle_mask = ts1.shuffle_mask & ts2.shuffle_mask
565 # move shuffle mask according to optimal shift
566 combined_shuffle_mask = np.roll(combined_shuffle_mask, int(optimal_delays[i, j]))
567 indices_to_select = np.arange(t)[combined_shuffle_mask]
568 random_shifts[i, j, :] = np.random.choice(indices_to_select, size=nsh)//ds
570 # calculate similarity metric arrays
571 for i, ts1 in tqdm.tqdm(enumerate(ts_bunch1),
572 total=len(ts_bunch1),
573 position=0,
574 leave=True,
575 disable=not enable_progressbar):
577 np.random.seed(seed)
579 # TODO: deprecate this branch, it is unnecessary with MultiTimeSeries
580 if joint_distr:
581 if metric != 'mi':
582 raise ValueError("joint_distr mode works with metric = 'mi' only")
583 if mask[i,0] == 1:
584 # default metric without shuffling, minus due to different order
585 me0 = get_multi_mi(ts_bunch2, ts1, ds=ds, shift=-optimal_delays[i, 0]//ds)
586 me_table[i,0] = me0 + np.random.random()*noise_const # add small noise for better fitting
588 np.random.seed(seed)
589 random_noise = np.random.random(size=len(random_shifts[i, 0, :])) * noise_const # add small noise for better fitting
590 for k, shift in enumerate(random_shifts[i, 0, :]):
591 mi = get_multi_mi(ts_bunch2, ts1, ds=ds, shift=shift)
592 me_table_shuffles[i,0,k] = mi + random_noise[k]
594 else:
595 me_table[i,0] = None
596 me_table_shuffles[i,0,:] = np.full(shape=nsh, fill_value=None)
598 else:
599 for j, ts2 in enumerate(ts_bunch2):
600 if mask[i,j] == 1:
601 me0 = get_sim(ts1,
602 ts2,
603 metric,
604 ds=ds,
605 shift=optimal_delays[i, j]//ds,
606 check_for_coincidence=True) # default metric without shuffling
608 np.random.seed(seed)
609 me_table[i,j] = me0 + np.random.random()*noise_const # add small noise for better fitting
611 np.random.seed(seed)
612 random_noise = np.random.random(
613 size=len(random_shifts[i, j, :])) * noise_const # add small noise for better fitting
615 for k, shift in enumerate(random_shifts[i,j,:]):
616 np.random.seed(seed)
617 #mi = get_1d_mi(ts1, ts2, shift=shift, ds=ds)
618 me = get_sim(ts1,
619 ts2,
620 metric,
621 ds=ds,
622 shift=shift)
624 me_table_shuffles[i,j,k] = me + random_noise[k]
626 else:
627 me_table[i,j] = None
628 me_table_shuffles[i,j,:] = np.array([None for _ in range(nsh)])
630 me_total = np.dstack((me_table, me_table_shuffles))
632 return random_shifts, me_total
635def scan_pairs_parallel(ts_bunch1,
636 ts_bunch2,
637 metric,
638 nsh,
639 optimal_delays,
640 joint_distr=False,
641 allow_mixed_dimensions=False,
642 ds=1,
643 mask=None,
644 noise_const=1e-3,
645 seed=None,
646 n_jobs=-1):
647 """
648 Calculate metric values and shuffles for time series pairs using parallel processing.
650 Parameters
651 ----------
652 ts_bunch1 : list of TimeSeries
653 First set of time series.
654 ts_bunch2 : list of TimeSeries
655 Second set of time series.
656 metric : str
657 Similarity metric to compute:
658 - 'mi': Mutual information
659 - 'spearman': Spearman correlation
660 - Other metrics supported by get_sim function
661 nsh : int
662 Number of shuffles to perform.
663 optimal_delays : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2))
664 Pre-computed optimal delays for each pair.
665 joint_distr : bool, optional
666 If True, treats all ts_bunch2 as components of a single multifeature.
667 Default: False.
668 ds : int, optional
669 Downsampling factor. Default: 1.
670 mask : np.ndarray, optional
671 Binary mask of shape (len(ts_bunch1), len(ts_bunch2)).
672 0 = skip computation, 1 = compute. Default: all ones.
673 noise_const : float, optional
674 Small noise added to improve numerical stability. Default: 1e-3.
675 seed : int, optional
676 Random seed for reproducibility. Default: None.
677 n_jobs : int, optional
678 Number of parallel jobs. Default: -1 (use all cores).
680 Returns
681 -------
682 random_shifts : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2), nsh)
683 Random shifts used for shuffling.
684 me_total : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2), nsh+1)
685 Metric values. [:,:,0] contains true values, [:,:,1:] contains shuffles.
687 See Also
688 --------
689 scan_pairs : Sequential version of this function
690 scan_pairs_router : Wrapper that chooses between parallel and sequential
691 """
693 # Validate inputs
694 validate_time_series_bunches(ts_bunch1, ts_bunch2, allow_mixed_dimensions=allow_mixed_dimensions)
695 validate_metric(metric)
696 validate_common_parameters(ds=ds, nsh=nsh, noise_const=noise_const)
698 n1 = len(ts_bunch1)
699 n2 = 1 if joint_distr else len(ts_bunch2)
701 # Validate optimal_delays shape
702 if optimal_delays.shape != (n1, n2):
703 raise ValueError(f"optimal_delays shape {optimal_delays.shape} doesn't match expected ({n1}, {n2})")
705 me_total = np.zeros((n1, n2, nsh+1))
706 random_shifts = np.zeros((n1, n2, nsh), dtype=int)
708 if n_jobs == -1:
709 n_jobs = min(multiprocessing.cpu_count(), n1)
711 # Initialize mask if None
712 if mask is None:
713 n1 = len(ts_bunch1)
714 n2 = 1 if joint_distr else len(ts_bunch2)
715 mask = np.ones((n1, n2))
717 split_ts_bunch1_inds = np.array_split(np.arange(len(ts_bunch1)), n_jobs)
718 split_ts_bunch1 = [np.array(ts_bunch1)[idxs] for idxs in split_ts_bunch1_inds]
719 split_optimal_delays = [optimal_delays[idxs] for idxs in split_ts_bunch1_inds]
720 split_mask = [mask[idxs] for idxs in split_ts_bunch1_inds]
722 parallel_result = Parallel(n_jobs=n_jobs, verbose=True)(
723 delayed(scan_pairs)(small_ts_bunch,
724 ts_bunch2,
725 metric,
726 nsh,
727 split_optimal_delays[_],
728 joint_distr=joint_distr,
729 allow_mixed_dimensions=allow_mixed_dimensions,
730 ds=ds,
731 mask=split_mask[_],
732 noise_const=noise_const,
733 seed=seed,
734 enable_progressbar=False)
735 for _, small_ts_bunch in enumerate(split_ts_bunch1))
737 for i in range(n_jobs):
738 inds_of_interest = split_ts_bunch1_inds[i]
739 random_shifts[inds_of_interest, :, :] = parallel_result[i][0][:, :, :]
740 me_total[inds_of_interest, :, :] = parallel_result[i][1][:, :, :]
742 return random_shifts, me_total
745def scan_pairs_router(ts_bunch1,
746 ts_bunch2,
747 metric,
748 nsh,
749 optimal_delays,
750 joint_distr=False,
751 allow_mixed_dimensions=False,
752 ds=1,
753 mask=None,
754 noise_const=1e-3,
755 seed=None,
756 enable_parallelization=True,
757 n_jobs=-1):
758 """
759 Route metric computation to parallel or sequential implementation.
761 Parameters
762 ----------
763 ts_bunch1 : list of TimeSeries
764 First set of time series.
765 ts_bunch2 : list of TimeSeries
766 Second set of time series.
767 metric : str
768 Similarity metric to compute:
769 - 'mi': Mutual information
770 - 'spearman': Spearman correlation
771 - Other metrics supported by get_sim function
772 nsh : int
773 Number of shuffles to perform.
774 optimal_delays : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2))
775 Pre-computed optimal delays for each pair.
776 joint_distr : bool, optional
777 If True, treats all ts_bunch2 as components of a single multifeature.
778 Default: False.
779 ds : int, optional
780 Downsampling factor. Default: 1.
781 mask : np.ndarray, optional
782 Binary mask of shape (len(ts_bunch1), len(ts_bunch2)).
783 0 = skip computation, 1 = compute. Default: all ones.
784 noise_const : float, optional
785 Small noise added to improve numerical stability. Default: 1e-3.
786 seed : int, optional
787 Random seed for reproducibility. Default: None.
788 enable_parallelization : bool, optional
789 Whether to use parallel processing. Default: True.
790 n_jobs : int, optional
791 Number of parallel jobs if parallelization enabled. Default: -1 (use all cores).
793 Returns
794 -------
795 random_shifts : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2), nsh)
796 Random shifts used for shuffling.
797 me_total : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2), nsh+1)
798 Metric values. [:,:,0] contains true values, [:,:,1:] contains shuffles.
800 See Also
801 --------
802 scan_pairs : Sequential implementation
803 scan_pairs_parallel : Parallel implementation
804 """
806 if enable_parallelization:
807 random_shifts, me_total = scan_pairs_parallel(ts_bunch1,
808 ts_bunch2,
809 metric,
810 nsh,
811 optimal_delays,
812 joint_distr=joint_distr,
813 allow_mixed_dimensions=allow_mixed_dimensions,
814 ds=ds,
815 mask=mask,
816 noise_const=noise_const,
817 seed=seed,
818 n_jobs=n_jobs)
820 else:
821 random_shifts, me_total = scan_pairs(ts_bunch1,
822 ts_bunch2,
823 metric,
824 nsh,
825 optimal_delays,
826 joint_distr=joint_distr,
827 allow_mixed_dimensions=allow_mixed_dimensions,
828 ds=ds,
829 mask=mask,
830 seed=seed,
831 noise_const=noise_const)
833 return random_shifts, me_total
836class IntenseResults(object):
837 """
838 Container for INTENSE computation results.
840 Attributes
841 ----------
842 info : dict
843 Metadata about the computation (optimal delays, thresholds, etc.).
844 intense_params : dict
845 Parameters used for the INTENSE computation.
846 stats : dict
847 Statistical results (p-values, metric values, etc.).
848 significance : dict
849 Significance test results for each neuron-feature pair.
851 Methods
852 -------
853 update(property_name, data)
854 Add or update a property with data.
855 update_multiple(datadict)
856 Update multiple properties from a dictionary.
857 save_to_hdf5(fname)
858 Save all results to an HDF5 file.
860 Examples
861 --------
862 >>> results = IntenseResults()
863 >>> results.update('stats', computed_stats)
864 >>> results.update('info', {'optimal_delays': delays})
865 >>> results.save_to_hdf5('intense_results.h5')
866 """
867 def __init__(self):
868 pass
870 def update(self, property_name, data):
871 """Add or update a property with data."""
872 setattr(self, property_name, data)
874 def update_multiple(self, datadict):
875 """Update multiple properties from a dictionary."""
876 for dname, data in datadict.items():
877 setattr(self, dname, data)
879 def save_to_hdf5(self, fname):
880 """Save all results to an HDF5 file."""
881 dict_repr = self.__dict__
882 write_dict_to_hdf5(dict_repr, fname)
885def compute_me_stats(ts_bunch1,
886 ts_bunch2,
887 names1=None,
888 names2=None,
889 mode='two_stage',
890 metric='mi',
891 precomputed_mask_stage1=None,
892 precomputed_mask_stage2=None,
893 n_shuffles_stage1=100,
894 n_shuffles_stage2=10000,
895 joint_distr=False,
896 allow_mixed_dimensions=False,
897 metric_distr_type='gamma',
898 noise_ampl=1e-3,
899 ds=1,
900 topk1=1,
901 topk2=5,
902 multicomp_correction='holm',
903 pval_thr=0.01,
904 find_optimal_delays=False,
905 skip_delays=[],
906 shift_window=100,
907 verbose=True,
908 seed=None,
909 enable_parallelization=True,
910 n_jobs=-1,
911 duplicate_behavior='ignore'):
913 """
914 Calculates similarity metric statistics for TimeSeries or MultiTimeSeries pairs
916 Parameters
917 ----------
918 ts_bunch1: list of TimeSeries objects
920 ts_bunch2: list of TimeSeries objects
922 names1: list of str
923 names than will be given to time series from tsbunch1 in final results
925 names2: list of str
926 names than will be given to time series from tsbunch2 in final results
928 mode: str
929 Computation mode. 3 modes are available:
930 'stage1': perform preliminary scanning with "n_shuffles_stage1" shuffles only.
931 Rejects strictly non-significant neuron-feature pairs, does not give definite results
932 about significance of the others.
933 'stage2': skip stage 1 and perform full-scale scanning ("n_shuffles_stage2" shuffles) of all neuron-feature pairs.
934 Gives definite results, but can be very time-consuming. Also reduces statistical power
935 of multiple comparison tests, since the number of hypotheses is very high.
936 'two_stage': prune non-significant pairs during stage 1 and perform thorough testing for the rest during stage 2.
937 Recommended mode.
938 default: 'two-stage'
940 metric: similarity metric between TimeSeries
941 default: 'mi'
943 precomputed_mask_stage1: np.array of shape (len(ts_bunch1), len(ts_bunch2)) or (len(ts_bunch), 1) if joint_distr=True
944 precomputed mask for skipping some of possible pairs in stage 1.
945 0 in mask values means calculation will be skipped.
946 1 in mask values means calculation will proceed.
948 precomputed_mask_stage2: np.array of shape (len(ts_bunch1), len(ts_bunch2)) or (len(ts_bunch), 1) if joint_distr=True
949 precomputed mask for skipping some of possible pairs in stage 2.
950 0 in mask values means calculation will be skipped.
951 1 in mask values means calculation will proceed.
953 n_shuffles_stage1: int
954 number of shuffles for first stage
955 default: 100
957 n_shuffles_stage2: int
958 number of shuffles for second stage
959 default: 10000
961 joint_distr: bool
962 if joint_distr=True, ALL features in feat_bunch will be treated as components of a single multifeature
963 For example, 'x' and 'y' features will be put together into ('x','y') multifeature.
964 default: False
966 allow_mixed_dimensions: bool
967 if True, both TimeSeries and MultiTimeSeries can be provided as signals.
968 This parameter overrides "joint_distr"
969 default: False
971 metric_distr_type: str
972 Distribution type for shuffled metric distribution fit. Supported options are distributions from scipy.stats
973 Note: While 'gamma' is theoretically appropriate for MI distributions, empirical testing shows
974 that 'norm' (normal distribution) often performs better due to its conservative p-values when
975 fitting poorly to the skewed MI data. This conservatism reduces false positives.
976 default: "gamma"
978 noise_ampl: float
979 Small noise amplitude, which is added to metrics to improve numerical fit
980 default: 1e-3
982 ds: int
983 Downsampling constant. Every "ds" point will be taken from the data time series.
984 default: 1
986 topk1: int
987 true MI for stage 1 should be among topk1 MI shuffles
988 default: 1
990 topk2: int
991 true MI for stage 2 should be among topk2 MI shuffles
992 default: 5
994 multicomp_correction: str or None
995 type of multiple comparisons correction. Supported types are None (no correction),
996 "bonferroni", "holm", and "fdr_bh".
997 default: 'holm'
999 pval_thr: float
1000 pvalue threshold. if multicomp_correction=None, this is a p-value for a single pair.
1001 For FWER methods (bonferroni, holm), this is the family-wise error rate.
1002 For FDR methods (fdr_bh), this is the false discovery rate.
1004 find_optimal_delays: bool
1005 Allows slight shifting (not more than +- shift_window) of time series,
1006 selects a shift with the highest MI as default.
1007 default: True
1009 skip_delays: list
1010 List of indices from ts_bunch2 for which delays are not applied (set to 0).
1011 Has no effect if find_optimal_delays = False
1013 shift_window: int
1014 Window for optimal shift search (frames). Optimal shift will lie in the range
1015 -shift_window <= opt_shift <= shift_window
1017 verbose: bool
1018 whether to print intermediate information
1020 seed: int
1021 random seed for reproducibility
1023 duplicate_behavior: str
1024 How to handle duplicate TimeSeries in ts_bunch1 or ts_bunch2.
1025 - 'ignore': Process duplicates normally (default)
1026 - 'raise': Raise an error if duplicates are found
1027 - 'warn': Print a warning but continue processing
1029 Returns
1030 -------
1031 stats: dict of dict of dicts
1032 Outer dict keys: indices of tsbunch1 or names1, if given
1033 Inner dict keys: indices or tsbunch2 or names2, if given
1034 Last dict: dictionary of stats variables.
1035 Can be easily converted to pandas DataFrame by pd.DataFrame(stats)
1037 significance: dict of dict of dicts
1038 Outer dict keys: indices of tsbunch1 or names1, if given
1039 Inner dict keys: indices or tsbunch2 or names2, if given
1040 Last dict: dictionary of significance-related variables.
1041 Can be easily converted to pandas DataFrame by pd.DataFrame(significance)
1043 accumulated_info: dict
1044 Data collected during computation.
1045 """
1047 # TODO: add automatic min_shifts from autocorrelation time
1049 # Validate inputs
1050 validate_time_series_bunches(ts_bunch1, ts_bunch2, allow_mixed_dimensions=allow_mixed_dimensions)
1051 validate_metric(metric)
1052 validate_common_parameters(shift_window=shift_window, ds=ds, noise_const=noise_ampl)
1054 # Validate mode
1055 if mode not in ['stage1', 'stage2', 'two_stage']:
1056 raise ValueError(f"mode must be 'stage1', 'stage2', or 'two_stage', got '{mode}'")
1058 # Validate multicomp_correction
1059 if multicomp_correction not in [None, 'bonferroni', 'holm', 'fdr_bh']:
1060 raise ValueError(f"Unknown multiple comparison correction method: '{multicomp_correction}'")
1062 # Validate pval_thr
1063 if not 0 < pval_thr < 1:
1064 raise ValueError(f"pval_thr must be between 0 and 1, got {pval_thr}")
1066 # Validate stage-specific parameters
1067 validate_common_parameters(nsh=n_shuffles_stage1)
1068 validate_common_parameters(nsh=n_shuffles_stage2)
1070 accumulated_info = dict()
1072 # Check if we're comparing the same bunch with itself
1073 same_data_bunch = ts_bunch1 is ts_bunch2
1075 n1 = len(ts_bunch1)
1076 n2 = len(ts_bunch2)
1077 if not allow_mixed_dimensions:
1078 n2 = 1 if joint_distr else len(ts_bunch2)
1080 tsbunch1_is_1d = np.all([isinstance(ts, TimeSeries) for ts in ts_bunch1])
1081 tsbunch2_is_1d = np.all([isinstance(ts, TimeSeries) for ts in ts_bunch2])
1082 if not (tsbunch1_is_1d and tsbunch2_is_1d):
1083 raise ValueError('Multiple time series types found, but allow_mixed_dimensions=False.'
1084 'Consider setting it to True')
1086 if precomputed_mask_stage1 is None:
1087 precomputed_mask_stage1 = np.ones((n1, n2))
1088 if precomputed_mask_stage2 is None:
1089 precomputed_mask_stage2 = np.ones((n1, n2))
1091 # If comparing the same bunch with itself, mask out the diagonal
1092 # to avoid computing MI of a TimeSeries with itself at zero shift
1093 if same_data_bunch:
1094 np.fill_diagonal(precomputed_mask_stage1, 0)
1095 np.fill_diagonal(precomputed_mask_stage2, 0)
1097 # Handle duplicate TimeSeries based on duplicate_behavior parameter
1098 if duplicate_behavior in ['raise', 'warn']:
1099 # Check for duplicates in ts_bunch1
1100 ts1_ids = []
1101 for ts in ts_bunch1:
1102 ts_id = id(ts.data) if hasattr(ts, 'data') else id(ts)
1103 ts1_ids.append(ts_id)
1105 if len(set(ts1_ids)) < len(ts1_ids):
1106 msg = "Duplicate TimeSeries objects found in ts_bunch1"
1107 if duplicate_behavior == 'raise':
1108 raise ValueError(msg)
1109 else: # warn
1110 print(f"Warning: {msg}")
1112 # Check for duplicates in ts_bunch2 (if not joint_distr)
1113 if not joint_distr:
1114 ts2_ids = []
1115 for ts in ts_bunch2:
1116 ts_id = id(ts.data) if hasattr(ts, 'data') else id(ts)
1117 ts2_ids.append(ts_id)
1119 if len(set(ts2_ids)) < len(ts2_ids):
1120 msg = "Duplicate TimeSeries objects found in ts_bunch2"
1121 if duplicate_behavior == 'raise':
1122 raise ValueError(msg)
1123 else: # warn
1124 print(f"Warning: {msg}")
1126 optimal_delays = np.zeros((n1, n2), dtype=int)
1127 ts_with_delays = [ts for _, ts in enumerate(ts_bunch2) if _ not in skip_delays]
1128 ts_with_delays_inds = np.array([_ for _, ts in enumerate(ts_bunch2) if _ not in skip_delays])
1130 if find_optimal_delays:
1131 if enable_parallelization:
1132 optimal_delays_res = calculate_optimal_delays_parallel(ts_bunch1,
1133 ts_with_delays,
1134 metric,
1135 shift_window,
1136 ds,
1137 verbose=verbose,
1138 n_jobs=n_jobs)
1139 else:
1140 optimal_delays_res = calculate_optimal_delays(ts_bunch1,
1141 ts_with_delays,
1142 metric,
1143 shift_window,
1144 ds,
1145 verbose=verbose)
1147 optimal_delays[:, ts_with_delays_inds] = optimal_delays_res
1149 accumulated_info['optimal_delays'] = optimal_delays
1151 # Initialize masks based on mode
1152 if mode == 'stage2':
1153 # For stage2-only mode, assume all pairs pass stage 1
1154 mask_from_stage1 = np.ones((n1, n2))
1155 else:
1156 mask_from_stage1 = np.zeros((n1, n2))
1158 mask_from_stage2 = np.zeros((n1, n2))
1159 nhyp = n1*n2
1161 if mode in ['two_stage', 'stage1']:
1162 npairs_to_check1 = int(np.sum(precomputed_mask_stage1))
1163 if verbose:
1164 print(f'Starting stage 1 scanning for {npairs_to_check1}/{nhyp} possible pairs')
1166 # STAGE 1 - primary scanning
1167 random_shifts1, me_total1 = scan_pairs_router(ts_bunch1,
1168 ts_bunch2,
1169 metric,
1170 n_shuffles_stage1,
1171 optimal_delays,
1172 joint_distr=joint_distr,
1173 allow_mixed_dimensions=allow_mixed_dimensions,
1174 ds=ds,
1175 mask=precomputed_mask_stage1,
1176 noise_const=noise_ampl,
1177 seed=seed,
1178 enable_parallelization=enable_parallelization,
1179 n_jobs=n_jobs)
1181 # turn computed data tables from stage 1 and precomputed data into dict of stats dicts
1182 stage_1_stats = get_table_of_stats(me_total1,
1183 optimal_delays,
1184 metric_distr_type=metric_distr_type,
1185 nsh=n_shuffles_stage1,
1186 precomputed_mask=precomputed_mask_stage1,
1187 stage=1)
1189 stage_1_stats_per_quantity = nested_dict_to_seq_of_tables(stage_1_stats,
1190 ordered_names1=range(n1),
1191 ordered_names2=range(n2))
1192 #print(stage_1_stats_per_quantity)
1194 # select potentially significant pairs for stage 2
1195 # 0 in mask values means the pair MI is definitely insignificant, stage 2 calculation will be skipped.
1196 # 1 in mask values means the pair MI is potentially significant, stage 2 calculation will proceed.
1198 if verbose:
1199 print('Computing significance for all pairs in stage 1...')
1201 stage_1_significance = populate_nested_dict(dict(), range(n1), range(n2))
1202 for i in range(n1):
1203 for j in range(n2):
1204 pair_passes_stage1 = criterion1(stage_1_stats[i][j],
1205 n_shuffles_stage1,
1206 topk=topk1)
1207 if pair_passes_stage1:
1208 mask_from_stage1[i, j] = 1
1210 sig1 = {'stage1': pair_passes_stage1}
1211 stage_1_significance[i][j].update(sig1)
1213 stage_1_significance_per_quantity = nested_dict_to_seq_of_tables(stage_1_significance,
1214 ordered_names1=range(n1),
1215 ordered_names2=range(n2))
1217 #print(stage_1_significance_per_quantity)
1218 accumulated_info.update(
1219 {
1220 'stage_1_significance': stage_1_significance_per_quantity,
1221 'stage_1_stats': stage_1_stats_per_quantity,
1222 'random_shifts1': random_shifts1,
1223 'me_total1': me_total1
1224 }
1225 )
1227 nhyp = int(np.sum(mask_from_stage1)) # number of hypotheses for further statistical testing
1228 if verbose:
1229 print('Stage 1 results:')
1230 print(f'{nhyp/n1/n2*100:.2f}% ({nhyp}/{n1*n2}) of possible pairs identified as candidates')
1232 if mode == 'stage1' or nhyp == 0:
1233 final_stats = add_names_to_nested_dict(stage_1_stats, names1, names2)
1234 final_significance = add_names_to_nested_dict(stage_1_significance, names1, names2)
1236 return final_stats, final_significance, accumulated_info
1238 elif mode == 'stage2':
1239 # For stage2-only mode, create empty stage 1 structures
1240 stage_1_stats = populate_nested_dict(dict(), range(n1), range(n2))
1241 stage_1_significance = populate_nested_dict(dict(), range(n1), range(n2))
1242 # Set all pairs as passing stage 1 with placeholder values
1243 for i in range(n1):
1244 for j in range(n2):
1245 stage_1_stats[i][j] = {'pre_rval': None, 'pre_pval': None}
1246 stage_1_significance[i][j]['stage1'] = True
1248 # Now proceed with stage 2
1249 if mode in ['two_stage', 'stage2']:
1250 # STAGE 2 - full-scale scanning
1251 combined_mask_for_stage_2 = np.ones((n1, n2))
1252 combined_mask_for_stage_2[np.where(mask_from_stage1 == 0)] = 0 # exclude non-significant pairs from stage1
1253 combined_mask_for_stage_2[np.where(precomputed_mask_stage2 == 0)] = 0 # exclude precomputed stage 2 pairs
1255 npairs_to_check2 = int(np.sum(combined_mask_for_stage_2))
1256 if verbose:
1257 print(f'Starting stage 2 scanning for {npairs_to_check2}/{nhyp} possible pairs')
1259 random_shifts2, me_total2 = scan_pairs_router(ts_bunch1,
1260 ts_bunch2,
1261 metric,
1262 n_shuffles_stage2,
1263 optimal_delays,
1264 joint_distr=joint_distr,
1265 allow_mixed_dimensions=allow_mixed_dimensions,
1266 ds=ds,
1267 mask=combined_mask_for_stage_2,
1268 noise_const=noise_ampl,
1269 seed=seed,
1270 enable_parallelization=enable_parallelization,
1271 n_jobs=n_jobs)
1273 # turn data tables from stage 2 to array of stats dicts
1274 stage_2_stats = get_table_of_stats(me_total2,
1275 optimal_delays,
1276 metric_distr_type=metric_distr_type,
1277 nsh=n_shuffles_stage2,
1278 precomputed_mask=combined_mask_for_stage_2,
1279 stage=2)
1281 stage_2_stats_per_quantity = nested_dict_to_seq_of_tables(stage_2_stats,
1282 ordered_names1=range(n1),
1283 ordered_names2=range(n2))
1284 #print(stage_2_stats_per_quantity)
1286 # select significant pairs after stage 2
1287 if verbose:
1288 print('Computing significance for all pairs in stage 2...')
1289 all_pvals = None
1290 if multicomp_correction in ['holm', 'fdr_bh']: # these procedures require all p-values
1291 all_pvals = get_all_nonempty_pvals(stage_2_stats, range(n1), range(n2))
1293 multicorr_thr = get_multicomp_correction_thr(pval_thr,
1294 mode=multicomp_correction,
1295 all_pvals=all_pvals,
1296 nhyp=nhyp)
1298 stage_2_significance = populate_nested_dict(dict(), range(n1), range(n2))
1299 for i in range(n1):
1300 for j in range(n2):
1301 pair_passes_stage2 = criterion2(stage_2_stats[i][j],
1302 n_shuffles_stage2,
1303 multicorr_thr,
1304 topk=topk2)
1305 if pair_passes_stage2:
1306 mask_from_stage2[i,j] = 1
1308 sig2 = {'stage2': pair_passes_stage2}
1309 stage_2_significance[i][j] = sig2
1311 stage_2_significance_per_quantity = nested_dict_to_seq_of_tables(stage_2_significance,
1312 ordered_names1=range(n1),
1313 ordered_names2=range(n2))
1315 #print(stage_2_significance_per_quantity)
1316 accumulated_info.update(
1317 {
1318 'stage_2_significance': stage_2_significance_per_quantity,
1319 'stage_2_stats': stage_2_stats_per_quantity,
1320 'random_shifts2': random_shifts2,
1321 'me_total2': me_total2,
1322 'corrected_pval_thr': multicorr_thr,
1323 'group_pval_thr': pval_thr,
1324 }
1325 )
1327 num2 = int(np.sum(mask_from_stage2))
1328 if verbose:
1329 print('Stage 2 results:')
1330 print(f'{num2/n1/n2*100:.2f}% ({num2}/{n1*n2}) of possible pairs identified as significant')
1332 # Always merge stats for consistency
1333 merged_stats = merge_stage_stats(stage_1_stats, stage_2_stats)
1334 merged_significance = merge_stage_significance(stage_1_significance, stage_2_significance)
1335 final_stats = add_names_to_nested_dict(merged_stats, names1, names2)
1336 final_significance = add_names_to_nested_dict(merged_significance, names1, names2)
1337 return final_stats, final_significance, accumulated_info
1340def get_multicomp_correction_thr(fwer, mode='holm', **multicomp_kwargs):
1341 """
1342 Calculate p-value threshold for multiple hypothesis correction.
1344 Parameters
1345 ----------
1346 fwer : float
1347 Family-wise error rate or false discovery rate (e.g., 0.05).
1348 mode : str or None, optional
1349 Multiple comparison correction method. Default: 'holm'.
1350 - None: No correction, threshold = fwer
1351 - 'bonferroni': Bonferroni correction (FWER control)
1352 - 'holm': Holm-Bonferroni correction (FWER control, more powerful)
1353 - 'fdr_bh': Benjamini-Hochberg FDR correction
1354 **multicomp_kwargs : dict
1355 Additional arguments for correction method:
1356 - For 'bonferroni': nhyp (int) - number of hypotheses
1357 - For 'holm': all_pvals (list) - all p-values to be tested
1358 - For 'fdr_bh': all_pvals (list) - all p-values to be tested
1360 Returns
1361 -------
1362 threshold : float
1363 Adjusted p-value threshold for individual hypothesis testing.
1365 Raises
1366 ------
1367 ValueError
1368 If required arguments are missing or unknown method specified.
1370 Notes
1371 -----
1372 - FWER methods (bonferroni, holm) control probability of ANY false positive
1373 - FDR methods control expected proportion of false positives among rejections
1374 - Holm is uniformly more powerful than Bonferroni
1375 - FDR typically allows more discoveries but with controlled false positive rate
1377 Examples
1378 --------
1379 >>> # Holm correction (default)
1380 >>> pvals = [0.001, 0.01, 0.02, 0.03, 0.04]
1381 >>> thr = get_multicomp_correction_thr(0.05, mode='holm', all_pvals=pvals)
1382 >>>
1383 >>> # FDR correction
1384 >>> thr = get_multicomp_correction_thr(0.05, mode='fdr_bh', all_pvals=pvals)
1385 """
1386 if mode is None:
1387 threshold = fwer
1389 elif mode == 'bonferroni':
1390 if 'nhyp' in multicomp_kwargs:
1391 threshold = fwer / multicomp_kwargs['nhyp']
1392 else:
1393 raise ValueError('Number of hypotheses for Bonferroni correction not provided')
1395 elif mode == 'holm':
1396 if 'all_pvals' in multicomp_kwargs:
1397 all_pvals = sorted(multicomp_kwargs['all_pvals'])
1398 nhyp = len(all_pvals)
1399 threshold = 0 # Default if no discoveries
1400 for i, pval in enumerate(all_pvals):
1401 cthr = fwer / (nhyp - i)
1402 if pval > cthr:
1403 break
1404 threshold = cthr
1405 else:
1406 raise ValueError('List of p-values for Holm correction not provided')
1408 elif mode == 'fdr_bh':
1409 if 'all_pvals' in multicomp_kwargs:
1410 all_pvals = sorted(multicomp_kwargs['all_pvals'])
1411 nhyp = len(all_pvals)
1412 threshold = 0.0
1414 # Benjamini-Hochberg procedure
1415 for i in range(nhyp - 1, -1, -1):
1416 if all_pvals[i] <= fwer * (i + 1) / nhyp:
1417 threshold = all_pvals[i]
1418 break
1419 else:
1420 raise ValueError('List of p-values for FDR correction not provided')
1422 else:
1423 raise ValueError('Unknown multiple comparisons correction method')
1425 return threshold