Coverage for src/driada/intense/intense_base.py: 87.19%

406 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1import numpy as np 

2import tqdm 

3from joblib import Parallel, delayed 

4import multiprocessing 

5import scipy.stats 

6 

7from .stats import * 

8from ..information.info_base import TimeSeries, MultiTimeSeries, get_1d_mi, get_multi_mi, get_mi, get_sim 

9from ..utils.data import write_dict_to_hdf5, nested_dict_to_seq_of_tables 

10 

11 

12def validate_time_series_bunches(ts_bunch1, ts_bunch2, allow_mixed_dimensions=False): 

13 """ 

14 Validate time series bunches for INTENSE computations. 

15  

16 Parameters 

17 ---------- 

18 ts_bunch1 : list 

19 First set of time series. 

20 ts_bunch2 : list 

21 Second set of time series. 

22 allow_mixed_dimensions : bool, optional 

23 Whether to allow mixed TimeSeries and MultiTimeSeries. Default: False. 

24  

25 Raises 

26 ------ 

27 ValueError 

28 If validation fails. 

29 """ 

30 if len(ts_bunch1) == 0: 

31 raise ValueError("ts_bunch1 cannot be empty") 

32 if len(ts_bunch2) == 0: 

33 raise ValueError("ts_bunch2 cannot be empty") 

34 

35 # Check time series types 

36 if not allow_mixed_dimensions: 

37 ts1_types = [type(ts) for ts in ts_bunch1] 

38 ts2_types = [type(ts) for ts in ts_bunch2] 

39 

40 if not all(t == TimeSeries for t in ts1_types): 

41 if any(t == MultiTimeSeries for t in ts1_types): 

42 raise ValueError("MultiTimeSeries found in ts_bunch1 but allow_mixed_dimensions=False") 

43 else: 

44 raise ValueError("ts_bunch1 must contain TimeSeries objects") 

45 

46 if not all(t == TimeSeries for t in ts2_types): 

47 if any(t == MultiTimeSeries for t in ts2_types): 

48 raise ValueError("MultiTimeSeries found in ts_bunch2 but allow_mixed_dimensions=False") 

49 else: 

50 raise ValueError("ts_bunch2 must contain TimeSeries objects") 

51 

52 # Check lengths match 

53 lengths1 = [len(ts.data) if isinstance(ts, TimeSeries) else ts.data.shape[1] for ts in ts_bunch1] 

54 lengths2 = [len(ts.data) if isinstance(ts, TimeSeries) else ts.data.shape[1] for ts in ts_bunch2] 

55 

56 if len(set(lengths1)) > 1: 

57 raise ValueError(f"All time series in ts_bunch1 must have same length, got {set(lengths1)}") 

58 if len(set(lengths2)) > 1: 

59 raise ValueError(f"All time series in ts_bunch2 must have same length, got {set(lengths2)}") 

60 if lengths1[0] != lengths2[0]: 

61 raise ValueError(f"Time series lengths don't match: {lengths1[0]} vs {lengths2[0]}") 

62 

63 

64def validate_metric(metric, allow_scipy=True): 

65 """ 

66 Validate metric name and check if it's supported. 

67  

68 Parameters 

69 ---------- 

70 metric : str 

71 Metric name to validate. 

72 allow_scipy : bool, optional 

73 Whether to allow scipy correlation metrics. Default: True. 

74  

75 Returns 

76 ------- 

77 metric_type : str 

78 Type of metric: 'mi', 'correlation', 'special', or 'scipy'. 

79  

80 Raises 

81 ------ 

82 ValueError 

83 If metric is not supported. 

84 """ 

85 # Built-in metrics 

86 if metric == 'mi': 

87 return 'mi' 

88 

89 # Special metrics 

90 if metric in ['av', 'fast_pearsonr']: 

91 return 'special' 

92 

93 # Common correlation metrics (shorthand names) 

94 correlation_metrics = ['spearman', 'pearson', 'kendall'] 

95 if metric in correlation_metrics: 

96 return 'correlation' 

97 

98 # Full scipy names 

99 scipy_correlation_metrics = ['spearmanr', 'pearsonr', 'kendalltau'] 

100 if metric in scipy_correlation_metrics: 

101 return 'scipy' 

102 

103 # Check if it's a scipy function 

104 if allow_scipy: 

105 try: 

106 import scipy.stats 

107 if hasattr(scipy.stats, metric): 

108 return 'scipy' 

109 except ImportError: 

110 pass 

111 

112 # If we get here, metric is not supported 

113 raise ValueError(f"Unsupported metric: {metric}. Supported metrics include: " 

114 f"'mi', 'av', 'fast_pearsonr', 'spearman', 'pearson', 'kendall', " 

115 f"'spearmanr', 'pearsonr', 'kendalltau', and other scipy.stats functions.") 

116 

117 

118def validate_common_parameters(shift_window=None, ds=None, nsh=None, noise_const=None): 

119 """ 

120 Validate common INTENSE parameters. 

121  

122 Parameters 

123 ---------- 

124 shift_window : int, optional 

125 Maximum shift window in frames. 

126 ds : int, optional 

127 Downsampling factor. 

128 nsh : int, optional 

129 Number of shuffles. 

130 noise_const : float, optional 

131 Noise constant for numerical stability. 

132  

133 Raises 

134 ------ 

135 ValueError 

136 If any parameter is invalid. 

137 """ 

138 if shift_window is not None and shift_window < 0: 

139 raise ValueError(f"shift_window must be non-negative, got {shift_window}") 

140 

141 if ds is not None and ds <= 0: 

142 raise ValueError(f"ds must be positive, got {ds}") 

143 

144 if nsh is not None and nsh <= 0: 

145 raise ValueError(f"nsh must be positive, got {nsh}") 

146 

147 if noise_const is not None and noise_const < 0: 

148 raise ValueError(f"noise_const must be non-negative, got {noise_const}") 

149 

150 

151def calculate_optimal_delays(ts_bunch1, ts_bunch2, metric, 

152 shift_window, ds, verbose=True, enable_progressbar=True): 

153 """ 

154 Calculate optimal temporal delays between pairs of time series. 

155  

156 Finds the delay that maximizes the similarity metric between each pair of time series 

157 from ts_bunch1 and ts_bunch2. This accounts for temporal offsets in neural responses 

158 relative to behavioral variables. 

159  

160 Parameters 

161 ---------- 

162 ts_bunch1 : list of TimeSeries 

163 First set of time series (typically neural signals). 

164 ts_bunch2 : list of TimeSeries 

165 Second set of time series (typically behavioral variables). 

166 metric : str 

167 Similarity metric to maximize. Options include: 

168 - 'mi': Mutual information 

169 - 'spearman': Spearman correlation 

170 - Other metrics supported by get_sim function 

171 shift_window : int 

172 Maximum shift to test in each direction (frames). 

173 Will test shifts from -shift_window to +shift_window. 

174 ds : int 

175 Downsampling factor. Every ds-th point is used from the time series. 

176 Default: 1 (no downsampling). 

177 verbose : bool, optional 

178 Whether to print progress information. Default: True. 

179 enable_progressbar : bool, optional 

180 Whether to show progress bar. Default: True. 

181  

182 Returns 

183 ------- 

184 optimal_delays : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2)) 

185 Optimal delay (in frames) for each pair. Positive values indicate 

186 that ts2 leads ts1, negative values indicate ts1 leads ts2. 

187  

188 Notes 

189 ----- 

190 - Computational complexity: O(n1 * n2 * shifts) where n1, n2 are lengths 

191 of ts_bunch1 and ts_bunch2, and shifts = 2 * shift_window / ds 

192 - The optimal delay is found by exhaustive search over all possible shifts 

193 - Memory efficient: only stores final optimal delays, not all tested values 

194  

195 Examples 

196 -------- 

197 >>> neurons = [neuron1.ca, neuron2.ca] # calcium signals 

198 >>> behaviors = [speed_ts, direction_ts] # behavioral variables 

199 >>> delays = calculate_optimal_delays(neurons, behaviors, 'mi',  

200 ... shift_window=100, ds=1) 

201 >>> print(f"Neuron 1 optimal delay with speed: {delays[0, 0]} frames") 

202 """ 

