Coverage for src/driada/intense/pipelines.py: 74.91%

283 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1from .stats import * 

2from .intense_base import compute_me_stats, IntenseResults 

3from ..information.info_base import TimeSeries, MultiTimeSeries 

4from .disentanglement import disentangle_all_selectivities, DEFAULT_MULTIFEATURE_MAP 

5 

6 

7def compute_cell_feat_significance(exp, 

8 cell_bunch=None, 

9 feat_bunch=None, 

10 data_type='calcium', 

11 metric='mi', 

12 mode='two_stage', 

13 n_shuffles_stage1=100, 

14 n_shuffles_stage2=10000, 

15 joint_distr=False, 

16 allow_mixed_dimensions=False, 

17 metric_distr_type='norm', 

18 noise_ampl=1e-3, 

19 ds=1, 

20 use_precomputed_stats=True, 

21 save_computed_stats=True, 

22 force_update=False, 

23 topk1=1, 

24 topk2=5, 

25 multicomp_correction='holm', 

26 pval_thr=0.01, 

27 find_optimal_delays=True, 

28 skip_delays=[], 

29 shift_window=5, 

30 verbose=True, 

31 enable_parallelization=True, 

32 n_jobs=-1, 

33 seed=42, 

34 with_disentanglement=False, 

35 multifeature_map=None, 

36 duplicate_behavior='ignore'): 

37 

38 """ 

39 Calculates significant neuron-feature pairs 

40 

41 Parameters 

42 ---------- 

43 exp: Experiment instance 

44 Experiment object to read and write data from 

45 

46 cell_bunch: int, iterable or None 

47 Neuron indices. By default, (cell_bunch=None), all neurons will be taken 

48 

49 feat_bunch: str, iterable or None 

50 Feature names. By default, (feat_bunch=None), all single features will be taken 

51 

52 data_type: str 

53 Data type used for INTENSE computations. Can be 'calcium' or 'spikes' 

54 

55 metric: similarity metric between TimeSeries 

56 default: 'mi' 

57 

58 mode: str 

59 Computation mode. 3 modes are available: 

60 'stage1': perform preliminary scanning with "n_shuffles_stage1" shuffles only. 

61 Rejects strictly non-significant neuron-feature pairs, does not give definite results 

62 about significance of the others. 

63 'stage2': skip stage 1 and perform full-scale scanning ("n_shuffles_stage2" shuffles) of all neuron-feature pairs. 

64 Gives definite results, but can be very time-consuming. Also reduces statistical power 

65 of multiple comparison tests, since the number of hypotheses is very high. 

66 'two_stage': prune non-significant pairs during stage 1 and perform thorough testing for the rest during stage 2. 

67 Recommended mode. 

68 default: 'two-stage' 

69 

70 n_shuffles_stage1: int 

71 number of shuffles for first stage 

72 default: 100 

73 

74 n_shuffles_stage2: int 

75 number of shuffles for second stage 

76 default: 10000 

77 

78 joint_distr: bool 

79 if True, ALL features in feat_bunch will be treated as components of a single multifeature 

80 For example, 'x' and 'y' features will be put together into ('x','y') multifeature. 

81 default: False 

82 

83 allow_mixed_dimensions: bool 

84 if True, both TimeSeries and MultiTimeSeries can be provided as signals. 

85 This parameter overrides "joint_distr" 

86 

87 metric_distr_type: str 

88 Distribution type for shuffled metric distribution fit. Supported options are distributions from scipy.stats 

89 Note: While 'gamma' is theoretically appropriate for MI distributions, empirical testing shows 

90 that 'norm' (normal distribution) often performs better due to its conservative p-values when 

91 fitting poorly to the skewed MI data. This conservatism reduces false positives. 

92 default: "gamma" 

93 

94 noise_ampl: float 

95 Small noise amplitude, which is added to MI and shuffled MI to improve numerical fit 

96 default: 1e-3 

97 

98 ds: int 

99 Downsampling constant. Every "ds" point will be taken from the data time series. 

100 Reduces the computational load, but needs caution since with large "ds" some important information may be lost. 

101 Experiment class performs an internal check for this effect. 

102 default: 1 

103 

104 use_precomputed_stats: bool 

105 Whether to use stats saved in Experiment instance. Stats are accumulated separately for stage1 and stage2. 

106 Notes on stats data rewriting (if save_computed_stats=True): 

107 If you want to recalculate stage1 results only, use "use_precomputed_stats=False" and "mode='stage1'". 

108 Stage 2 stats will be erased since they will become irrelevant. 

109 If you want to recalculate stage2 results only, use "use_precomputed_stats=True" and "mode='stage2'" or "mode='two-stage'" 

110 If you want to recalculate everything, use "use_precomputed_stats=False" and "mode='two-stage'" 

111 default: True 

112 

113 save_computed_stats: bool 

114 Whether to save computed stats to Experiment instance 

115 default: True 

116 

117 force_update: bool 

118 Whether to force saved statistics data update in case the collision between actual data hashes and 

119 saved stats data hashes is found (for example, if neuronal or behavior data has been changed externally). 

120 default: False 

121 

122 topk1: int 

123 true MI for stage 1 should be among topk1 MI shuffles 

124 default: 1 

125 

126 topk2: int 

127 true MI for stage 2 should be among topk2 MI shuffles 

128 default: 5 

129 

130 multicomp_correction: str or None 

131 type of multiple comparison correction. Supported types are None (no correction), 

132 "bonferroni" and "holm". 

133 default: 'holm' 

134 

135 pval_thr: float 

136 pvalue threshold. if multicomp_correction=None, this is a p-value for a single pair. 

137 Otherwise it is a FWER significance level. 

138 

139 find_optimal_delays: bool 

140 Allows slight shifting (not more than +- shift_window) of time series, 

141 selects a shift with the highest MI as default. 

142 default: True 

143 

144 skip_delays: list 

145 List of features for which delays are not applied (set to 0). 

146 Has no effect if find_optimal_delays = False 

147 

148 shift_window: int 

149 Window for optimal shift search (seconds). Optimal shift (in frames) will lie in the range 

150 -shift_window*fps <= opt_shift <= shift_window*fps 

151 Has no effect if find_optimal_delays = False 

152  

153 with_disentanglement: bool 

154 If True, performs a full INTENSE pipeline with mixed selectivity analysis: 

155 1. Computes behavioral feature-feature significance 

156 2. Computes neuron-feature significance  

157 3. Disentangles mixed selectivities using behavioral correlations 

158 default: False 

159  

160 multifeature_map: dict or None 

161 Mapping from multifeature tuples to aggregated names for disentanglement. 

162 If None, uses DEFAULT_MULTIFEATURE_MAP from disentanglement module. 

163 Only used when with_disentanglement=True. 

164 default: None 

165  

166 duplicate_behavior: str 

167 How to handle duplicate TimeSeries in neuron or feature bunches. 

168 - 'ignore': Process duplicates normally (default) 

169 - 'raise': Raise an error if duplicates are found 

170 - 'warn': Print a warning but continue processing 

171 

172 Returns 

173 ------- 

174 stats: dict of dict of dicts 

175 Outer dict: dynamic features, inner dict: cells, last dict: stats. 

176 Can be easily converted to pandas DataFrame by pd.DataFrame(stats) 

177 significance: dict of dict of bools 

178 Significance results for each neuron-feature pair 

179 info: dict 

180 Additional information from compute_me_stats 

181 intense_res: IntenseResults 

182 Complete results object 

183 disentanglement_results: dict (only if with_disentanglement=True) 

184 Contains: 

185 - 'feat_feat_significance': Feature-feature significance matrix 

186 - 'disent_matrix': Disentanglement results matrix 

187 - 'count_matrix': Count matrix from disentanglement 

188 - 'summary': Summary statistics from disentanglement 

189 """ 

190 

191 exp.check_ds(ds) 

192 

193 cell_ids = exp._process_cbunch(cell_bunch) 

194 feat_ids = exp._process_fbunch(feat_bunch, allow_multifeatures=True, mode=data_type) 

195 cells = [exp.neurons[cell_id] for cell_id in cell_ids] 

196 

197 if data_type == 'calcium': 

198 signals = [cell.ca for cell in cells] 

199 elif data_type == 'spikes': 

200 signals = [cell.sp for cell in cells] 

201 else: 

202 raise ValueError('"data_type" can be either "calcium" or "spikes"') 

203 

204 #min_shifts = [int(cell.get_t_off() * MIN_CA_SHIFT) for cell in cells] 

205 if not allow_mixed_dimensions: 

206 feats = [exp.dynamic_features[feat_id] for feat_id in feat_ids if feat_id in exp.dynamic_features] 

207 if joint_distr: 

208 feat_ids = [tuple(sorted(feat_ids))] 

209 else: 

210 feats = [] 

211 for feat_id in feat_ids: 

212 if isinstance(feat_id, str): 

213 if feat_id not in exp.dynamic_features: 

214 raise ValueError(f"Feature '{feat_id}' not found in experiment. Available features: {list(exp.dynamic_features.keys())}") 

215 ts = exp.dynamic_features[feat_id] 

216 feats.append(ts) 

217 elif isinstance(feat_id, tuple): 

218 for f in feat_id: 

219 if f not in exp.dynamic_features: 

220 raise ValueError(f"Feature '{f}' not found in experiment. Available features: {list(exp.dynamic_features.keys())}") 

221 parts = [exp.dynamic_features[f] for f in feat_id] 

222 mts = MultiTimeSeries(parts) 

223 feats.append(mts) 

224 else: 

225 raise ValueError('Unknown feature id type') 

226 

227 n, t, f = len(cells), exp.n_frames, len(feats) 

228 

229 precomputed_mask_stage1 = np.ones((n,f)) 

230 precomputed_mask_stage2 = np.ones((n,f)) 

231 

232 if not exp.selectivity_tables_initialized: 

233 exp._set_selectivity_tables(data_type, cbunch=cell_ids, fbunch=feat_ids) 

234 

235 if use_precomputed_stats: 

236 print('Retrieving saved stats data...') 

237 # 0 in mask values means precomputed results are found, calculation will be skipped. 

238 # 1 in mask values means precomputed results are not found or incomplete, calculation will proceed. 

239 

240 for i, cell_id in enumerate(cell_ids): 

241 for j, feat_id in enumerate(feat_ids): 

242 try: 

243 pair_stats = exp.get_neuron_feature_pair_stats(cell_id, feat_id, mode=data_type) 

244 except (ValueError, KeyError): 

245 if isinstance(feat_id, str): 

246 raise ValueError(f'Unknown single feature in feat_bunch: {feat_id}. Check initial data') 

247 else: 

248 exp._add_multifeature_to_data_hashes(feat_id, mode=data_type) 

249 exp._add_multifeature_to_stats(feat_id, mode=data_type) 

250 pair_stats = DEFAULT_STATS.copy() 

251 

252 current_data_hash = exp._data_hashes[data_type][feat_id][cell_id] 

253 

254 if stats_not_empty(pair_stats, current_data_hash, stage=1): 

255 precomputed_mask_stage1[i,j] = 0 

256 if stats_not_empty(pair_stats, current_data_hash, stage=2): 

257 precomputed_mask_stage2[i,j] = 0 

258 

259 combined_precomputed_mask = np.ones((n, f)) 

260 if mode in ['stage2', 'two_stage']: 

261 combined_precomputed_mask[np.where((precomputed_mask_stage1 == 0) & (precomputed_mask_stage2 == 0))] = 0 

262 elif mode == 'stage1': 

263 combined_precomputed_mask[np.where(precomputed_mask_stage1 == 0)] = 0 

264 else: 

265 raise ValueError('Wrong mode!') 

266 