203 # Validate inputs 

204 validate_time_series_bunches(ts_bunch1, ts_bunch2, allow_mixed_dimensions=False) 

205 validate_metric(metric) 

206 validate_common_parameters(shift_window=shift_window, ds=ds) 

207 

208 if verbose: 

209 print('Calculating optimal delays:') 

210 

211 optimal_delays = np.zeros((len(ts_bunch1), len(ts_bunch2)), dtype=int) 

212 shifts = np.arange(-shift_window, shift_window, ds) // ds 

213 

214 for i, ts1 in tqdm.tqdm(enumerate(ts_bunch1), total=len(ts_bunch1), disable=not enable_progressbar): 

215 for j, ts2 in enumerate(ts_bunch2): 

216 shifted_me = [] 

217 for shift in shifts: 

218 lag_me = get_sim(ts1, ts2, metric, ds=ds, shift=int(shift)) 

219 shifted_me.append(lag_me) 

220 

221 best_shift = shifts[np.argmax(shifted_me)] 

222 optimal_delays[i, j] = int(best_shift*ds) 

223 

224 return optimal_delays 

225 

226 

227def calculate_optimal_delays_parallel(ts_bunch1, ts_bunch2, metric, 

228 shift_window, ds, verbose=True, n_jobs=-1): 

229 """ 

230 Calculate optimal temporal delays between pairs of time series using parallel processing. 

231  

232 Parallel version of calculate_optimal_delays that distributes computation across 

233 multiple CPU cores for improved performance with large datasets. 

234  

235 Parameters 

236 ---------- 

237 ts_bunch1 : list of TimeSeries 

238 First set of time series (typically neural signals). 

239 ts_bunch2 : list of TimeSeries 

240 Second set of time series (typically behavioral variables). 

241 metric : str 

242 Similarity metric to maximize. Options include: 

243 - 'mi': Mutual information 

244 - 'spearman': Spearman correlation 

245 - Other metrics supported by get_sim function 

246 shift_window : int 

247 Maximum shift to test in each direction (frames). 

248 Will test shifts from -shift_window to +shift_window. 

249 ds : int 

250 Downsampling factor. Every ds-th point is used from the time series. 

251 verbose : bool, optional 

252 Whether to print progress information. Default: True. 

253 n_jobs : int, optional 

254 Number of parallel jobs to run. Default: -1 (use all available cores). 

255  

256 Returns 

257 ------- 

258 optimal_delays : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2)) 

259 Optimal delay (in frames) for each pair. Positive values indicate 

260 that ts2 leads ts1, negative values indicate ts1 leads ts2. 

261  

262 Notes 

263 ----- 

264 - Parallelization is done by splitting ts_bunch1 across workers 

265 - Each worker processes a subset of ts_bunch1 against all of ts_bunch2 

266 - Memory usage scales with number of workers 

267 - Speedup is typically sublinear due to overhead and memory bandwidth 

268  

269 See Also 

270 -------- 

271 calculate_optimal_delays : Sequential version of this function 

272  

273 Examples 

274 -------- 

275 >>> neurons = [neuron.ca for neuron in exp.neurons[:100]] 

276 >>> behaviors = [exp.speed, exp.direction] 

277 >>> # Use 8 cores for faster computation 

278 >>> delays = calculate_optimal_delays_parallel(neurons, behaviors, 'mi', 

279 ... shift_window=100, ds=1, n_jobs=8) 

280 """ 

281 # Validate inputs 

282 validate_time_series_bunches(ts_bunch1, ts_bunch2, allow_mixed_dimensions=False) 

283 validate_metric(metric) 

284 validate_common_parameters(shift_window=shift_window, ds=ds) 

285 

286 if verbose: 

287 print('Calculating optimal delays in parallel mode:') 

288 

289 optimal_delays = np.zeros((len(ts_bunch1), len(ts_bunch2)), dtype=int) 

290 

291 if n_jobs == -1: 

292 n_jobs = min(multiprocessing.cpu_count(), len(ts_bunch1)) 

293 

294 split_ts_bunch1_inds = np.array_split(np.arange(len(ts_bunch1)), n_jobs) 

295 split_ts_bunch1 = [np.array(ts_bunch1)[idxs] for idxs in split_ts_bunch1_inds] 

296 

297 parallel_delays = Parallel(n_jobs=n_jobs, verbose=True)( 

298 delayed(calculate_optimal_delays)(small_ts_bunch, 

299 ts_bunch2, 

300 metric, 

301 shift_window, 

302 ds, 

303 verbose=False, 

304 enable_progressbar=False) 

305 for small_ts_bunch in split_ts_bunch1) 

306 

307 for i, pd in enumerate(parallel_delays): 

308 inds_of_interest = split_ts_bunch1_inds[i] 

309 optimal_delays[inds_of_interest, :] = pd 

310 

311 return optimal_delays 

312 

313 

314def get_calcium_feature_me_profile(exp, cell_id=None, feat_id=None, cbunch=None, fbunch=None, 

315 window=1000, ds=1, metric='mi', data_type='calcium'): 

316 """ 

317 Compute metric profile between neurons and behavioral features across time shifts. 

318  

319 Parameters 

320 ---------- 

321 exp : Experiment 

322 Experiment object containing neurons and behavioral features. 

323 cell_id : int, optional 

324 Index of a single neuron in exp.neurons. Deprecated - use cbunch instead. 

325 feat_id : str or tuple of str, optional 

326 Single feature name(s) to analyze. Deprecated - use fbunch instead. 

327 cbunch : int, iterable or None, optional 

328 Neuron indices. If None (default), all neurons will be analyzed. 

329 Takes precedence over cell_id if both provided. 

330 fbunch : str, iterable or None, optional 

331 Feature names. If None (default), all single features will be analyzed. 

332 Takes precedence over feat_id if both provided. 

333 window : int, optional 

334 Maximum shift to test in each direction (frames). Default: 1000. 

335 ds : int, optional 

336 Downsampling factor. Default: 1 (no downsampling). 

337 metric : str, optional 

338 Similarity metric to compute. Default: 'mi'. 

339 - 'mi': Mutual information 

340 - 'spearman': Spearman correlation 

341 - Other metrics supported by get_sim function 

342 data_type : str, optional 

343 Type of neural data to use. Default: 'calcium'. 

344 - 'calcium': Use calcium imaging data 

345 - 'spikes': Use spike data 

346  

347 Returns 

348 ------- 

349 dict 

350 If single cell_id and feat_id provided (backward compatibility): 

351 {'me0': float, 'shifted_me': list of float} 

352 If cbunch or fbunch used: 

353 Nested dictionary with structure: 

354 {cell_id: {feat_id: {'me0': float, 'shifted_me': list}}} 

355 where shifted_me contains metric values from -window to +window. 

356  

357 Notes 

358 ----- 

359 - Total number of shifts tested: 2 * window / ds 

360 - Multi-feature analysis (tuple feat_id) only supported for metric='mi' 

361 - Progress bar shows computation progress 

362  

363 Examples 

364 -------- 

365 >>> # Backward compatibility - single cell and feature 

366 >>> mi_zero, mi_profile = get_calcium_feature_me_profile(exp, 0, 'speed') 

367 >>>  

368 >>> # New usage - analyze multiple cells and features 

369 >>> results = get_calcium_feature_me_profile(exp, cbunch=[0, 1, 2], fbunch=['speed', 'head_direction']) 

370 >>> # Access specific result: results[cell_id][feat_id]['me0'] and ['shifted_me'] 

371 >>>  

372 >>> # Analyze all cells with all features 

373 >>> results = get_calcium_feature_me_profile(exp, cbunch=None, fbunch=None) 

374 >>>  

375 >>> # Multi-feature joint mutual information 

376 >>> results = get_calcium_feature_me_profile(exp, cbunch=[0], fbunch=[('x', 'y')]) 

377 """ 