267 computed_stats, computed_significance, info = compute_me_stats(signals, 

268 feats, 

269 mode=mode, 

270 names1=cell_ids, 

271 names2=feat_ids, 

272 metric=metric, 

273 precomputed_mask_stage1=precomputed_mask_stage1, 

274 precomputed_mask_stage2=precomputed_mask_stage2, 

275 n_shuffles_stage1=n_shuffles_stage1, 

276 n_shuffles_stage2=n_shuffles_stage2, 

277 joint_distr=joint_distr, 

278 allow_mixed_dimensions=allow_mixed_dimensions, 

279 metric_distr_type=metric_distr_type, 

280 noise_ampl=noise_ampl, 

281 ds=ds, 

282 topk1=topk1, 

283 topk2=topk2, 

284 multicomp_correction=multicomp_correction, 

285 pval_thr=pval_thr, 

286 find_optimal_delays=find_optimal_delays, 

287 skip_delays=[feat_ids.index(f) for f in skip_delays], 

288 shift_window=shift_window*exp.fps, 

289 verbose=verbose, 

290 enable_parallelization=enable_parallelization, 

291 n_jobs=n_jobs, 

292 seed=seed, 

293 duplicate_behavior=duplicate_behavior) 

294 

295 exp.optimal_nf_delays = info['optimal_delays'] 

296 # add hash data and update Experiment saved statistics and significance if needed 

297 for i, cell_id in enumerate(cell_ids): 

298 for j, feat_id in enumerate(feat_ids): 

299 # Check for non-existing feature if use_precomputed_stats==False 

300 if not use_precomputed_stats: 

301 if feat_id not in exp._data_hashes[data_type]: 

302 raise ValueError(f"Feature '{feat_id}' not found in data hashes. This may indicate the feature was not properly initialized.") 

303 computed_stats[cell_id][feat_id]['data_hash'] = exp._data_hashes[data_type][feat_id][cell_id] 

304 

305 me_val = computed_stats[cell_id][feat_id].get('me') 

306 if me_val is not None and metric == 'mi': 

307 feat_entropy = exp.get_feature_entropy(feat_id, ds=ds) 

308 ca_entropy = exp.neurons[int(cell_id)].ca.get_entropy(ds=ds) 

309 computed_stats[cell_id][feat_id]['rel_me_beh'] = me_val / feat_entropy 

310 computed_stats[cell_id][feat_id]['rel_me_ca'] = me_val / ca_entropy 

311 

312 if save_computed_stats: 

313 stage2_only = True if mode == 'stage2' else False 

314 if combined_precomputed_mask[i,j]: 

315 exp.update_neuron_feature_pair_stats(computed_stats[cell_id][feat_id], 

316 cell_id, 

317 feat_id, 

318 mode=data_type, 

319 force_update=force_update, 

320 stage2_only=stage2_only) 

321 

322 sig = computed_significance[cell_id][feat_id] 

323 exp.update_neuron_feature_pair_significance(sig, cell_id, feat_id, mode=data_type) 

324 

325 # save all results to a single object 

326 intense_params = { 

327 'neurons': {i: cell_ids[i] for i in range(len(cell_ids))}, 

328 'feat_bunch': {i: feat_ids[i] for i in range(len(feat_ids))}, 

329 'data_type': data_type, 

330 'mode': mode, 

331 'metric': metric, 

332 'n_shuffles_stage1': n_shuffles_stage1, 

333 'n_shuffles_stage2': n_shuffles_stage2, 

334 'joint_distr': joint_distr, 

335 'metric_distr_type': metric_distr_type, 

336 'noise_ampl': noise_ampl, 

337 'ds': ds, 

338 'topk1': topk1, 

339 'topk2': topk2, 

340 'multicomp_correction': multicomp_correction, 

341 'pval_thr': pval_thr, 

342 'find_optimal_delays': find_optimal_delays, 

343 'shift_window': shift_window 

344 } 

345 

346 intense_res = IntenseResults() 

347 #intense_res.update('stats', computed_stats) 

348 #intense_res.update('significance', computed_significance) 

349 intense_res.update('info', info) 

350 intense_res.update('intense_params', intense_params) 

351 

352 # Perform disentanglement analysis if requested 

353 if with_disentanglement: 

354 if verbose: 

355 print("\nPerforming mixed selectivity disentanglement analysis...") 

356 

357 # Step 1: Compute feature-feature significance 

358 _, feat_feat_significance, _, feat_names, _ = compute_feat_feat_significance( 

359 exp, 

360 feat_bunch=feat_bunch if feat_bunch is not None else 'all', 

361 metric=metric, 

362 mode=mode, 

363 n_shuffles_stage1=n_shuffles_stage1, 

364 n_shuffles_stage2=n_shuffles_stage2 // 10, # Reduce shuffles for feat-feat 

365 metric_distr_type=metric_distr_type, 

366 noise_ampl=noise_ampl, 

367 ds=ds, 

368 topk1=topk1, 

369 topk2=topk2, 

370 multicomp_correction=multicomp_correction, 

371 pval_thr=pval_thr, 

372 verbose=verbose, 

373 enable_parallelization=enable_parallelization, 

374 n_jobs=n_jobs, 

375 seed=seed 

376 ) 

377 

378 # Step 2: Use default multifeature map if not provided 

379 if multifeature_map is None: 

380 multifeature_map = DEFAULT_MULTIFEATURE_MAP 

381 

382 # Step 3: Run disentanglement analysis 

383 disent_matrix, count_matrix = disentangle_all_selectivities( 

384 exp, 

385 feat_names, 

386 ds=ds, 

387 multifeature_map=multifeature_map, 

388 feat_feat_significance=feat_feat_significance, 

389 cell_bunch=cell_ids 

390 ) 

391 

392 # Step 4: Get summary statistics 

393 from .disentanglement import get_disentanglement_summary 

394 summary = get_disentanglement_summary( 

395 disent_matrix, 

396 count_matrix, 

397 feat_names, 

398 feat_feat_significance 

399 ) 

400 

401 # Package disentanglement results 

402 disentanglement_results = { 

403 'feat_feat_significance': feat_feat_significance, 

404 'disent_matrix': disent_matrix, 

405 'count_matrix': count_matrix, 

406 'feature_names': feat_names, 

407 'summary': summary 

408 } 

409 

410 # Add to IntenseResults 

411 intense_res.update('disentanglement', disentanglement_results) 

412 

413 if verbose: 

414 print(f"\nDisentanglement analysis complete!") 

415 print(f"Total mixed selectivity pairs analyzed: {summary['overall_stats']['total_neuron_pairs']}") 

416 print(f"Redundancy rate: {summary['overall_stats']['redundancy_rate']:.1f}%") 

417 print(f"Independence rate: {summary['overall_stats']['independence_rate']:.1f}%") 

418 if 'true_mixed_selectivity_rate' in summary['overall_stats']: 

419 print(f"True mixed selectivity rate: {summary['overall_stats']['true_mixed_selectivity_rate']:.1f}%") 

420 

421 # Return with disentanglement results 

422 return computed_stats, computed_significance, info, intense_res, disentanglement_results 

423 

424 # Return multiple values for backward compatibility 

425 return computed_stats, computed_significance, info, intense_res 

426 

427 

428def compute_feat_feat_significance(exp, 

429 feat_bunch='all', 

430 metric='mi', 

431 mode='two_stage', 

432 n_shuffles_stage1=100, 

433 n_shuffles_stage2=1000, 

434 metric_distr_type='gamma', 

435 noise_ampl=1e-3, 

436 ds=1, 

437 topk1=1, 

438 topk2=5, 

439 multicomp_correction='holm', 

440 pval_thr=0.01, 

441 verbose=True, 

442 enable_parallelization=True, 

443 n_jobs=-1, 

444 seed=42, 

445 duplicate_behavior='ignore'): 

446 """ 

447 Compute pairwise significance between all behavioral features. 

448  

449 This function calculates pairwise similarity (e.g., mutual information) between 

450 all behavioral features using the two-stage INTENSE approach. The diagonal  

451 elements are set to zero as self-similarity is prevented by the check_for_coincidence 

452 mechanism in get_mi. 

453  

454 Parameters 

455 ---------- 

456 exp : Experiment 

457 Experiment object containing behavioral data. 

458 feat_bunch : str, list or None 

459 Feature names to analyze. Default: 'all' (all features including multifeatures). 

460 Can be a list of specific feature names. 

461 metric : str, optional 

462 Similarity metric to use. Default: 'mi' (mutual information). 

463 mode : str, optional 

464 Computation mode: 'two_stage', 'stage1', or 'stage2'. Default: 'two_stage'. 

465 n_shuffles_stage1 : int, optional 

466 Number of shuffles for stage 1. Default: 100. 

467 n_shuffles_stage2 : int, optional 

468 Number of shuffles for stage 2. Default: 1000. 

469 metric_distr_type : str, optional 

470 Distribution type for metric null distribution. Default: 'gamma'. 

471 noise_ampl : float, optional 

472 Small noise amplitude for numerical stability. Default: 1e-3. 

473 ds : int, optional 

474 Downsampling factor. Default: 1. 

475 topk1 : int, optional 

476 Top-k criterion for stage 1. Default: 1. 

477 topk2 : int, optional 

478 Top-k criterion for stage 2. Default: 5. 

479 multicomp_correction : str or None, optional 

480 Multiple comparison correction method. Default: 'holm'. 

481 pval_thr : float, optional 

482 P-value threshold for significance. Default: 0.01. 

483 verbose : bool, optional 

484 Whether to print progress information. Default: True. 

485 enable_parallelization : bool, optional 

486 Whether to use parallel processing. Default: True. 

487 n_jobs : int, optional 

488 Number of parallel jobs. -1 means use all processors. Default: -1. 

489 seed : int, optional 

490 Random seed for reproducibility. Default: 42. 

491 duplicate_behavior : str, optional 

492 How to handle duplicate TimeSeries in ts_bunch1 or ts_bunch2. 

493 - 'ignore': Process duplicates normally (default) 

494 - 'raise': Raise an error if duplicates are found 

495 - 'warn': Print a warning but continue processing 

496 Default: 'ignore'. 

497  

498 Returns 

499 ------- 

500 similarity_matrix : ndarray 

501 Matrix of similarity values between features. Element [i,j] contains 

502 the similarity between feature i and feature j. Diagonal is zero. 

503 significance_matrix : ndarray 

504 Matrix of binary significance values. 1 indicates significant similarity. 

505 p_value_matrix : ndarray 

506 Matrix of p-values for each comparison. 

507 feature_names : list 

508 List of feature names corresponding to matrix indices. 

509 May include tuples for multifeatures (e.g., ('x', 'y')). 

510 info : dict 

511 Dictionary containing additional information from compute_me_stats. 

512  

513 Notes 

514 ----- 

515 - Uses the two-stage INTENSE approach for efficient significance testing 

516 - Diagonal elements are zero (self-similarity check prevents computation) 

517 - The function handles both discrete and continuous variables 

518 - Supports MultiTimeSeries (e.g., place fields from x,y coordinates) 

519 - For mutual information, values are in bits 

520 - No optimal delay search is performed (delays are set to 0) 

521  

522 Examples 

523 -------- 

524 >>> # Compute MI between all behavioral variables (default) 

525 >>> sim_mat, sig_mat, pval_mat, features, info = compute_feat_feat_significance(exp) 

526 >>>  

527 >>> # Analyze only specific features 

528 >>> sim_mat, sig_mat, pval_mat, features, info = compute_feat_feat_significance( 

529 ... exp,  

530 ... feat_bunch=['speed', 'head_direction', ('x', 'y')] 

531 ... ) 

532 """ 

533 import numpy as np 

534 

535 # Process feature bunch - default is all features 

536 if feat_bunch == 'all': 

537 feat_bunch = None # None means all features in _process_fbunch 

538 feat_ids = exp._process_fbunch(feat_bunch, allow_multifeatures=True, mode='calcium') 

539 n_features = len(feat_ids) 

540 

541 # Handle empty feature list case 

542 if n_features == 0: 

543 if verbose: 

544 print("No features to analyze - returning empty results") 

545 return ( 

546 np.array([]).reshape(0, 0), # similarity_matrix 

547 np.array([]).reshape(0, 0), # significance_matrix 

548 np.array([]).reshape(0, 0), # p_value_matrix 

549 [], # feature_names 

550 {} # info 

551 ) 

552 

553 if verbose: 

554 print(f"Computing behavioral similarity matrix for {n_features} features...") 

555 print(f"Features: {feat_ids}") 

556 

557 # Get TimeSeries/MultiTimeSeries objects for all features 

558 from ..information.info_base import aggregate_multiple_ts 

559 