378 # Validate inputs 

379 validate_common_parameters(ds=ds) 

380 validate_metric(metric) 

381 

382 if window <= 0: 

383 raise ValueError(f"window must be positive, got {window}") 

384 

385 # Check if single cell/feature mode (backward compatibility) 

386 single_mode = (cell_id is not None and feat_id is not None and 

387 cbunch is None and fbunch is None) 

388 

389 # Handle backward compatibility - if old-style single cell_id/feat_id provided 

390 if cbunch is None and cell_id is not None: 

391 cbunch = cell_id 

392 if fbunch is None and feat_id is not None: 

393 fbunch = feat_id 

394 

395 # Process cbunch and fbunch using experiment's methods 

396 cell_ids = exp._process_cbunch(cbunch) 

397 feat_ids = exp._process_fbunch(fbunch, allow_multifeatures=True, mode=data_type) 

398 

399 # Validate cell indices 

400 for cid in cell_ids: 

401 if not (0 <= cid < len(exp.neurons)): 

402 raise ValueError(f"cell_id {cid} out of range [0, {len(exp.neurons)-1}]") 

403 

404 # Initialize results dictionary 

405 results = {} 

406 

407 # Progress bar for all combinations 

408 total_combinations = len(cell_ids) * len(feat_ids) 

409 pbar = tqdm.tqdm(total=total_combinations, desc="Computing ME profiles") 

410 

411 for cid in cell_ids: 

412 cell = exp.neurons[cid] 

413 ts1 = cell.ca if data_type == 'calcium' else cell.spikes 

414 results[cid] = {} 

415 

416 for fid in feat_ids: 

417 shifted_me = [] 

418 

419 if isinstance(fid, str): 

420 # Single feature 

421 ts2 = exp.dynamic_features[fid] 

422 me0 = get_sim(ts1, ts2, metric, ds=ds) 

423 

424 for shift in np.arange(-window, window, ds)//ds: 

425 lag_me = get_sim(ts1, ts2, metric, ds=ds, shift=shift) 

426 shifted_me.append(lag_me) 

427 

428 else: 

429 # Multi-feature (tuple) 

430 if metric != 'mi': 

431 raise ValueError(f"Multi-feature analysis only supported for metric='mi', got '{metric}'") 

432 feats = [exp.dynamic_features[f] for f in fid] 

433 me0 = get_multi_mi(feats, ts1, ds=ds) 

434 

435 for shift in np.arange(-window, window, ds)//ds: 

436 lag_me = get_multi_mi(feats, ts1, ds=ds, shift=shift) 

437 shifted_me.append(lag_me) 

438 

439 results[cid][fid] = {'me0': me0, 'shifted_me': shifted_me} 

440 pbar.update(1) 

441 

442 pbar.close() 

443 

444 # Return format based on usage mode 

445 if single_mode: 

446 # Backward compatibility - return simple format 

447 return results[cell_ids[0]][feat_ids[0]]['me0'], results[cell_ids[0]][feat_ids[0]]['shifted_me'] 

448 else: 

449 # New format - return full results dictionary 

450 return results 

451 

452 

453def scan_pairs(ts_bunch1, 

454 ts_bunch2, 

455 metric, 

456 nsh, 

457 optimal_delays, 

458 joint_distr=False, 

459 ds=1, 

460 mask=None, 

461 noise_const=1e-3, 

462 seed=None, 

463 allow_mixed_dimensions=False, 

464 enable_progressbar=True): 

465 

466 """ 

467 Calculates MI shuffles for 2 given sets of TimeSeries 

468 This function is generally assumed to be used internally, 

469 but can be also called manually to "look inside" high-level computation routines 

470 

471 Parameters 

472 ---------- 

473 ts_bunch1: list of TimeSeries objects 

474 

475 ts_bunch2: list of TimeSeries objects 

476  

477 metric: similarity metric between TimeSeries 

478  

479 nsh: int 

480 number of shuffles 

481 

482 joint_distr: bool 

483 if joint_distr=True, ALL (sic!) TimeSeries in ts_bunch2 will be treated as components of a single multifeature 

484 default: False 

485 

486 ds: int 

487 Downsampling constant. Every "ds" point will be taken from the data time series. 

488 default: 1 

489 

490 mask: np.array of shape (len(ts_bunch1), len(ts_bunch2)) or (len(ts_bunch), 1) if joint_distr=True 

491 precomputed mask for skipping some of possible pairs. 

492 0 in mask values means calculation will be skipped. 

493 1 in mask values means calculation will proceed. 

494 

495 noise_const: float 

496 Small noise amplitude, which is added to MI and shuffled MI to improve numerical fit 

497 default: 1e-3 

498 

499 optimal_delays: np.array of shape (len(ts_bunch1), len(ts_bunch2)) or (len(ts_bunch), 1) if joint_distr=True 

500 best shifts from original time series alignment in terms of MI. 

501 

502 seed: int 

503 Random seed for reproducibility 

504 

505 Returns 

506 ------- 

507 random_shifts: np.array of shape (len(ts_bunch1), len(ts_bunch2), nsh) 

508 signals shifts used for MI distribution computation 

509 

510 me_total: np.array of shape (len(ts_bunch1), len(ts_bunch2)), nsh+1) or (len(ts_bunch1), 1, nsh+1) if joint_distr==True 

511 Aggregated array of true and shuffled MI values. 

512 True MI matrix can be obtained by me_total[:,:,0] 

513 Shuffled MI tensor of shape (len(ts_bunch1), len(ts_bunch2)), nsh) or (len(ts_bunch1), 1, nsh) if joint_distr==True 

514 can be obtained by me_total[:,:,1:] 

515 """ 

516 

517 # Validate inputs 

518 validate_time_series_bunches(ts_bunch1, ts_bunch2, allow_mixed_dimensions=allow_mixed_dimensions) 

519 validate_metric(metric) 

520 validate_common_parameters(ds=ds, nsh=nsh, noise_const=noise_const) 

521 

522 # Validate optimal_delays shape 

523 n1 = len(ts_bunch1) 

524 n2 = 1 if joint_distr else len(ts_bunch2) 

525 

526 if optimal_delays.shape != (n1, n2): 

527 raise ValueError(f"optimal_delays shape {optimal_delays.shape} doesn't match expected ({n1}, {n2})") 

528 

529 if seed is None: 

530 seed = 0 

531 

532 np.random.seed(seed) 

533 

534 lengths1 = [len(ts.data) if isinstance(ts, TimeSeries) else ts.data.shape[1] for ts in ts_bunch1] 

535 lengths2 = [len(ts.data) if isinstance(ts, TimeSeries) else ts.data.shape[1] for ts in ts_bunch2] 

536 if len(set(lengths1)) == 1 and len(set(lengths2)) == 1 and set(lengths1) == set(lengths2): 

537 t = lengths1[0] # full length is the same for all time series 

538 else: 

539 raise ValueError('Lenghts of TimeSeries do not match!') 

540 

541 if mask is None: 