560 feature_ts = [] 

561 for feat_id in feat_ids: 

562 if isinstance(feat_id, tuple): 

563 # Create MultiTimeSeries for tuples using aggregate_multiple_ts 

564 ts_list = [exp.dynamic_features[f] for f in feat_id] 

565 ts = aggregate_multiple_ts(*ts_list) 

566 else: 

567 ts = exp.dynamic_features[feat_id] 

568 feature_ts.append(ts) 

569 

570 # Create masks that exclude diagonal (self-comparisons) AND lower triangle 

571 # This ensures we only compute the upper triangle for symmetric results 

572 precomputed_mask_stage1 = np.triu(np.ones((n_features, n_features)), k=1) 

573 precomputed_mask_stage2 = np.triu(np.ones((n_features, n_features)), k=1) 

574 

575 # Call compute_me_stats with features against themselves 

576 # Note: optimal delays are disabled (set to False) 

577 stats, significance, info = compute_me_stats( 

578 feature_ts, 

579 feature_ts, 

580 names1=feat_ids, 

581 names2=feat_ids, 

582 metric=metric, 

583 mode=mode, 

584 precomputed_mask_stage1=precomputed_mask_stage1, 

585 precomputed_mask_stage2=precomputed_mask_stage2, 

586 n_shuffles_stage1=n_shuffles_stage1, 

587 n_shuffles_stage2=n_shuffles_stage2, 

588 joint_distr=False, 

589 allow_mixed_dimensions=True, # Allow MultiTimeSeries 

590 metric_distr_type=metric_distr_type, 

591 noise_ampl=noise_ampl, 

592 ds=ds, 

593 topk1=topk1, 

594 topk2=topk2, 

595 multicomp_correction=multicomp_correction, 

596 pval_thr=pval_thr, 

597 find_optimal_delays=False, # No delay optimization 

598 shift_window=0, # No shift window needed 

599 verbose=verbose, 

600 enable_parallelization=enable_parallelization, 

601 n_jobs=n_jobs, 

602 seed=seed, 

603 duplicate_behavior='ignore' # Default behavior for feature-feature comparison 

604 ) 

605 

606 # Extract matrices from results 

607 similarity_matrix = np.zeros((n_features, n_features)) 

608 significance_matrix = np.zeros((n_features, n_features)) 

609 p_value_matrix = np.ones((n_features, n_features)) 

610 

611 # Fill matrices from stats and significance dictionaries 

612 # Since we only computed upper triangle, we need to fill both upper and lower 

613 for i, feat1 in enumerate(feat_ids): 

614 for j, feat2 in enumerate(feat_ids): 

615 if i == j: 

616 # Diagonal is already 0 

617 continue 

618 

619 # Convert tuples to strings for dictionary keys if needed 

620 key1 = str(feat1) if isinstance(feat1, tuple) else feat1 

621 key2 = str(feat2) if isinstance(feat2, tuple) else feat2 

622 

623 # We computed only upper triangle, so check if this pair was computed 

624 if i < j: 

625 # Upper triangle - get from stats 

626 if key1 in stats and key2 in stats[key1]: 

627 stats_dict = stats[key1][key2] 

628 if stats_dict: # Check if dict is not empty 

629 similarity_matrix[i, j] = stats_dict.get('me', 0) 

630 p_value_matrix[i, j] = stats_dict.get('p', 1) 

631 

632 sig_dict = significance.get(key1, {}).get(key2, {}) 

633 if sig_dict.get('stage2') is not None: 

634 significance_matrix[i, j] = float(sig_dict['stage2']) 

635 elif sig_dict.get('stage1') is not None: 

636 significance_matrix[i, j] = float(sig_dict['stage1']) 

637 else: 

638 # Lower triangle - copy from upper triangle for symmetry 

639 similarity_matrix[i, j] = similarity_matrix[j, i] 

640 p_value_matrix[i, j] = p_value_matrix[j, i] 

641 significance_matrix[i, j] = significance_matrix[j, i] 

642 

643 # Ensure diagonal is zero (should already be due to coincidence check) 

644 np.fill_diagonal(similarity_matrix, 0) 

645 np.fill_diagonal(significance_matrix, 0) 

646 np.fill_diagonal(p_value_matrix, 1) 

647 

648 if verbose: 

649 print(f"\nBehavioral similarity matrix computation complete!") 

650 print(f"Feature pairs analyzed: {n_features * n_features}") 

651 print(f"Significant pairs (stage 1): {info.get('n_significant_stage1', 0)}") 

652 print(f"Significant pairs (final): {np.sum(significance_matrix)}") 

653 # Count unique significant pairs (upper triangle only) 

654 unique_sig = np.sum(np.triu(significance_matrix, k=1)) 

655 print(f"Unique significant pairs: {unique_sig}") 

656 

657 return similarity_matrix, significance_matrix, p_value_matrix, feat_ids, info 

658 

659 

660def compute_cell_cell_significance(exp, 

661 cell_bunch=None, 

662 data_type='calcium', 

663 metric='mi', 

664 mode='two_stage', 

665 n_shuffles_stage1=100, 

666 n_shuffles_stage2=1000, 

667 metric_distr_type='gamma', 

668 noise_ampl=1e-3, 

669 ds=1, 

670 topk1=1, 

671 topk2=5, 

672 multicomp_correction='holm', 

673 pval_thr=0.01, 

674 verbose=True, 

675 enable_parallelization=True, 

676 n_jobs=-1, 

677 seed=42, 

678 duplicate_behavior='ignore'): 

679 """ 

680 Compute pairwise functional correlations between neurons using INTENSE. 

681  

682 This function calculates pairwise similarity (e.g., mutual information) between 

683 all neurons using the two-stage INTENSE approach. This can reveal functionally 

684 correlated neurons that may form assemblies or functional modules. 

685  

686 Parameters 

687 ---------- 

688 exp : Experiment 

689 Experiment object containing neural data. 

690 cell_bunch : int, list or None, optional 

691 Neuron indices to analyze. Default: None (all neurons). 

692 data_type : str, optional 

693 Type of neural data: 'calcium' or 'spikes'. Default: 'calcium'. 

694 metric : str, optional 

695 Similarity metric to use. Default: 'mi' (mutual information). 

696 mode : str, optional 

697 Computation mode: 'two_stage', 'stage1', or 'stage2'. Default: 'two_stage'. 

698 n_shuffles_stage1 : int, optional 

699 Number of shuffles for stage 1. Default: 100. 

700 n_shuffles_stage2 : int, optional 

701 Number of shuffles for stage 2. Default: 1000. 

702 metric_distr_type : str, optional 

703 Distribution type for metric null distribution. Default: 'gamma'. 

704 noise_ampl : float, optional 

705 Small noise amplitude for numerical stability. Default: 1e-3. 

706 ds : int, optional 

707 Downsampling factor. Default: 1. 

708 topk1 : int, optional 

709 Top-k criterion for stage 1. Default: 1. 

710 topk2 : int, optional 

711 Top-k criterion for stage 2. Default: 5. 

712 multicomp_correction : str or None, optional 

713 Multiple comparison correction method. Default: 'holm'. 

714 pval_thr : float, optional 

715 P-value threshold for significance. Default: 0.01. 

716 verbose : bool, optional 

717 Whether to print progress information. Default: True. 

718 enable_parallelization : bool, optional 

719 Whether to use parallel processing. Default: True. 

720 n_jobs : int, optional 

721 Number of parallel jobs. -1 means use all processors. Default: -1. 

722 seed : int, optional 

723 Random seed for reproducibility. Default: 42. 

724 duplicate_behavior : str, optional 

725 How to handle duplicate TimeSeries in ts_bunch1 or ts_bunch2. 

726 - 'ignore': Process duplicates normally (default) 

727 - 'raise': Raise an error if duplicates are found 

728 - 'warn': Print a warning but continue processing 

729 Default: 'ignore'. 

730  

731 Returns 

732 ------- 

733 similarity_matrix : ndarray 

734 Matrix of similarity values between neurons. Element [i,j] contains 

735 the similarity between neuron i and neuron j. Diagonal is zero. 

736 significance_matrix : ndarray 

737 Matrix of binary significance values. 1 indicates significant similarity. 

738 p_value_matrix : ndarray 

739 Matrix of p-values for each comparison. 

740 cell_ids : list 

741 List of cell IDs corresponding to matrix indices. 

742 info : dict 

743 Dictionary containing additional information from compute_me_stats. 

744  

745 Notes 

746 ----- 

747 - Uses the two-stage INTENSE approach for efficient significance testing 

748 - Diagonal elements are zero (self-similarity check prevents computation) 

749 - For calcium imaging data, considers temporal dynamics 

750 - For spike data, uses discrete MI formulation 

751 - Can identify functional assemblies through graph analysis of significant pairs 

752 - No optimal delay search is performed (synchronous activity assumed) 

753  

754 Examples 

755 -------- 

756 >>> # Compute functional correlations between all neurons 

757 >>> sim_mat, sig_mat, pval_mat, cells, info = compute_cell_cell_significance(exp) 

758 >>>  

759 >>> # Analyze only specific neurons 

760 >>> sim_mat, sig_mat, pval_mat, cells, info = compute_cell_cell_significance( 

761 ... exp,  

762 ... cell_bunch=[0, 5, 10, 15, 20], 

763 ... data_type='spikes' 

764 ... ) 

765 """ 

766 import numpy as np 

767 

768 # Check downsampling 

769 exp.check_ds(ds) 

770 

771 # Process cell bunch 

772 cell_ids = exp._process_cbunch(cell_bunch) 

773 n_cells = len(cell_ids) 

774 cells = [exp.neurons[cell_id] for cell_id in cell_ids] 

775 

776 if verbose: 

777 print(f"Computing neuronal similarity matrix for {n_cells} neurons...") 

778 print(f"Data type: {data_type}") 

779 

780 # Get neural signals based on data type 

781 if data_type == 'calcium': 

782 signals = [cell.ca for cell in cells] 

783 elif data_type == 'spikes': 

784 signals = [cell.sp for cell in cells] 

785 # Check if spike data exists and is non-degenerate 

786 if any(sig is None for sig in signals): 

787 raise ValueError("Some neurons have no spike data. Use reconstruct_spikes or provide spike data.") 

788 # Check if all spike data is identical (e.g., all zeros) 

789 if len(signals) > 1: 

790 first_data = signals[0].data 

791 if all(np.array_equal(sig.data, first_data) for sig in signals[1:]): 

792 import warnings 

793 warnings.warn("All neurons have identical spike data. This may lead to degenerate results.") 

794 else: 

795 raise ValueError('"data_type" can be either "calcium" or "spikes"') 

796 

797 # Create masks that exclude diagonal (self-comparisons) AND lower triangle 

798 # This ensures we only compute the upper triangle for symmetric results 

799 precomputed_mask_stage1 = np.triu(np.ones((n_cells, n_cells)), k=1) 

800 precomputed_mask_stage2 = np.triu(np.ones((n_cells, n_cells)), k=1) 

801 

802 # Call compute_me_stats with neurons against themselves 

803 # Note: optimal delays are disabled (set to False) for synchronous analysis 

804 stats, significance, info = compute_me_stats( 

805 signals, 

806 signals, 

807 names1=cell_ids, 

808 names2=cell_ids, 

809 metric=metric, 

810 mode=mode, 

811 precomputed_mask_stage1=precomputed_mask_stage1, 

812 precomputed_mask_stage2=precomputed_mask_stage2, 

813 n_shuffles_stage1=n_shuffles_stage1, 

814 n_shuffles_stage2=n_shuffles_stage2, 

815 joint_distr=False, 

816 allow_mixed_dimensions=False, # Neurons are single time series 

817 metric_distr_type=metric_distr_type, 

818 noise_ampl=noise_ampl, 

819 ds=ds, 

820 topk1=topk1, 

821 topk2=topk2, 

822 multicomp_correction=multicomp_correction, 

823 pval_thr=pval_thr, 

824 find_optimal_delays=False, # Assume synchronous activity 

825 shift_window=0, # No shift window needed 

826 verbose=verbose, 

827 enable_parallelization=enable_parallelization, 

828 n_jobs=n_jobs, 

829 seed=seed, 

830 duplicate_behavior='ignore' # Default behavior for cell-cell comparison 

831 ) 