542 mask = np.ones((n1, n2)) 

543 

544 me_table = np.zeros((n1, n2)) 

545 me_table_shuffles = np.zeros((n1, n2, nsh)) 

546 random_shifts = np.zeros((n1, n2, nsh), dtype=int) 

547 

548 # fill random shifts according to the allowed shuffles masks of both time series 

549 for i, ts1 in enumerate(ts_bunch1): 

550 if joint_distr: 

551 np.random.seed(seed) 

552 # Combine shuffle masks from ts1 and all ts in tsbunch2 

553 combined_shuffle_mask = ts1.shuffle_mask.copy() 

554 for ts2 in ts_bunch2: 

555 combined_shuffle_mask = combined_shuffle_mask & ts2.shuffle_mask 

556 # move shuffle mask according to optimal shift 

557 combined_shuffle_mask = np.roll(combined_shuffle_mask, int(optimal_delays[i, 0])) 

558 indices_to_select = np.arange(t)[combined_shuffle_mask] 

559 random_shifts[i, 0, :] = np.random.choice(indices_to_select, size=nsh) // ds 

560 

561 else: 

562 for j, ts2 in enumerate(ts_bunch2): 

563 np.random.seed(seed) 

564 combined_shuffle_mask = ts1.shuffle_mask & ts2.shuffle_mask 

565 # move shuffle mask according to optimal shift 

566 combined_shuffle_mask = np.roll(combined_shuffle_mask, int(optimal_delays[i, j])) 

567 indices_to_select = np.arange(t)[combined_shuffle_mask] 

568 random_shifts[i, j, :] = np.random.choice(indices_to_select, size=nsh)//ds 

569 

570 # calculate similarity metric arrays 

571 for i, ts1 in tqdm.tqdm(enumerate(ts_bunch1), 

572 total=len(ts_bunch1), 

573 position=0, 

574 leave=True, 

575 disable=not enable_progressbar): 

576 

577 np.random.seed(seed) 

578 

579 # TODO: deprecate this branch, it is unnecessary with MultiTimeSeries 

580 if joint_distr: 

581 if metric != 'mi': 

582 raise ValueError("joint_distr mode works with metric = 'mi' only") 

583 if mask[i,0] == 1: 

584 # default metric without shuffling, minus due to different order 