832 

833 # Extract matrices from results 

834 similarity_matrix = np.zeros((n_cells, n_cells)) 

835 significance_matrix = np.zeros((n_cells, n_cells)) 

836 p_value_matrix = np.ones((n_cells, n_cells)) 

837 

838 # Fill matrices from stats and significance dictionaries 

839 # Since we only computed upper triangle, we need to fill both upper and lower 

840 for i, cell1 in enumerate(cell_ids): 

841 for j, cell2 in enumerate(cell_ids): 

842 if i == j: 

843 # Diagonal is already 0 

844 continue 

845 

846 # We computed only upper triangle, so check if this pair was computed 

847 if i < j: 

848 # Upper triangle - get from stats 

849 if cell1 in stats and cell2 in stats[cell1]: 

850 stats_dict = stats[cell1][cell2] 

851 if stats_dict: # Check if dict is not empty 

852 similarity_matrix[i, j] = stats_dict.get('me', 0) 

853 p_value_matrix[i, j] = stats_dict.get('p', 1) 

854 

855 sig_dict = significance.get(cell1, {}).get(cell2, {}) 

856 if sig_dict.get('stage2') is not None: 

857 significance_matrix[i, j] = float(sig_dict['stage2']) 

858 elif sig_dict.get('stage1') is not None: 

859 significance_matrix[i, j] = float(sig_dict['stage1']) 

860 else: 

861 # Lower triangle - copy from upper triangle for symmetry 

862 similarity_matrix[i, j] = similarity_matrix[j, i] 

863 p_value_matrix[i, j] = p_value_matrix[j, i] 

864 significance_matrix[i, j] = significance_matrix[j, i] 

865 

866 # Ensure diagonal is zero (should already be due to coincidence check) 

867 np.fill_diagonal(similarity_matrix, 0) 

868 np.fill_diagonal(significance_matrix, 0) 

869 np.fill_diagonal(p_value_matrix, 1) 

870 

871 if verbose: 

872 print(f"\nNeuronal similarity matrix computation complete!") 

873 print(f"Neuron pairs analyzed: {n_cells * n_cells}") 

874 print(f"Significant pairs (stage 1): {info.get('n_significant_stage1', 0)}") 

875 print(f"Significant pairs (final): {np.sum(significance_matrix)}") 

876 # Count unique significant pairs (upper triangle only) 

877 unique_sig = np.sum(np.triu(significance_matrix, k=1)) 

878 print(f"Unique significant pairs: {unique_sig}") 

879 

880 # Basic network statistics 

881 if unique_sig > 0: 

882 avg_connections = np.sum(significance_matrix) / n_cells 

883 print(f"Average connections per neuron: {avg_connections:.2f}") 

884 max_connections = np.max(np.sum(significance_matrix, axis=1)) 

885 print(f"Maximum connections for a single neuron: {int(max_connections)}") 

886 

887 return similarity_matrix, significance_matrix, p_value_matrix, cell_ids, info 

888 

889 

890def compute_embedding_selectivity(exp, 

891 embedding_methods=None, 

892 cell_bunch=None, 

893 data_type='calcium', 

894 metric='mi', 

895 mode='two_stage', 

896 n_shuffles_stage1=100, 

897 n_shuffles_stage2=10000, 

898 metric_distr_type='norm', 

899 noise_ampl=1e-3, 

900 ds=1, 

901 use_precomputed_stats=True, 

902 save_computed_stats=True, 

903 force_update=False, 

904 topk1=1, 

905 topk2=5, 

906 multicomp_correction='holm', 

907 pval_thr=0.01, 

908 find_optimal_delays=True, 

909 shift_window=5, 

910 verbose=True, 

911 enable_parallelization=True, 

912 n_jobs=-1, 

913 seed=42): 

914 """ 

915 Compute INTENSE selectivity between neurons and dimensionality reduction embeddings. 

916  

917 This function treats each embedding component as a dynamic feature and computes 

918 the mutual information between neural activity and embedding dimensions. This reveals 

919 how individual neurons contribute to the population-level manifold structure. 

920  

921 Parameters 

922 ---------- 

923 exp : Experiment 

924 Experiment object with stored embeddings 

925 embedding_methods : str, list or None 

926 Names of embedding methods to analyze. If None, analyzes all stored embeddings. 

927 cell_bunch : int, iterable or None 

928 Neuron indices. By default (None), all neurons will be taken 

929 data_type : str 

930 Data type used for embeddings and INTENSE ('calcium' or 'spikes') 

931 metric : str 

932 Similarity metric between TimeSeries (default: 'mi') 

933 mode : str 

934 Computation mode: 'stage1', 'stage2', or 'two_stage' (default) 

935 n_shuffles_stage1 : int 

936 Number of shuffles for first stage (default: 100) 

937 n_shuffles_stage2 : int 

938 Number of shuffles for second stage (default: 10000) 

939 metric_distr_type : str 

940 Distribution type for shuffled metric distribution fit (default: 'norm') 

941 noise_ampl : float 

942 Small noise amplitude added to improve numerical fit (default: 1e-3) 

943 ds : int 

944 Downsampling constant (default: 1) 

945 use_precomputed_stats : bool 

946 Whether to use stats saved in Experiment instance (default: True) 

947 save_computed_stats : bool 

948 Whether to save computed stats to Experiment instance (default: True) 

949 force_update : bool 

950 Force update saved statistics if data hash collision found (default: False) 

951 topk1 : int 

952 True MI for stage 1 should be among topk1 MI shuffles (default: 1) 

953 topk2 : int 

954 True MI for stage 2 should be among topk2 MI shuffles (default: 5) 

955 multicomp_correction : str or None 

956 Multiple comparison correction type: None, 'bonferroni', or 'holm' (default) 

957 pval_thr : float 

958 P-value threshold (default: 0.01) 

959 find_optimal_delays : bool 

960 Find optimal temporal delays between neural activity and embeddings (default: True) 

961 shift_window : int 

962 Window for optimal shift search in seconds (default: 5) 

963 verbose : bool 

964 Print progress information (default: True) 

965 enable_parallelization : bool 

966 Enable parallel computation (default: True) 

967 n_jobs : int 

968 Number of parallel jobs, -1 for all cores (default: -1) 

969 seed : int 

970 Random seed (default: 42) 

971  

972 Returns 

973 ------- 

974 results : dict 

975 Dictionary with keys as embedding method names, each containing: 

976 - 'stats': Statistics for each neuron-component pair 

977 - 'significance': Significance results 

978 - 'info': Additional information from compute_me_stats 

979 - 'significant_neurons': Dict of neurons significantly selective to embedding components 

980 - 'n_components': Number of embedding components 

981 - 'component_selectivity': For each component, list of selective neurons 

982 """ 

983 

984 # Get list of embedding methods to analyze 

985 if embedding_methods is None: 

986 embedding_methods = list(exp.embeddings[data_type].keys()) 

987 elif isinstance(embedding_methods, str): 

988 embedding_methods = [embedding_methods] 

989 

990 if not embedding_methods: 

991 raise ValueError(f"No embeddings found for data_type '{data_type}'. " 

992 "Use exp.store_embedding() to add embeddings first.") 

993 

994 results = {} 

995 

996 # Process each embedding method 

997 for method_name in embedding_methods: 

998 if verbose: 

999 print(f"\n{'='*60}") 

1000 print(f"Computing selectivity for embedding: {method_name}") 

1001 print(f"{'='*60}") 

1002 

1003 # Get embedding data 

1004 embedding_dict = exp.get_embedding(method_name, data_type) 

1005 embedding_data = embedding_dict['data'] 

1006 n_components = embedding_data.shape[1] 

1007 

1008 # Create TimeSeries for each embedding component 

1009 embedding_features = {} 

1010 for comp_idx in range(n_components): 

1011 feat_name = f"{method_name}_comp{comp_idx}" 

1012 embedding_features[feat_name] = TimeSeries(embedding_data[:, comp_idx], discrete=False) 

1013 

1014 # Temporarily add embedding components to dynamic features 

1015 original_features = exp.dynamic_features.copy() 

1016 exp.dynamic_features.update(embedding_features) 

1017 

1018 # Also update internal experiment attributes for the new features 

1019 for feat_name, feat_ts in embedding_features.items(): 

1020 setattr(exp, feat_name, feat_ts) 

1021 

1022 # Rebuild data hashes to include new features 

1023 exp._build_data_hashes(mode=data_type) 

1024 

1025 # Initialize stats tables if not already done 

1026 if save_computed_stats and data_type not in exp.stats_tables: 

1027 exp._set_selectivity_tables(data_type) 

1028 

1029 try: 

1030 # Run INTENSE analysis 

1031 stats, significance, info, intense_res = compute_cell_feat_significance( 

1032 exp, 

1033 cell_bunch=cell_bunch, 

1034 feat_bunch=list(embedding_features.keys()), 

1035 data_type=data_type, 

1036 metric=metric, 

1037 mode=mode, 

1038 n_shuffles_stage1=n_shuffles_stage1, 

1039 n_shuffles_stage2=n_shuffles_stage2, 

1040 metric_distr_type=metric_distr_type, 

1041 noise_ampl=noise_ampl, 

1042 ds=ds, 

1043 use_precomputed_stats=False, # Must be False for new dynamic features 

1044 save_computed_stats=False, # Don't save stats for temporary embedding features 

1045 force_update=force_update, 

1046 topk1=topk1, 

1047 topk2=topk2, 

1048 multicomp_correction=multicomp_correction, 

1049 pval_thr=pval_thr, 

1050 find_optimal_delays=find_optimal_delays, 

1051 shift_window=shift_window, 

1052 verbose=verbose, 

1053 enable_parallelization=enable_parallelization, 

1054 n_jobs=n_jobs, 

1055 seed=seed 

1056 ) 

1057 

1058 # Extract significant neurons from the significance results 

1059 # Note: significance structure is significance[neuron_id][feat_name] 

1060 significant_neurons = {} 

1061 for neuron_id in significance.keys(): 

1062 for feat_name in embedding_features.keys(): 

1063 if feat_name in significance[neuron_id]: 

1064 sig_info = significance[neuron_id][feat_name] 

1065 if sig_info.get('stage2', False): # Check if significant in stage 2 

1066 if neuron_id not in significant_neurons: 

1067 significant_neurons[neuron_id] = [] 

1068 significant_neurons[neuron_id].append(feat_name) 

1069 

1070 # Organize component selectivity 

1071 component_selectivity = {comp_idx: [] for comp_idx in range(n_components)} 

1072 for neuron_id, features in significant_neurons.items(): 

1073 for feat in features: 

1074 comp_idx = int(feat.split('_comp')[-1]) 

1075 component_selectivity[comp_idx].append(neuron_id) 

1076 

1077 # Store results 

1078 results[method_name] = { 

1079 'stats': stats, 

1080 'significance': significance, 

1081 'info': info, 

1082 'significant_neurons': significant_neurons, 

1083 'n_components': n_components, 

1084 'component_selectivity': component_selectivity, 

1085 'embedding_metadata': embedding_dict.get('metadata', {}) 

1086 } 

1087 

1088 if verbose: 

1089 n_sig_neurons = len(significant_neurons) 

1090 n_total_neurons = len(exp._process_cbunch(cell_bunch)) 

1091 print(f"\nResults for {method_name}:") 

1092 print(f" Embedding dimensions: {n_components}") 

1093 print(f" Significant neurons: {n_sig_neurons}/{n_total_neurons} ({100*n_sig_neurons/n_total_neurons:.1f}%)") 

1094 

1095 # Component-wise summary 

1096 for comp_idx in range(n_components): 

1097 n_selective = len(component_selectivity[comp_idx]) 

1098 if n_selective > 0: 

1099 print(f" Component {comp_idx}: {n_selective} selective neurons") 

1100 

1101 finally: 

1102 # Restore original features 

1103 exp.dynamic_features = original_features 

1104 

1105 # Remove temporary attributes 

1106 for feat_name in embedding_features.keys(): 

1107 if hasattr(exp, feat_name): 

1108 delattr(exp, feat_name) 

1109 

1110 return results 

1111