585 me0 = get_multi_mi(ts_bunch2, ts1, ds=ds, shift=-optimal_delays[i, 0]//ds) 

586 me_table[i,0] = me0 + np.random.random()*noise_const # add small noise for better fitting 

587 

588 np.random.seed(seed) 

589 random_noise = np.random.random(size=len(random_shifts[i, 0, :])) * noise_const # add small noise for better fitting 

590 for k, shift in enumerate(random_shifts[i, 0, :]): 

591 mi = get_multi_mi(ts_bunch2, ts1, ds=ds, shift=shift) 

592 me_table_shuffles[i,0,k] = mi + random_noise[k] 

593 

594 else: 

595 me_table[i,0] = None 

596 me_table_shuffles[i,0,:] = np.full(shape=nsh, fill_value=None) 

597 

598 else: 

599 for j, ts2 in enumerate(ts_bunch2): 

600 if mask[i,j] == 1: 

601 me0 = get_sim(ts1, 

602 ts2, 

603 metric, 

604 ds=ds, 

605 shift=optimal_delays[i, j]//ds, 

606 check_for_coincidence=True) # default metric without shuffling 

607 

608 np.random.seed(seed) 

609 me_table[i,j] = me0 + np.random.random()*noise_const # add small noise for better fitting 

610 

611 np.random.seed(seed) 

612 random_noise = np.random.random( 

613 size=len(random_shifts[i, j, :])) * noise_const # add small noise for better fitting 

614 

615 for k, shift in enumerate(random_shifts[i,j,:]): 

616 np.random.seed(seed) 

617 #mi = get_1d_mi(ts1, ts2, shift=shift, ds=ds) 

618 me = get_sim(ts1, 

619 ts2, 

620 metric, 

621 ds=ds, 

622 shift=shift) 

623 

624 me_table_shuffles[i,j,k] = me + random_noise[k] 

625 

626 else: 

627 me_table[i,j] = None 

628 me_table_shuffles[i,j,:] = np.array([None for _ in range(nsh)]) 

629 

630 me_total = np.dstack((me_table, me_table_shuffles)) 

631 

632 return random_shifts, me_total 

633 

634 

635def scan_pairs_parallel(ts_bunch1, 

636 ts_bunch2, 

637 metric, 

638 nsh, 

639 optimal_delays, 

640 joint_distr=False, 

641 allow_mixed_dimensions=False, 

642 ds=1, 

643 mask=None, 

644 noise_const=1e-3, 

645 seed=None, 

646 n_jobs=-1): 

647 """ 

648 Calculate metric values and shuffles for time series pairs using parallel processing. 

649  

650 Parameters 

651 ---------- 

652 ts_bunch1 : list of TimeSeries 

653 First set of time series. 

654 ts_bunch2 : list of TimeSeries 

655 Second set of time series. 

656 metric : str 

657 Similarity metric to compute: 

658 - 'mi': Mutual information 

659 - 'spearman': Spearman correlation 

660 - Other metrics supported by get_sim function 

661 nsh : int 

662 Number of shuffles to perform. 

663 optimal_delays : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2)) 

664 Pre-computed optimal delays for each pair. 

665 joint_distr : bool, optional 

666 If True, treats all ts_bunch2 as components of a single multifeature. 

667 Default: False. 

668 ds : int, optional 

669 Downsampling factor. Default: 1. 

670 mask : np.ndarray, optional 

671 Binary mask of shape (len(ts_bunch1), len(ts_bunch2)). 

672 0 = skip computation, 1 = compute. Default: all ones. 

673 noise_const : float, optional 

674 Small noise added to improve numerical stability. Default: 1e-3. 

675 seed : int, optional 

676 Random seed for reproducibility. Default: None. 

677 n_jobs : int, optional 

678 Number of parallel jobs. Default: -1 (use all cores). 

679  

680 Returns 

681 ------- 

682 random_shifts : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2), nsh) 

683 Random shifts used for shuffling. 

684 me_total : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2), nsh+1) 

685 Metric values. [:,:,0] contains true values, [:,:,1:] contains shuffles. 

686  

687 See Also 

688 -------- 

689 scan_pairs : Sequential version of this function 

690 scan_pairs_router : Wrapper that chooses between parallel and sequential 

691 """ 

692 

693 # Validate inputs 

694 validate_time_series_bunches(ts_bunch1, ts_bunch2, allow_mixed_dimensions=allow_mixed_dimensions) 

695 validate_metric(metric) 

696 validate_common_parameters(ds=ds, nsh=nsh, noise_const=noise_const) 

697 

698 n1 = len(ts_bunch1) 

699 n2 = 1 if joint_distr else len(ts_bunch2) 

700 

701 # Validate optimal_delays shape 

702 if optimal_delays.shape != (n1, n2): 

703 raise ValueError(f"optimal_delays shape {optimal_delays.shape} doesn't match expected ({n1}, {n2})") 

704 

705 me_total = np.zeros((n1, n2, nsh+1)) 

706 random_shifts = np.zeros((n1, n2, nsh), dtype=int) 

707 

708 if n_jobs == -1: 

709 n_jobs = min(multiprocessing.cpu_count(), n1) 

710 

711 # Initialize mask if None 

712 if mask is None: 

713 n1 = len(ts_bunch1) 

714 n2 = 1 if joint_distr else len(ts_bunch2) 

715 mask = np.ones((n1, n2)) 

716 

717 split_ts_bunch1_inds = np.array_split(np.arange(len(ts_bunch1)), n_jobs) 

718 split_ts_bunch1 = [np.array(ts_bunch1)[idxs] for idxs in split_ts_bunch1_inds] 

719 split_optimal_delays = [optimal_delays[idxs] for idxs in split_ts_bunch1_inds] 

720 split_mask = [mask[idxs] for idxs in split_ts_bunch1_inds] 

721 

722 parallel_result = Parallel(n_jobs=n_jobs, verbose=True)( 

723 delayed(scan_pairs)(small_ts_bunch, 

724 ts_bunch2, 

725 metric, 

726 nsh, 

727 split_optimal_delays[_], 

728 joint_distr=joint_distr, 

729 allow_mixed_dimensions=allow_mixed_dimensions, 

730 ds=ds, 

731 mask=split_mask[_], 

732 noise_const=noise_const, 

733 seed=seed, 

734 enable_progressbar=False) 

735 for _, small_ts_bunch in enumerate(split_ts_bunch1)) 

736 

737 for i in range(n_jobs): 

738 inds_of_interest = split_ts_bunch1_inds[i] 

739 random_shifts[inds_of_interest, :, :] = parallel_result[i][0][:, :, :] 

740 me_total[inds_of_interest, :, :] = parallel_result[i][1][:, :, :] 

741 

742 return random_shifts, me_total 

743 

744 

745def scan_pairs_router(ts_bunch1, 

746 ts_bunch2, 

747 metric, 

748 nsh, 

749 optimal_delays, 

750 joint_distr=False, 

751 allow_mixed_dimensions=False, 

752 ds=1, 

753 mask=None, 

754 noise_const=1e-3, 

755 seed=None, 

756 enable_parallelization=True, 

757 n_jobs=-1): 

758 """ 

759 Route metric computation to parallel or sequential implementation. 

760  

761 Parameters 

762 ---------- 

763 ts_bunch1 : list of TimeSeries 

764 First set of time series. 

765 ts_bunch2 : list of TimeSeries 

766 Second set of time series. 

767 metric : str 

768 Similarity metric to compute: 

769 - 'mi': Mutual information 

770 - 'spearman': Spearman correlation 

771 - Other metrics supported by get_sim function 

772 nsh : int 

773 Number of shuffles to perform. 

774 optimal_delays : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2)) 

775 Pre-computed optimal delays for each pair. 

776 joint_distr : bool, optional 

777 If True, treats all ts_bunch2 as components of a single multifeature. 

778 Default: False. 

779 ds : int, optional 

780 Downsampling factor. Default: 1. 

781 mask : np.ndarray, optional 

782 Binary mask of shape (len(ts_bunch1), len(ts_bunch2)). 

783 0 = skip computation, 1 = compute. Default: all ones. 

784 noise_const : float, optional 

785 Small noise added to improve numerical stability. Default: 1e-3. 

786 seed : int, optional 

787 Random seed for reproducibility. Default: None. 

788 enable_parallelization : bool, optional 

789 Whether to use parallel processing. Default: True. 

790 n_jobs : int, optional 

791 Number of parallel jobs if parallelization enabled. Default: -1 (use all cores). 

792  

793 Returns 

794 ------- 

795 random_shifts : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2), nsh) 

796 Random shifts used for shuffling. 

797 me_total : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2), nsh+1) 

798 Metric values. [:,:,0] contains true values, [:,:,1:] contains shuffles. 

799  

800 See Also 

801 -------- 

802 scan_pairs : Sequential implementation 

803 scan_pairs_parallel : Parallel implementation 

804 """ 

805 

806 if enable_parallelization: 

807 random_shifts, me_total = scan_pairs_parallel(ts_bunch1, 

808 ts_bunch2, 

809 metric, 

810 nsh, 

811 optimal_delays, 

812 joint_distr=joint_distr, 

813 allow_mixed_dimensions=allow_mixed_dimensions, 

814 ds=ds, 

815 mask=mask, 

816 noise_const=noise_const, 

817 seed=seed, 

818 n_jobs=n_jobs) 

819 

820 else: 

821 random_shifts, me_total = scan_pairs(ts_bunch1, 

822 ts_bunch2, 

823 metric, 

824 nsh, 

825 optimal_delays, 

826 joint_distr=joint_distr, 

827 allow_mixed_dimensions=allow_mixed_dimensions, 

828 ds=ds, 

829 mask=mask, 

830 seed=seed, 

831 noise_const=noise_const) 

832 

833 return random_shifts, me_total 

834 

835 

836class IntenseResults(object): 

837 """ 

838 Container for INTENSE computation results. 

839  

840 Attributes 

841 ---------- 

842 info : dict 

843 Metadata about the computation (optimal delays, thresholds, etc.). 

844 intense_params : dict 

845 Parameters used for the INTENSE computation. 

846 stats : dict 

847 Statistical results (p-values, metric values, etc.). 

848 significance : dict 

849 Significance test results for each neuron-feature pair. 

850  

851 Methods 

852 ------- 

853 update(property_name, data) 

854 Add or update a property with data. 

855 update_multiple(datadict) 

856 Update multiple properties from a dictionary. 

857 save_to_hdf5(fname) 

858 Save all results to an HDF5 file. 

859  

860 Examples 

861 -------- 

862 >>> results = IntenseResults() 

863 >>> results.update('stats', computed_stats) 

864 >>> results.update('info', {'optimal_delays': delays}) 

865 >>> results.save_to_hdf5('intense_results.h5') 

866 """ 

867 def __init__(self): 

868 pass 

869 

870 def update(self, property_name, data): 

871 """Add or update a property with data.""" 

872 setattr(self, property_name, data) 

873 

874 def update_multiple(self, datadict): 

875 """Update multiple properties from a dictionary.""" 

876 for dname, data in datadict.items(): 

877 setattr(self, dname, data) 

878 

879 def save_to_hdf5(self, fname): 

880 """Save all results to an HDF5 file.""" 

881 dict_repr = self.__dict__ 

882 write_dict_to_hdf5(dict_repr, fname) 

883 

884 

885def compute_me_stats(ts_bunch1, 

886 ts_bunch2, 

887 names1=None, 

888 names2=None, 

889 mode='two_stage', 

890 metric='mi', 

891 precomputed_mask_stage1=None, 

892 precomputed_mask_stage2=None, 

893 n_shuffles_stage1=100, 

894 n_shuffles_stage2=10000, 

895 joint_distr=False, 

896 allow_mixed_dimensions=False, 

897 metric_distr_type='gamma', 

898 noise_ampl=1e-3, 

899 ds=1, 

900 topk1=1, 

901 topk2=5, 

902 multicomp_correction='holm', 

903 pval_thr=0.01, 

904 find_optimal_delays=False, 

905 skip_delays=[], 

906 shift_window=100, 

907 verbose=True, 

908 seed=None, 

909 enable_parallelization=True, 

910 n_jobs=-1, 

911 duplicate_behavior='ignore'): 

912 

913 """ 

914 Calculates similarity metric statistics for TimeSeries or MultiTimeSeries pairs 

915 

916 Parameters 

917 ---------- 

918 ts_bunch1: list of TimeSeries objects 

919 

920 ts_bunch2: list of TimeSeries objects 

921 

922 names1: list of str 

923 names than will be given to time series from tsbunch1 in final results 

924 

925 names2: list of str 

926 names than will be given to time series from tsbunch2 in final results 

927 

928 mode: str 

929 Computation mode. 3 modes are available: 

930 'stage1': perform preliminary scanning with "n_shuffles_stage1" shuffles only. 

931 Rejects strictly non-significant neuron-feature pairs, does not give definite results 

932 about significance of the others. 

933 'stage2': skip stage 1 and perform full-scale scanning ("n_shuffles_stage2" shuffles) of all neuron-feature pairs. 

934 Gives definite results, but can be very time-consuming. Also reduces statistical power 

935 of multiple comparison tests, since the number of hypotheses is very high. 

936 'two_stage': prune non-significant pairs during stage 1 and perform thorough testing for the rest during stage 2. 

937 Recommended mode. 

938 default: 'two-stage' 

939  

940 metric: similarity metric between TimeSeries 

941 default: 'mi' 

942  

943 precomputed_mask_stage1: np.array of shape (len(ts_bunch1), len(ts_bunch2)) or (len(ts_bunch), 1) if joint_distr=True 

944 precomputed mask for skipping some of possible pairs in stage 1. 

945 0 in mask values means calculation will be skipped. 

946 1 in mask values means calculation will proceed. 

947 

948 precomputed_mask_stage2: np.array of shape (len(ts_bunch1), len(ts_bunch2)) or (len(ts_bunch), 1) if joint_distr=True 

949 precomputed mask for skipping some of possible pairs in stage 2. 

950 0 in mask values means calculation will be skipped. 

951 1 in mask values means calculation will proceed. 

952 

953 n_shuffles_stage1: int 

954 number of shuffles for first stage 

955 default: 100 

956 

957 n_shuffles_stage2: int 

958 number of shuffles for second stage 

959 default: 10000 

960 

961 joint_distr: bool 

962 if joint_distr=True, ALL features in feat_bunch will be treated as components of a single multifeature 

963 For example, 'x' and 'y' features will be put together into ('x','y') multifeature. 

964 default: False 

965 

966 allow_mixed_dimensions: bool 

967 if True, both TimeSeries and MultiTimeSeries can be provided as signals. 

968 This parameter overrides "joint_distr" 

969 default: False 

970 

971 metric_distr_type: str 

972 Distribution type for shuffled metric distribution fit. Supported options are distributions from scipy.stats 

973 Note: While 'gamma' is theoretically appropriate for MI distributions, empirical testing shows 

974 that 'norm' (normal distribution) often performs better due to its conservative p-values when 

975 fitting poorly to the skewed MI data. This conservatism reduces false positives. 

976 default: "gamma" 

977 

978 noise_ampl: float 

979 Small noise amplitude, which is added to metrics to improve numerical fit 

980 default: 1e-3 

981 

982 ds: int 

983 Downsampling constant. Every "ds" point will be taken from the data time series. 

984 default: 1 

985 

986 topk1: int 

987 true MI for stage 1 should be among topk1 MI shuffles 

988 default: 1 

989 

990 topk2: int 

991 true MI for stage 2 should be among topk2 MI shuffles 

992 default: 5 

993 

994 multicomp_correction: str or None 

995 type of multiple comparisons correction. Supported types are None (no correction), 

996 "bonferroni", "holm", and "fdr_bh". 

997 default: 'holm' 

998 

999 pval_thr: float 

1000 pvalue threshold. if multicomp_correction=None, this is a p-value for a single pair. 

1001 For FWER methods (bonferroni, holm), this is the family-wise error rate. 

1002 For FDR methods (fdr_bh), this is the false discovery rate. 

1003 

1004 find_optimal_delays: bool 

1005 Allows slight shifting (not more than +- shift_window) of time series, 

1006 selects a shift with the highest MI as default. 

1007 default: True 

1008 

1009 skip_delays: list 

1010 List of indices from ts_bunch2 for which delays are not applied (set to 0). 

1011 Has no effect if find_optimal_delays = False 

1012 

1013 shift_window: int 

1014 Window for optimal shift search (frames). Optimal shift will lie in the range 

1015 -shift_window <= opt_shift <= shift_window 

1016 

1017 verbose: bool 

1018 whether to print intermediate information 

1019 

1020 seed: int 

1021 random seed for reproducibility 

1022  

1023 duplicate_behavior: str 

1024 How to handle duplicate TimeSeries in ts_bunch1 or ts_bunch2. 

1025 - 'ignore': Process duplicates normally (default) 

1026 - 'raise': Raise an error if duplicates are found 

1027 - 'warn': Print a warning but continue processing 

1028 

1029 Returns 

1030 ------- 

1031 stats: dict of dict of dicts 

1032 Outer dict keys: indices of tsbunch1 or names1, if given 

1033 Inner dict keys: indices or tsbunch2 or names2, if given 

1034 Last dict: dictionary of stats variables. 

1035 Can be easily converted to pandas DataFrame by pd.DataFrame(stats) 

1036 

1037 significance: dict of dict of dicts 

1038 Outer dict keys: indices of tsbunch1 or names1, if given 

1039 Inner dict keys: indices or tsbunch2 or names2, if given 

1040 Last dict: dictionary of significance-related variables. 

1041 Can be easily converted to pandas DataFrame by pd.DataFrame(significance) 

1042 

1043 accumulated_info: dict 

1044 Data collected during computation. 

1045 """ 

1046 

1047 # TODO: add automatic min_shifts from autocorrelation time 

1048 

1049 # Validate inputs 

1050 validate_time_series_bunches(ts_bunch1, ts_bunch2, allow_mixed_dimensions=allow_mixed_dimensions) 

1051 validate_metric(metric) 

1052 validate_common_parameters(shift_window=shift_window, ds=ds, noise_const=noise_ampl) 

1053 

1054 # Validate mode 

1055 if mode not in ['stage1', 'stage2', 'two_stage']: 

1056 raise ValueError(f"mode must be 'stage1', 'stage2', or 'two_stage', got '{mode}'") 

1057 

1058 # Validate multicomp_correction 

1059 if multicomp_correction not in [None, 'bonferroni', 'holm', 'fdr_bh']: 

1060 raise ValueError(f"Unknown multiple comparison correction method: '{multicomp_correction}'") 

1061 

1062 # Validate pval_thr 

1063 if not 0 < pval_thr < 1: 

1064 raise ValueError(f"pval_thr must be between 0 and 1, got {pval_thr}") 

1065 

1066 # Validate stage-specific parameters 

1067 validate_common_parameters(nsh=n_shuffles_stage1) 

1068 validate_common_parameters(nsh=n_shuffles_stage2) 

1069 

1070 accumulated_info = dict() 

1071 

1072 # Check if we're comparing the same bunch with itself 

1073 same_data_bunch = ts_bunch1 is ts_bunch2 

1074 

1075 n1 = len(ts_bunch1) 

1076 n2 = len(ts_bunch2) 

1077 if not allow_mixed_dimensions: 

1078 n2 = 1 if joint_distr else len(ts_bunch2) 

1079 

1080 tsbunch1_is_1d = np.all([isinstance(ts, TimeSeries) for ts in ts_bunch1]) 

1081 tsbunch2_is_1d = np.all([isinstance(ts, TimeSeries) for ts in ts_bunch2]) 

1082 if not (tsbunch1_is_1d and tsbunch2_is_1d): 

1083 raise ValueError('Multiple time series types found, but allow_mixed_dimensions=False.' 

1084 'Consider setting it to True') 

1085 

1086 if precomputed_mask_stage1 is None: 

1087 precomputed_mask_stage1 = np.ones((n1, n2)) 

1088 if precomputed_mask_stage2 is None: 

1089 precomputed_mask_stage2 = np.ones((n1, n2)) 

1090 

1091 # If comparing the same bunch with itself, mask out the diagonal 

1092 # to avoid computing MI of a TimeSeries with itself at zero shift 

1093 if same_data_bunch: 

1094 np.fill_diagonal(precomputed_mask_stage1, 0) 

1095 np.fill_diagonal(precomputed_mask_stage2, 0) 

1096 

1097 # Handle duplicate TimeSeries based on duplicate_behavior parameter 

1098 if duplicate_behavior in ['raise', 'warn']: 

1099 # Check for duplicates in ts_bunch1 

1100 ts1_ids = [] 

1101 for ts in ts_bunch1: 

1102 ts_id = id(ts.data) if hasattr(ts, 'data') else id(ts) 

1103 ts1_ids.append(ts_id) 

1104 

1105 if len(set(ts1_ids)) < len(ts1_ids): 

1106 msg = "Duplicate TimeSeries objects found in ts_bunch1" 

1107 if duplicate_behavior == 'raise': 

1108 raise ValueError(msg) 

1109 else: # warn 

1110 print(f"Warning: {msg}") 

1111 

1112 # Check for duplicates in ts_bunch2 (if not joint_distr) 

1113 if not joint_distr: 

1114 ts2_ids = [] 

1115 for ts in ts_bunch2: 

1116 ts_id = id(ts.data) if hasattr(ts, 'data') else id(ts) 

1117 ts2_ids.append(ts_id) 

1118 

1119 if len(set(ts2_ids)) < len(ts2_ids): 

1120 msg = "Duplicate TimeSeries objects found in ts_bunch2" 

1121 if duplicate_behavior == 'raise': 

1122 raise ValueError(msg) 

1123 else: # warn 

1124 print(f"Warning: {msg}") 

1125 

1126 optimal_delays = np.zeros((n1, n2), dtype=int) 

1127 ts_with_delays = [ts for _, ts in enumerate(ts_bunch2) if _ not in skip_delays] 

1128 ts_with_delays_inds = np.array([_ for _, ts in enumerate(ts_bunch2) if _ not in skip_delays]) 

1129 

1130 if find_optimal_delays: 

1131 if enable_parallelization: 

1132 optimal_delays_res = calculate_optimal_delays_parallel(ts_bunch1, 

1133 ts_with_delays, 

1134 metric, 

1135 shift_window, 

1136 ds, 

1137 verbose=verbose, 

1138 n_jobs=n_jobs) 

1139 else: 

1140 optimal_delays_res = calculate_optimal_delays(ts_bunch1, 

1141 ts_with_delays, 

1142 metric, 

1143 shift_window, 

1144 ds, 

1145 verbose=verbose) 

1146 

1147 optimal_delays[:, ts_with_delays_inds] = optimal_delays_res 

1148 

1149 accumulated_info['optimal_delays'] = optimal_delays 

1150 

1151 # Initialize masks based on mode 

1152 if mode == 'stage2': 

1153 # For stage2-only mode, assume all pairs pass stage 1 

1154 mask_from_stage1 = np.ones((n1, n2)) 

1155 else: 

1156 mask_from_stage1 = np.zeros((n1, n2)) 

1157 

1158 mask_from_stage2 = np.zeros((n1, n2)) 

1159 nhyp = n1*n2 

1160 

1161 if mode in ['two_stage', 'stage1']: 

1162 npairs_to_check1 = int(np.sum(precomputed_mask_stage1)) 

1163 if verbose: 

1164 print(f'Starting stage 1 scanning for {npairs_to_check1}/{nhyp} possible pairs') 

1165 

1166 # STAGE 1 - primary scanning 

1167 random_shifts1, me_total1 = scan_pairs_router(ts_bunch1, 

1168 ts_bunch2, 

1169 metric, 

1170 n_shuffles_stage1, 

1171 optimal_delays, 

1172 joint_distr=joint_distr, 

1173 allow_mixed_dimensions=allow_mixed_dimensions, 

1174 ds=ds, 

1175 mask=precomputed_mask_stage1, 

1176 noise_const=noise_ampl, 

1177 seed=seed, 

1178 enable_parallelization=enable_parallelization, 

1179 n_jobs=n_jobs) 

1180 

1181 # turn computed data tables from stage 1 and precomputed data into dict of stats dicts 

1182 stage_1_stats = get_table_of_stats(me_total1, 

1183 optimal_delays, 

1184 metric_distr_type=metric_distr_type, 

1185 nsh=n_shuffles_stage1, 

1186 precomputed_mask=precomputed_mask_stage1, 

1187 stage=1) 

1188 

1189 stage_1_stats_per_quantity = nested_dict_to_seq_of_tables(stage_1_stats, 

1190 ordered_names1=range(n1), 

1191 ordered_names2=range(n2)) 

1192 #print(stage_1_stats_per_quantity) 

1193 

1194 # select potentially significant pairs for stage 2 

1195 # 0 in mask values means the pair MI is definitely insignificant, stage 2 calculation will be skipped. 

1196 # 1 in mask values means the pair MI is potentially significant, stage 2 calculation will proceed. 

1197 

1198 if verbose: 

1199 print('Computing significance for all pairs in stage 1...') 

1200 

1201 stage_1_significance = populate_nested_dict(dict(), range(n1), range(n2)) 

1202 for i in range(n1): 

1203 for j in range(n2): 

1204 pair_passes_stage1 = criterion1(stage_1_stats[i][j], 

1205 n_shuffles_stage1, 

1206 topk=topk1) 

1207 if pair_passes_stage1: 

1208 mask_from_stage1[i, j] = 1 

1209 

1210 sig1 = {'stage1': pair_passes_stage1} 

1211 stage_1_significance[i][j].update(sig1) 

1212 

1213 stage_1_significance_per_quantity = nested_dict_to_seq_of_tables(stage_1_significance, 

1214 ordered_names1=range(n1), 

1215 ordered_names2=range(n2)) 

1216 

1217 #print(stage_1_significance_per_quantity) 

1218 accumulated_info.update( 

1219 { 

1220 'stage_1_significance': stage_1_significance_per_quantity, 

1221 'stage_1_stats': stage_1_stats_per_quantity, 

1222 'random_shifts1': random_shifts1, 

1223 'me_total1': me_total1 

1224 } 

1225 ) 

1226 

1227 nhyp = int(np.sum(mask_from_stage1)) # number of hypotheses for further statistical testing 

1228 if verbose: 

1229 print('Stage 1 results:') 

1230 print(f'{nhyp/n1/n2*100:.2f}% ({nhyp}/{n1*n2}) of possible pairs identified as candidates') 

1231 

1232 if mode == 'stage1' or nhyp == 0: 

1233 final_stats = add_names_to_nested_dict(stage_1_stats, names1, names2) 

1234 final_significance = add_names_to_nested_dict(stage_1_significance, names1, names2) 

1235 

1236 return final_stats, final_significance, accumulated_info 

1237 

1238 elif mode == 'stage2': 

1239 # For stage2-only mode, create empty stage 1 structures 

1240 stage_1_stats = populate_nested_dict(dict(), range(n1), range(n2)) 

1241 stage_1_significance = populate_nested_dict(dict(), range(n1), range(n2)) 

1242 # Set all pairs as passing stage 1 with placeholder values 

1243 for i in range(n1): 

1244 for j in range(n2): 

1245 stage_1_stats[i][j] = {'pre_rval': None, 'pre_pval': None} 

1246 stage_1_significance[i][j]['stage1'] = True 

1247 

1248 # Now proceed with stage 2 

1249 if mode in ['two_stage', 'stage2']: 

1250 # STAGE 2 - full-scale scanning 

1251 combined_mask_for_stage_2 = np.ones((n1, n2)) 

1252 combined_mask_for_stage_2[np.where(mask_from_stage1 == 0)] = 0 # exclude non-significant pairs from stage1 

1253 combined_mask_for_stage_2[np.where(precomputed_mask_stage2 == 0)] = 0 # exclude precomputed stage 2 pairs 

1254 

1255 npairs_to_check2 = int(np.sum(combined_mask_for_stage_2)) 

1256 if verbose: 

1257 print(f'Starting stage 2 scanning for {npairs_to_check2}/{nhyp} possible pairs') 

1258 

1259 random_shifts2, me_total2 = scan_pairs_router(ts_bunch1, 

1260 ts_bunch2, 

1261 metric, 

1262 n_shuffles_stage2, 

1263 optimal_delays, 

1264 joint_distr=joint_distr, 

1265 allow_mixed_dimensions=allow_mixed_dimensions, 

1266 ds=ds, 

1267 mask=combined_mask_for_stage_2, 

1268 noise_const=noise_ampl, 

1269 seed=seed, 

1270 enable_parallelization=enable_parallelization, 

1271 n_jobs=n_jobs) 

1272 

1273 # turn data tables from stage 2 to array of stats dicts 

1274 stage_2_stats = get_table_of_stats(me_total2, 

1275 optimal_delays, 

1276 metric_distr_type=metric_distr_type, 

1277 nsh=n_shuffles_stage2, 

1278 precomputed_mask=combined_mask_for_stage_2, 

1279 stage=2) 

1280 

1281 stage_2_stats_per_quantity = nested_dict_to_seq_of_tables(stage_2_stats, 

1282 ordered_names1=range(n1), 

1283 ordered_names2=range(n2)) 

1284 #print(stage_2_stats_per_quantity) 

1285 

1286 # select significant pairs after stage 2 

1287 if verbose: 

1288 print('Computing significance for all pairs in stage 2...') 

1289 all_pvals = None 

1290 if multicomp_correction in ['holm', 'fdr_bh']: # these procedures require all p-values 

1291 all_pvals = get_all_nonempty_pvals(stage_2_stats, range(n1), range(n2)) 

1292 

1293 multicorr_thr = get_multicomp_correction_thr(pval_thr, 

1294 mode=multicomp_correction, 

1295 all_pvals=all_pvals, 

1296 nhyp=nhyp) 

1297 

1298 stage_2_significance = populate_nested_dict(dict(), range(n1), range(n2)) 

1299 for i in range(n1): 

1300 for j in range(n2): 

1301 pair_passes_stage2 = criterion2(stage_2_stats[i][j], 

1302 n_shuffles_stage2, 

1303 multicorr_thr, 

1304 topk=topk2) 

1305 if pair_passes_stage2: 

1306 mask_from_stage2[i,j] = 1 

1307 

1308 sig2 = {'stage2': pair_passes_stage2} 

1309 stage_2_significance[i][j] = sig2 

1310 

1311 stage_2_significance_per_quantity = nested_dict_to_seq_of_tables(stage_2_significance, 

1312 ordered_names1=range(n1), 

1313 ordered_names2=range(n2)) 

1314 

1315 #print(stage_2_significance_per_quantity) 

1316 accumulated_info.update( 

1317 { 

1318 'stage_2_significance': stage_2_significance_per_quantity, 

1319 'stage_2_stats': stage_2_stats_per_quantity, 

1320 'random_shifts2': random_shifts2, 

1321 'me_total2': me_total2, 

1322 'corrected_pval_thr': multicorr_thr, 

1323 'group_pval_thr': pval_thr, 

1324 } 

1325 ) 

1326 

1327 num2 = int(np.sum(mask_from_stage2)) 

1328 if verbose: 

1329 print('Stage 2 results:') 

1330 print(f'{num2/n1/n2*100:.2f}% ({num2}/{n1*n2}) of possible pairs identified as significant') 

1331 

1332 # Always merge stats for consistency 

1333 merged_stats = merge_stage_stats(stage_1_stats, stage_2_stats) 

1334 merged_significance = merge_stage_significance(stage_1_significance, stage_2_significance) 

1335 final_stats = add_names_to_nested_dict(merged_stats, names1, names2) 

1336 final_significance = add_names_to_nested_dict(merged_significance, names1, names2) 

1337 return final_stats, final_significance, accumulated_info 

1338 

1339 

1340def get_multicomp_correction_thr(fwer, mode='holm', **multicomp_kwargs): 

1341 """ 

1342 Calculate p-value threshold for multiple hypothesis correction. 

1343  

1344 Parameters 

1345 ---------- 

1346 fwer : float 

1347 Family-wise error rate or false discovery rate (e.g., 0.05). 

1348 mode : str or None, optional 

1349 Multiple comparison correction method. Default: 'holm'. 

1350 - None: No correction, threshold = fwer 

1351 - 'bonferroni': Bonferroni correction (FWER control) 

1352 - 'holm': Holm-Bonferroni correction (FWER control, more powerful) 

1353 - 'fdr_bh': Benjamini-Hochberg FDR correction 

1354 **multicomp_kwargs : dict 

1355 Additional arguments for correction method: 

1356 - For 'bonferroni': nhyp (int) - number of hypotheses 

1357 - For 'holm': all_pvals (list) - all p-values to be tested 

1358 - For 'fdr_bh': all_pvals (list) - all p-values to be tested 

1359  

1360 Returns 

1361 ------- 

1362 threshold : float 

1363 Adjusted p-value threshold for individual hypothesis testing. 

1364  

1365 Raises 

1366 ------ 

1367 ValueError 

1368 If required arguments are missing or unknown method specified. 

1369  

1370 Notes 

1371 ----- 

1372 - FWER methods (bonferroni, holm) control probability of ANY false positive 

1373 - FDR methods control expected proportion of false positives among rejections 

1374 - Holm is uniformly more powerful than Bonferroni 

1375 - FDR typically allows more discoveries but with controlled false positive rate 

1376  

1377 Examples 

1378 -------- 

1379 >>> # Holm correction (default) 

1380 >>> pvals = [0.001, 0.01, 0.02, 0.03, 0.04] 

1381 >>> thr = get_multicomp_correction_thr(0.05, mode='holm', all_pvals=pvals) 

1382 >>>  

1383 >>> # FDR correction 

1384 >>> thr = get_multicomp_correction_thr(0.05, mode='fdr_bh', all_pvals=pvals) 

1385 """ 

1386 if mode is None: 

1387 threshold = fwer 

1388 

1389 elif mode == 'bonferroni': 

1390 if 'nhyp' in multicomp_kwargs: 

1391 threshold = fwer / multicomp_kwargs['nhyp'] 

1392 else: 

1393 raise ValueError('Number of hypotheses for Bonferroni correction not provided') 

1394 

1395 elif mode == 'holm': 

1396 if 'all_pvals' in multicomp_kwargs: 

1397 all_pvals = sorted(multicomp_kwargs['all_pvals']) 

1398 nhyp = len(all_pvals) 

1399 threshold = 0 # Default if no discoveries 

1400 for i, pval in enumerate(all_pvals): 

1401 cthr = fwer / (nhyp - i) 

1402 if pval > cthr: 

1403 break 

1404 threshold = cthr 

1405 else: 

1406 raise ValueError('List of p-values for Holm correction not provided') 

1407 

1408 elif mode == 'fdr_bh': 

1409 if 'all_pvals' in multicomp_kwargs: 

1410 all_pvals = sorted(multicomp_kwargs['all_pvals']) 

1411 nhyp = len(all_pvals) 

1412 threshold = 0.0 

1413 

1414 # Benjamini-Hochberg procedure 

1415 for i in range(nhyp - 1, -1, -1): 

1416 if all_pvals[i] <= fwer * (i + 1) / nhyp: 

1417 threshold = all_pvals[i] 

1418 break 

1419 else: 

1420 raise ValueError('List of p-values for FDR correction not provided') 

1421 

1422 else: 

1423 raise ValueError('Unknown multiple comparisons correction method') 

1424 

1425 return threshold