Coverage for src/driada/intense/pipelines.py: 74.91%
283 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1from .stats import *
2from .intense_base import compute_me_stats, IntenseResults
3from ..information.info_base import TimeSeries, MultiTimeSeries
4from .disentanglement import disentangle_all_selectivities, DEFAULT_MULTIFEATURE_MAP
7def compute_cell_feat_significance(exp,
8 cell_bunch=None,
9 feat_bunch=None,
10 data_type='calcium',
11 metric='mi',
12 mode='two_stage',
13 n_shuffles_stage1=100,
14 n_shuffles_stage2=10000,
15 joint_distr=False,
16 allow_mixed_dimensions=False,
17 metric_distr_type='norm',
18 noise_ampl=1e-3,
19 ds=1,
20 use_precomputed_stats=True,
21 save_computed_stats=True,
22 force_update=False,
23 topk1=1,
24 topk2=5,
25 multicomp_correction='holm',
26 pval_thr=0.01,
27 find_optimal_delays=True,
28 skip_delays=[],
29 shift_window=5,
30 verbose=True,
31 enable_parallelization=True,
32 n_jobs=-1,
33 seed=42,
34 with_disentanglement=False,
35 multifeature_map=None,
36 duplicate_behavior='ignore'):
38 """
39 Calculates significant neuron-feature pairs
41 Parameters
42 ----------
43 exp: Experiment instance
44 Experiment object to read and write data from
46 cell_bunch: int, iterable or None
47 Neuron indices. By default, (cell_bunch=None), all neurons will be taken
49 feat_bunch: str, iterable or None
50 Feature names. By default, (feat_bunch=None), all single features will be taken
52 data_type: str
53 Data type used for INTENSE computations. Can be 'calcium' or 'spikes'
55 metric: similarity metric between TimeSeries
56 default: 'mi'
58 mode: str
59 Computation mode. 3 modes are available:
60 'stage1': perform preliminary scanning with "n_shuffles_stage1" shuffles only.
61 Rejects strictly non-significant neuron-feature pairs, does not give definite results
62 about significance of the others.
63 'stage2': skip stage 1 and perform full-scale scanning ("n_shuffles_stage2" shuffles) of all neuron-feature pairs.
64 Gives definite results, but can be very time-consuming. Also reduces statistical power
65 of multiple comparison tests, since the number of hypotheses is very high.
66 'two_stage': prune non-significant pairs during stage 1 and perform thorough testing for the rest during stage 2.
67 Recommended mode.
68 default: 'two-stage'
70 n_shuffles_stage1: int
71 number of shuffles for first stage
72 default: 100
74 n_shuffles_stage2: int
75 number of shuffles for second stage
76 default: 10000
78 joint_distr: bool
79 if True, ALL features in feat_bunch will be treated as components of a single multifeature
80 For example, 'x' and 'y' features will be put together into ('x','y') multifeature.
81 default: False
83 allow_mixed_dimensions: bool
84 if True, both TimeSeries and MultiTimeSeries can be provided as signals.
85 This parameter overrides "joint_distr"
87 metric_distr_type: str
88 Distribution type for shuffled metric distribution fit. Supported options are distributions from scipy.stats
89 Note: While 'gamma' is theoretically appropriate for MI distributions, empirical testing shows
90 that 'norm' (normal distribution) often performs better due to its conservative p-values when
91 fitting poorly to the skewed MI data. This conservatism reduces false positives.
92 default: "gamma"
94 noise_ampl: float
95 Small noise amplitude, which is added to MI and shuffled MI to improve numerical fit
96 default: 1e-3
98 ds: int
99 Downsampling constant. Every "ds" point will be taken from the data time series.
100 Reduces the computational load, but needs caution since with large "ds" some important information may be lost.
101 Experiment class performs an internal check for this effect.
102 default: 1
104 use_precomputed_stats: bool
105 Whether to use stats saved in Experiment instance. Stats are accumulated separately for stage1 and stage2.
106 Notes on stats data rewriting (if save_computed_stats=True):
107 If you want to recalculate stage1 results only, use "use_precomputed_stats=False" and "mode='stage1'".
108 Stage 2 stats will be erased since they will become irrelevant.
109 If you want to recalculate stage2 results only, use "use_precomputed_stats=True" and "mode='stage2'" or "mode='two-stage'"
110 If you want to recalculate everything, use "use_precomputed_stats=False" and "mode='two-stage'"
111 default: True
113 save_computed_stats: bool
114 Whether to save computed stats to Experiment instance
115 default: True
117 force_update: bool
118 Whether to force saved statistics data update in case the collision between actual data hashes and
119 saved stats data hashes is found (for example, if neuronal or behavior data has been changed externally).
120 default: False
122 topk1: int
123 true MI for stage 1 should be among topk1 MI shuffles
124 default: 1
126 topk2: int
127 true MI for stage 2 should be among topk2 MI shuffles
128 default: 5
130 multicomp_correction: str or None
131 type of multiple comparison correction. Supported types are None (no correction),
132 "bonferroni" and "holm".
133 default: 'holm'
135 pval_thr: float
136 pvalue threshold. if multicomp_correction=None, this is a p-value for a single pair.
137 Otherwise it is a FWER significance level.
139 find_optimal_delays: bool
140 Allows slight shifting (not more than +- shift_window) of time series,
141 selects a shift with the highest MI as default.
142 default: True
144 skip_delays: list
145 List of features for which delays are not applied (set to 0).
146 Has no effect if find_optimal_delays = False
148 shift_window: int
149 Window for optimal shift search (seconds). Optimal shift (in frames) will lie in the range
150 -shift_window*fps <= opt_shift <= shift_window*fps
151 Has no effect if find_optimal_delays = False
153 with_disentanglement: bool
154 If True, performs a full INTENSE pipeline with mixed selectivity analysis:
155 1. Computes behavioral feature-feature significance
156 2. Computes neuron-feature significance
157 3. Disentangles mixed selectivities using behavioral correlations
158 default: False
160 multifeature_map: dict or None
161 Mapping from multifeature tuples to aggregated names for disentanglement.
162 If None, uses DEFAULT_MULTIFEATURE_MAP from disentanglement module.
163 Only used when with_disentanglement=True.
164 default: None
166 duplicate_behavior: str
167 How to handle duplicate TimeSeries in neuron or feature bunches.
168 - 'ignore': Process duplicates normally (default)
169 - 'raise': Raise an error if duplicates are found
170 - 'warn': Print a warning but continue processing
172 Returns
173 -------
174 stats: dict of dict of dicts
175 Outer dict: dynamic features, inner dict: cells, last dict: stats.
176 Can be easily converted to pandas DataFrame by pd.DataFrame(stats)
177 significance: dict of dict of bools
178 Significance results for each neuron-feature pair
179 info: dict
180 Additional information from compute_me_stats
181 intense_res: IntenseResults
182 Complete results object
183 disentanglement_results: dict (only if with_disentanglement=True)
184 Contains:
185 - 'feat_feat_significance': Feature-feature significance matrix
186 - 'disent_matrix': Disentanglement results matrix
187 - 'count_matrix': Count matrix from disentanglement
188 - 'summary': Summary statistics from disentanglement
189 """
191 exp.check_ds(ds)
193 cell_ids = exp._process_cbunch(cell_bunch)
194 feat_ids = exp._process_fbunch(feat_bunch, allow_multifeatures=True, mode=data_type)
195 cells = [exp.neurons[cell_id] for cell_id in cell_ids]
197 if data_type == 'calcium':
198 signals = [cell.ca for cell in cells]
199 elif data_type == 'spikes':
200 signals = [cell.sp for cell in cells]
201 else:
202 raise ValueError('"data_type" can be either "calcium" or "spikes"')
204 #min_shifts = [int(cell.get_t_off() * MIN_CA_SHIFT) for cell in cells]
205 if not allow_mixed_dimensions:
206 feats = [exp.dynamic_features[feat_id] for feat_id in feat_ids if feat_id in exp.dynamic_features]
207 if joint_distr:
208 feat_ids = [tuple(sorted(feat_ids))]
209 else:
210 feats = []
211 for feat_id in feat_ids:
212 if isinstance(feat_id, str):
213 if feat_id not in exp.dynamic_features:
214 raise ValueError(f"Feature '{feat_id}' not found in experiment. Available features: {list(exp.dynamic_features.keys())}")
215 ts = exp.dynamic_features[feat_id]
216 feats.append(ts)
217 elif isinstance(feat_id, tuple):
218 for f in feat_id:
219 if f not in exp.dynamic_features:
220 raise ValueError(f"Feature '{f}' not found in experiment. Available features: {list(exp.dynamic_features.keys())}")
221 parts = [exp.dynamic_features[f] for f in feat_id]
222 mts = MultiTimeSeries(parts)
223 feats.append(mts)
224 else:
225 raise ValueError('Unknown feature id type')
227 n, t, f = len(cells), exp.n_frames, len(feats)
229 precomputed_mask_stage1 = np.ones((n,f))
230 precomputed_mask_stage2 = np.ones((n,f))
232 if not exp.selectivity_tables_initialized:
233 exp._set_selectivity_tables(data_type, cbunch=cell_ids, fbunch=feat_ids)
235 if use_precomputed_stats:
236 print('Retrieving saved stats data...')
237 # 0 in mask values means precomputed results are found, calculation will be skipped.
238 # 1 in mask values means precomputed results are not found or incomplete, calculation will proceed.
240 for i, cell_id in enumerate(cell_ids):
241 for j, feat_id in enumerate(feat_ids):
242 try:
243 pair_stats = exp.get_neuron_feature_pair_stats(cell_id, feat_id, mode=data_type)
244 except (ValueError, KeyError):
245 if isinstance(feat_id, str):
246 raise ValueError(f'Unknown single feature in feat_bunch: {feat_id}. Check initial data')
247 else:
248 exp._add_multifeature_to_data_hashes(feat_id, mode=data_type)
249 exp._add_multifeature_to_stats(feat_id, mode=data_type)
250 pair_stats = DEFAULT_STATS.copy()
252 current_data_hash = exp._data_hashes[data_type][feat_id][cell_id]
254 if stats_not_empty(pair_stats, current_data_hash, stage=1):
255 precomputed_mask_stage1[i,j] = 0
256 if stats_not_empty(pair_stats, current_data_hash, stage=2):
257 precomputed_mask_stage2[i,j] = 0
259 combined_precomputed_mask = np.ones((n, f))
260 if mode in ['stage2', 'two_stage']:
261 combined_precomputed_mask[np.where((precomputed_mask_stage1 == 0) & (precomputed_mask_stage2 == 0))] = 0
262 elif mode == 'stage1':
263 combined_precomputed_mask[np.where(precomputed_mask_stage1 == 0)] = 0
264 else:
265 raise ValueError('Wrong mode!')
267 computed_stats, computed_significance, info = compute_me_stats(signals,
268 feats,
269 mode=mode,
270 names1=cell_ids,
271 names2=feat_ids,
272 metric=metric,
273 precomputed_mask_stage1=precomputed_mask_stage1,
274 precomputed_mask_stage2=precomputed_mask_stage2,
275 n_shuffles_stage1=n_shuffles_stage1,
276 n_shuffles_stage2=n_shuffles_stage2,
277 joint_distr=joint_distr,
278 allow_mixed_dimensions=allow_mixed_dimensions,
279 metric_distr_type=metric_distr_type,
280 noise_ampl=noise_ampl,
281 ds=ds,
282 topk1=topk1,
283 topk2=topk2,
284 multicomp_correction=multicomp_correction,
285 pval_thr=pval_thr,
286 find_optimal_delays=find_optimal_delays,
287 skip_delays=[feat_ids.index(f) for f in skip_delays],
288 shift_window=shift_window*exp.fps,
289 verbose=verbose,
290 enable_parallelization=enable_parallelization,
291 n_jobs=n_jobs,
292 seed=seed,
293 duplicate_behavior=duplicate_behavior)
295 exp.optimal_nf_delays = info['optimal_delays']
296 # add hash data and update Experiment saved statistics and significance if needed
297 for i, cell_id in enumerate(cell_ids):
298 for j, feat_id in enumerate(feat_ids):
299 # Check for non-existing feature if use_precomputed_stats==False
300 if not use_precomputed_stats:
301 if feat_id not in exp._data_hashes[data_type]:
302 raise ValueError(f"Feature '{feat_id}' not found in data hashes. This may indicate the feature was not properly initialized.")
303 computed_stats[cell_id][feat_id]['data_hash'] = exp._data_hashes[data_type][feat_id][cell_id]
305 me_val = computed_stats[cell_id][feat_id].get('me')
306 if me_val is not None and metric == 'mi':
307 feat_entropy = exp.get_feature_entropy(feat_id, ds=ds)
308 ca_entropy = exp.neurons[int(cell_id)].ca.get_entropy(ds=ds)
309 computed_stats[cell_id][feat_id]['rel_me_beh'] = me_val / feat_entropy
310 computed_stats[cell_id][feat_id]['rel_me_ca'] = me_val / ca_entropy
312 if save_computed_stats:
313 stage2_only = True if mode == 'stage2' else False
314 if combined_precomputed_mask[i,j]:
315 exp.update_neuron_feature_pair_stats(computed_stats[cell_id][feat_id],
316 cell_id,
317 feat_id,
318 mode=data_type,
319 force_update=force_update,
320 stage2_only=stage2_only)
322 sig = computed_significance[cell_id][feat_id]
323 exp.update_neuron_feature_pair_significance(sig, cell_id, feat_id, mode=data_type)
325 # save all results to a single object
326 intense_params = {
327 'neurons': {i: cell_ids[i] for i in range(len(cell_ids))},
328 'feat_bunch': {i: feat_ids[i] for i in range(len(feat_ids))},
329 'data_type': data_type,
330 'mode': mode,
331 'metric': metric,
332 'n_shuffles_stage1': n_shuffles_stage1,
333 'n_shuffles_stage2': n_shuffles_stage2,
334 'joint_distr': joint_distr,
335 'metric_distr_type': metric_distr_type,
336 'noise_ampl': noise_ampl,
337 'ds': ds,
338 'topk1': topk1,
339 'topk2': topk2,
340 'multicomp_correction': multicomp_correction,
341 'pval_thr': pval_thr,
342 'find_optimal_delays': find_optimal_delays,
343 'shift_window': shift_window
344 }
346 intense_res = IntenseResults()
347 #intense_res.update('stats', computed_stats)
348 #intense_res.update('significance', computed_significance)
349 intense_res.update('info', info)
350 intense_res.update('intense_params', intense_params)
352 # Perform disentanglement analysis if requested
353 if with_disentanglement:
354 if verbose:
355 print("\nPerforming mixed selectivity disentanglement analysis...")
357 # Step 1: Compute feature-feature significance
358 _, feat_feat_significance, _, feat_names, _ = compute_feat_feat_significance(
359 exp,
360 feat_bunch=feat_bunch if feat_bunch is not None else 'all',
361 metric=metric,
362 mode=mode,
363 n_shuffles_stage1=n_shuffles_stage1,
364 n_shuffles_stage2=n_shuffles_stage2 // 10, # Reduce shuffles for feat-feat
365 metric_distr_type=metric_distr_type,
366 noise_ampl=noise_ampl,
367 ds=ds,
368 topk1=topk1,
369 topk2=topk2,
370 multicomp_correction=multicomp_correction,
371 pval_thr=pval_thr,
372 verbose=verbose,
373 enable_parallelization=enable_parallelization,
374 n_jobs=n_jobs,
375 seed=seed
376 )
378 # Step 2: Use default multifeature map if not provided
379 if multifeature_map is None:
380 multifeature_map = DEFAULT_MULTIFEATURE_MAP
382 # Step 3: Run disentanglement analysis
383 disent_matrix, count_matrix = disentangle_all_selectivities(
384 exp,
385 feat_names,
386 ds=ds,
387 multifeature_map=multifeature_map,
388 feat_feat_significance=feat_feat_significance,
389 cell_bunch=cell_ids
390 )
392 # Step 4: Get summary statistics
393 from .disentanglement import get_disentanglement_summary
394 summary = get_disentanglement_summary(
395 disent_matrix,
396 count_matrix,
397 feat_names,
398 feat_feat_significance
399 )
401 # Package disentanglement results
402 disentanglement_results = {
403 'feat_feat_significance': feat_feat_significance,
404 'disent_matrix': disent_matrix,
405 'count_matrix': count_matrix,
406 'feature_names': feat_names,
407 'summary': summary
408 }
410 # Add to IntenseResults
411 intense_res.update('disentanglement', disentanglement_results)
413 if verbose:
414 print(f"\nDisentanglement analysis complete!")
415 print(f"Total mixed selectivity pairs analyzed: {summary['overall_stats']['total_neuron_pairs']}")
416 print(f"Redundancy rate: {summary['overall_stats']['redundancy_rate']:.1f}%")
417 print(f"Independence rate: {summary['overall_stats']['independence_rate']:.1f}%")
418 if 'true_mixed_selectivity_rate' in summary['overall_stats']:
419 print(f"True mixed selectivity rate: {summary['overall_stats']['true_mixed_selectivity_rate']:.1f}%")
421 # Return with disentanglement results
422 return computed_stats, computed_significance, info, intense_res, disentanglement_results
424 # Return multiple values for backward compatibility
425 return computed_stats, computed_significance, info, intense_res
428def compute_feat_feat_significance(exp,
429 feat_bunch='all',
430 metric='mi',
431 mode='two_stage',
432 n_shuffles_stage1=100,
433 n_shuffles_stage2=1000,
434 metric_distr_type='gamma',
435 noise_ampl=1e-3,
436 ds=1,
437 topk1=1,
438 topk2=5,
439 multicomp_correction='holm',
440 pval_thr=0.01,
441 verbose=True,
442 enable_parallelization=True,
443 n_jobs=-1,
444 seed=42,
445 duplicate_behavior='ignore'):
446 """
447 Compute pairwise significance between all behavioral features.
449 This function calculates pairwise similarity (e.g., mutual information) between
450 all behavioral features using the two-stage INTENSE approach. The diagonal
451 elements are set to zero as self-similarity is prevented by the check_for_coincidence
452 mechanism in get_mi.
454 Parameters
455 ----------
456 exp : Experiment
457 Experiment object containing behavioral data.
458 feat_bunch : str, list or None
459 Feature names to analyze. Default: 'all' (all features including multifeatures).
460 Can be a list of specific feature names.
461 metric : str, optional
462 Similarity metric to use. Default: 'mi' (mutual information).
463 mode : str, optional
464 Computation mode: 'two_stage', 'stage1', or 'stage2'. Default: 'two_stage'.
465 n_shuffles_stage1 : int, optional
466 Number of shuffles for stage 1. Default: 100.
467 n_shuffles_stage2 : int, optional
468 Number of shuffles for stage 2. Default: 1000.
469 metric_distr_type : str, optional
470 Distribution type for metric null distribution. Default: 'gamma'.
471 noise_ampl : float, optional
472 Small noise amplitude for numerical stability. Default: 1e-3.
473 ds : int, optional
474 Downsampling factor. Default: 1.
475 topk1 : int, optional
476 Top-k criterion for stage 1. Default: 1.
477 topk2 : int, optional
478 Top-k criterion for stage 2. Default: 5.
479 multicomp_correction : str or None, optional
480 Multiple comparison correction method. Default: 'holm'.
481 pval_thr : float, optional
482 P-value threshold for significance. Default: 0.01.
483 verbose : bool, optional
484 Whether to print progress information. Default: True.
485 enable_parallelization : bool, optional
486 Whether to use parallel processing. Default: True.
487 n_jobs : int, optional
488 Number of parallel jobs. -1 means use all processors. Default: -1.
489 seed : int, optional
490 Random seed for reproducibility. Default: 42.
491 duplicate_behavior : str, optional
492 How to handle duplicate TimeSeries in ts_bunch1 or ts_bunch2.
493 - 'ignore': Process duplicates normally (default)
494 - 'raise': Raise an error if duplicates are found
495 - 'warn': Print a warning but continue processing
496 Default: 'ignore'.
498 Returns
499 -------
500 similarity_matrix : ndarray
501 Matrix of similarity values between features. Element [i,j] contains
502 the similarity between feature i and feature j. Diagonal is zero.
503 significance_matrix : ndarray
504 Matrix of binary significance values. 1 indicates significant similarity.
505 p_value_matrix : ndarray
506 Matrix of p-values for each comparison.
507 feature_names : list
508 List of feature names corresponding to matrix indices.
509 May include tuples for multifeatures (e.g., ('x', 'y')).
510 info : dict
511 Dictionary containing additional information from compute_me_stats.
513 Notes
514 -----
515 - Uses the two-stage INTENSE approach for efficient significance testing
516 - Diagonal elements are zero (self-similarity check prevents computation)
517 - The function handles both discrete and continuous variables
518 - Supports MultiTimeSeries (e.g., place fields from x,y coordinates)
519 - For mutual information, values are in bits
520 - No optimal delay search is performed (delays are set to 0)
522 Examples
523 --------
524 >>> # Compute MI between all behavioral variables (default)
525 >>> sim_mat, sig_mat, pval_mat, features, info = compute_feat_feat_significance(exp)
526 >>>
527 >>> # Analyze only specific features
528 >>> sim_mat, sig_mat, pval_mat, features, info = compute_feat_feat_significance(
529 ... exp,
530 ... feat_bunch=['speed', 'head_direction', ('x', 'y')]
531 ... )
532 """
533 import numpy as np
535 # Process feature bunch - default is all features
536 if feat_bunch == 'all':
537 feat_bunch = None # None means all features in _process_fbunch
538 feat_ids = exp._process_fbunch(feat_bunch, allow_multifeatures=True, mode='calcium')
539 n_features = len(feat_ids)
541 # Handle empty feature list case
542 if n_features == 0:
543 if verbose:
544 print("No features to analyze - returning empty results")
545 return (
546 np.array([]).reshape(0, 0), # similarity_matrix
547 np.array([]).reshape(0, 0), # significance_matrix
548 np.array([]).reshape(0, 0), # p_value_matrix
549 [], # feature_names
550 {} # info
551 )
553 if verbose:
554 print(f"Computing behavioral similarity matrix for {n_features} features...")
555 print(f"Features: {feat_ids}")
557 # Get TimeSeries/MultiTimeSeries objects for all features
558 from ..information.info_base import aggregate_multiple_ts
560 feature_ts = []
561 for feat_id in feat_ids:
562 if isinstance(feat_id, tuple):
563 # Create MultiTimeSeries for tuples using aggregate_multiple_ts
564 ts_list = [exp.dynamic_features[f] for f in feat_id]
565 ts = aggregate_multiple_ts(*ts_list)
566 else:
567 ts = exp.dynamic_features[feat_id]
568 feature_ts.append(ts)
570 # Create masks that exclude diagonal (self-comparisons) AND lower triangle
571 # This ensures we only compute the upper triangle for symmetric results
572 precomputed_mask_stage1 = np.triu(np.ones((n_features, n_features)), k=1)
573 precomputed_mask_stage2 = np.triu(np.ones((n_features, n_features)), k=1)
575 # Call compute_me_stats with features against themselves
576 # Note: optimal delays are disabled (set to False)
577 stats, significance, info = compute_me_stats(
578 feature_ts,
579 feature_ts,
580 names1=feat_ids,
581 names2=feat_ids,
582 metric=metric,
583 mode=mode,
584 precomputed_mask_stage1=precomputed_mask_stage1,
585 precomputed_mask_stage2=precomputed_mask_stage2,
586 n_shuffles_stage1=n_shuffles_stage1,
587 n_shuffles_stage2=n_shuffles_stage2,
588 joint_distr=False,
589 allow_mixed_dimensions=True, # Allow MultiTimeSeries
590 metric_distr_type=metric_distr_type,
591 noise_ampl=noise_ampl,
592 ds=ds,
593 topk1=topk1,
594 topk2=topk2,
595 multicomp_correction=multicomp_correction,
596 pval_thr=pval_thr,
597 find_optimal_delays=False, # No delay optimization
598 shift_window=0, # No shift window needed
599 verbose=verbose,
600 enable_parallelization=enable_parallelization,
601 n_jobs=n_jobs,
602 seed=seed,
603 duplicate_behavior='ignore' # Default behavior for feature-feature comparison
604 )
606 # Extract matrices from results
607 similarity_matrix = np.zeros((n_features, n_features))
608 significance_matrix = np.zeros((n_features, n_features))
609 p_value_matrix = np.ones((n_features, n_features))
611 # Fill matrices from stats and significance dictionaries
612 # Since we only computed upper triangle, we need to fill both upper and lower
613 for i, feat1 in enumerate(feat_ids):
614 for j, feat2 in enumerate(feat_ids):
615 if i == j:
616 # Diagonal is already 0
617 continue
619 # Convert tuples to strings for dictionary keys if needed
620 key1 = str(feat1) if isinstance(feat1, tuple) else feat1
621 key2 = str(feat2) if isinstance(feat2, tuple) else feat2
623 # We computed only upper triangle, so check if this pair was computed
624 if i < j:
625 # Upper triangle - get from stats
626 if key1 in stats and key2 in stats[key1]:
627 stats_dict = stats[key1][key2]
628 if stats_dict: # Check if dict is not empty
629 similarity_matrix[i, j] = stats_dict.get('me', 0)
630 p_value_matrix[i, j] = stats_dict.get('p', 1)
632 sig_dict = significance.get(key1, {}).get(key2, {})
633 if sig_dict.get('stage2') is not None:
634 significance_matrix[i, j] = float(sig_dict['stage2'])
635 elif sig_dict.get('stage1') is not None:
636 significance_matrix[i, j] = float(sig_dict['stage1'])
637 else:
638 # Lower triangle - copy from upper triangle for symmetry
639 similarity_matrix[i, j] = similarity_matrix[j, i]
640 p_value_matrix[i, j] = p_value_matrix[j, i]
641 significance_matrix[i, j] = significance_matrix[j, i]
643 # Ensure diagonal is zero (should already be due to coincidence check)
644 np.fill_diagonal(similarity_matrix, 0)
645 np.fill_diagonal(significance_matrix, 0)
646 np.fill_diagonal(p_value_matrix, 1)
648 if verbose:
649 print(f"\nBehavioral similarity matrix computation complete!")
650 print(f"Feature pairs analyzed: {n_features * n_features}")
651 print(f"Significant pairs (stage 1): {info.get('n_significant_stage1', 0)}")
652 print(f"Significant pairs (final): {np.sum(significance_matrix)}")
653 # Count unique significant pairs (upper triangle only)
654 unique_sig = np.sum(np.triu(significance_matrix, k=1))
655 print(f"Unique significant pairs: {unique_sig}")
657 return similarity_matrix, significance_matrix, p_value_matrix, feat_ids, info
660def compute_cell_cell_significance(exp,
661 cell_bunch=None,
662 data_type='calcium',
663 metric='mi',
664 mode='two_stage',
665 n_shuffles_stage1=100,
666 n_shuffles_stage2=1000,
667 metric_distr_type='gamma',
668 noise_ampl=1e-3,
669 ds=1,
670 topk1=1,
671 topk2=5,
672 multicomp_correction='holm',
673 pval_thr=0.01,
674 verbose=True,
675 enable_parallelization=True,
676 n_jobs=-1,
677 seed=42,
678 duplicate_behavior='ignore'):
679 """
680 Compute pairwise functional correlations between neurons using INTENSE.
682 This function calculates pairwise similarity (e.g., mutual information) between
683 all neurons using the two-stage INTENSE approach. This can reveal functionally
684 correlated neurons that may form assemblies or functional modules.
686 Parameters
687 ----------
688 exp : Experiment
689 Experiment object containing neural data.
690 cell_bunch : int, list or None, optional
691 Neuron indices to analyze. Default: None (all neurons).
692 data_type : str, optional
693 Type of neural data: 'calcium' or 'spikes'. Default: 'calcium'.
694 metric : str, optional
695 Similarity metric to use. Default: 'mi' (mutual information).
696 mode : str, optional
697 Computation mode: 'two_stage', 'stage1', or 'stage2'. Default: 'two_stage'.
698 n_shuffles_stage1 : int, optional
699 Number of shuffles for stage 1. Default: 100.
700 n_shuffles_stage2 : int, optional
701 Number of shuffles for stage 2. Default: 1000.
702 metric_distr_type : str, optional
703 Distribution type for metric null distribution. Default: 'gamma'.
704 noise_ampl : float, optional
705 Small noise amplitude for numerical stability. Default: 1e-3.
706 ds : int, optional
707 Downsampling factor. Default: 1.
708 topk1 : int, optional
709 Top-k criterion for stage 1. Default: 1.
710 topk2 : int, optional
711 Top-k criterion for stage 2. Default: 5.
712 multicomp_correction : str or None, optional
713 Multiple comparison correction method. Default: 'holm'.
714 pval_thr : float, optional
715 P-value threshold for significance. Default: 0.01.
716 verbose : bool, optional
717 Whether to print progress information. Default: True.
718 enable_parallelization : bool, optional
719 Whether to use parallel processing. Default: True.
720 n_jobs : int, optional
721 Number of parallel jobs. -1 means use all processors. Default: -1.
722 seed : int, optional
723 Random seed for reproducibility. Default: 42.
724 duplicate_behavior : str, optional
725 How to handle duplicate TimeSeries in ts_bunch1 or ts_bunch2.
726 - 'ignore': Process duplicates normally (default)
727 - 'raise': Raise an error if duplicates are found
728 - 'warn': Print a warning but continue processing
729 Default: 'ignore'.
731 Returns
732 -------
733 similarity_matrix : ndarray
734 Matrix of similarity values between neurons. Element [i,j] contains
735 the similarity between neuron i and neuron j. Diagonal is zero.
736 significance_matrix : ndarray
737 Matrix of binary significance values. 1 indicates significant similarity.
738 p_value_matrix : ndarray
739 Matrix of p-values for each comparison.
740 cell_ids : list
741 List of cell IDs corresponding to matrix indices.
742 info : dict
743 Dictionary containing additional information from compute_me_stats.
745 Notes
746 -----
747 - Uses the two-stage INTENSE approach for efficient significance testing
748 - Diagonal elements are zero (self-similarity check prevents computation)
749 - For calcium imaging data, considers temporal dynamics
750 - For spike data, uses discrete MI formulation
751 - Can identify functional assemblies through graph analysis of significant pairs
752 - No optimal delay search is performed (synchronous activity assumed)
754 Examples
755 --------
756 >>> # Compute functional correlations between all neurons
757 >>> sim_mat, sig_mat, pval_mat, cells, info = compute_cell_cell_significance(exp)
758 >>>
759 >>> # Analyze only specific neurons
760 >>> sim_mat, sig_mat, pval_mat, cells, info = compute_cell_cell_significance(
761 ... exp,
762 ... cell_bunch=[0, 5, 10, 15, 20],
763 ... data_type='spikes'
764 ... )
765 """
766 import numpy as np
768 # Check downsampling
769 exp.check_ds(ds)
771 # Process cell bunch
772 cell_ids = exp._process_cbunch(cell_bunch)
773 n_cells = len(cell_ids)
774 cells = [exp.neurons[cell_id] for cell_id in cell_ids]
776 if verbose:
777 print(f"Computing neuronal similarity matrix for {n_cells} neurons...")
778 print(f"Data type: {data_type}")
780 # Get neural signals based on data type
781 if data_type == 'calcium':
782 signals = [cell.ca for cell in cells]
783 elif data_type == 'spikes':
784 signals = [cell.sp for cell in cells]
785 # Check if spike data exists and is non-degenerate
786 if any(sig is None for sig in signals):
787 raise ValueError("Some neurons have no spike data. Use reconstruct_spikes or provide spike data.")
788 # Check if all spike data is identical (e.g., all zeros)
789 if len(signals) > 1:
790 first_data = signals[0].data
791 if all(np.array_equal(sig.data, first_data) for sig in signals[1:]):
792 import warnings
793 warnings.warn("All neurons have identical spike data. This may lead to degenerate results.")
794 else:
795 raise ValueError('"data_type" can be either "calcium" or "spikes"')
797 # Create masks that exclude diagonal (self-comparisons) AND lower triangle
798 # This ensures we only compute the upper triangle for symmetric results
799 precomputed_mask_stage1 = np.triu(np.ones((n_cells, n_cells)), k=1)
800 precomputed_mask_stage2 = np.triu(np.ones((n_cells, n_cells)), k=1)
802 # Call compute_me_stats with neurons against themselves
803 # Note: optimal delays are disabled (set to False) for synchronous analysis
804 stats, significance, info = compute_me_stats(
805 signals,
806 signals,
807 names1=cell_ids,
808 names2=cell_ids,
809 metric=metric,
810 mode=mode,
811 precomputed_mask_stage1=precomputed_mask_stage1,
812 precomputed_mask_stage2=precomputed_mask_stage2,
813 n_shuffles_stage1=n_shuffles_stage1,
814 n_shuffles_stage2=n_shuffles_stage2,
815 joint_distr=False,
816 allow_mixed_dimensions=False, # Neurons are single time series
817 metric_distr_type=metric_distr_type,
818 noise_ampl=noise_ampl,
819 ds=ds,
820 topk1=topk1,
821 topk2=topk2,
822 multicomp_correction=multicomp_correction,
823 pval_thr=pval_thr,
824 find_optimal_delays=False, # Assume synchronous activity
825 shift_window=0, # No shift window needed
826 verbose=verbose,
827 enable_parallelization=enable_parallelization,
828 n_jobs=n_jobs,
829 seed=seed,
830 duplicate_behavior='ignore' # Default behavior for cell-cell comparison
831 )
833 # Extract matrices from results
834 similarity_matrix = np.zeros((n_cells, n_cells))
835 significance_matrix = np.zeros((n_cells, n_cells))
836 p_value_matrix = np.ones((n_cells, n_cells))
838 # Fill matrices from stats and significance dictionaries
839 # Since we only computed upper triangle, we need to fill both upper and lower
840 for i, cell1 in enumerate(cell_ids):
841 for j, cell2 in enumerate(cell_ids):
842 if i == j:
843 # Diagonal is already 0
844 continue
846 # We computed only upper triangle, so check if this pair was computed
847 if i < j:
848 # Upper triangle - get from stats
849 if cell1 in stats and cell2 in stats[cell1]:
850 stats_dict = stats[cell1][cell2]
851 if stats_dict: # Check if dict is not empty
852 similarity_matrix[i, j] = stats_dict.get('me', 0)
853 p_value_matrix[i, j] = stats_dict.get('p', 1)
855 sig_dict = significance.get(cell1, {}).get(cell2, {})
856 if sig_dict.get('stage2') is not None:
857 significance_matrix[i, j] = float(sig_dict['stage2'])
858 elif sig_dict.get('stage1') is not None:
859 significance_matrix[i, j] = float(sig_dict['stage1'])
860 else:
861 # Lower triangle - copy from upper triangle for symmetry
862 similarity_matrix[i, j] = similarity_matrix[j, i]
863 p_value_matrix[i, j] = p_value_matrix[j, i]
864 significance_matrix[i, j] = significance_matrix[j, i]
866 # Ensure diagonal is zero (should already be due to coincidence check)
867 np.fill_diagonal(similarity_matrix, 0)
868 np.fill_diagonal(significance_matrix, 0)
869 np.fill_diagonal(p_value_matrix, 1)
871 if verbose:
872 print(f"\nNeuronal similarity matrix computation complete!")
873 print(f"Neuron pairs analyzed: {n_cells * n_cells}")
874 print(f"Significant pairs (stage 1): {info.get('n_significant_stage1', 0)}")
875 print(f"Significant pairs (final): {np.sum(significance_matrix)}")
876 # Count unique significant pairs (upper triangle only)
877 unique_sig = np.sum(np.triu(significance_matrix, k=1))
878 print(f"Unique significant pairs: {unique_sig}")
880 # Basic network statistics
881 if unique_sig > 0:
882 avg_connections = np.sum(significance_matrix) / n_cells
883 print(f"Average connections per neuron: {avg_connections:.2f}")
884 max_connections = np.max(np.sum(significance_matrix, axis=1))
885 print(f"Maximum connections for a single neuron: {int(max_connections)}")
887 return similarity_matrix, significance_matrix, p_value_matrix, cell_ids, info
890def compute_embedding_selectivity(exp,
891 embedding_methods=None,
892 cell_bunch=None,
893 data_type='calcium',
894 metric='mi',
895 mode='two_stage',
896 n_shuffles_stage1=100,
897 n_shuffles_stage2=10000,
898 metric_distr_type='norm',
899 noise_ampl=1e-3,
900 ds=1,
901 use_precomputed_stats=True,
902 save_computed_stats=True,
903 force_update=False,
904 topk1=1,
905 topk2=5,
906 multicomp_correction='holm',
907 pval_thr=0.01,
908 find_optimal_delays=True,
909 shift_window=5,
910 verbose=True,
911 enable_parallelization=True,
912 n_jobs=-1,
913 seed=42):
914 """
915 Compute INTENSE selectivity between neurons and dimensionality reduction embeddings.
917 This function treats each embedding component as a dynamic feature and computes
918 the mutual information between neural activity and embedding dimensions. This reveals
919 how individual neurons contribute to the population-level manifold structure.
921 Parameters
922 ----------
923 exp : Experiment
924 Experiment object with stored embeddings
925 embedding_methods : str, list or None
926 Names of embedding methods to analyze. If None, analyzes all stored embeddings.
927 cell_bunch : int, iterable or None
928 Neuron indices. By default (None), all neurons will be taken
929 data_type : str
930 Data type used for embeddings and INTENSE ('calcium' or 'spikes')
931 metric : str
932 Similarity metric between TimeSeries (default: 'mi')
933 mode : str
934 Computation mode: 'stage1', 'stage2', or 'two_stage' (default)
935 n_shuffles_stage1 : int
936 Number of shuffles for first stage (default: 100)
937 n_shuffles_stage2 : int
938 Number of shuffles for second stage (default: 10000)
939 metric_distr_type : str
940 Distribution type for shuffled metric distribution fit (default: 'norm')
941 noise_ampl : float
942 Small noise amplitude added to improve numerical fit (default: 1e-3)
943 ds : int
944 Downsampling constant (default: 1)
945 use_precomputed_stats : bool
946 Whether to use stats saved in Experiment instance (default: True)
947 save_computed_stats : bool
948 Whether to save computed stats to Experiment instance (default: True)
949 force_update : bool
950 Force update saved statistics if data hash collision found (default: False)
951 topk1 : int
952 True MI for stage 1 should be among topk1 MI shuffles (default: 1)
953 topk2 : int
954 True MI for stage 2 should be among topk2 MI shuffles (default: 5)
955 multicomp_correction : str or None
956 Multiple comparison correction type: None, 'bonferroni', or 'holm' (default)
957 pval_thr : float
958 P-value threshold (default: 0.01)
959 find_optimal_delays : bool
960 Find optimal temporal delays between neural activity and embeddings (default: True)
961 shift_window : int
962 Window for optimal shift search in seconds (default: 5)
963 verbose : bool
964 Print progress information (default: True)
965 enable_parallelization : bool
966 Enable parallel computation (default: True)
967 n_jobs : int
968 Number of parallel jobs, -1 for all cores (default: -1)
969 seed : int
970 Random seed (default: 42)
972 Returns
973 -------
974 results : dict
975 Dictionary with keys as embedding method names, each containing:
976 - 'stats': Statistics for each neuron-component pair
977 - 'significance': Significance results
978 - 'info': Additional information from compute_me_stats
979 - 'significant_neurons': Dict of neurons significantly selective to embedding components
980 - 'n_components': Number of embedding components
981 - 'component_selectivity': For each component, list of selective neurons
982 """
984 # Get list of embedding methods to analyze
985 if embedding_methods is None:
986 embedding_methods = list(exp.embeddings[data_type].keys())
987 elif isinstance(embedding_methods, str):
988 embedding_methods = [embedding_methods]
990 if not embedding_methods:
991 raise ValueError(f"No embeddings found for data_type '{data_type}'. "
992 "Use exp.store_embedding() to add embeddings first.")
994 results = {}
996 # Process each embedding method
997 for method_name in embedding_methods:
998 if verbose:
999 print(f"\n{'='*60}")
1000 print(f"Computing selectivity for embedding: {method_name}")
1001 print(f"{'='*60}")
1003 # Get embedding data
1004 embedding_dict = exp.get_embedding(method_name, data_type)
1005 embedding_data = embedding_dict['data']
1006 n_components = embedding_data.shape[1]
1008 # Create TimeSeries for each embedding component
1009 embedding_features = {}
1010 for comp_idx in range(n_components):
1011 feat_name = f"{method_name}_comp{comp_idx}"
1012 embedding_features[feat_name] = TimeSeries(embedding_data[:, comp_idx], discrete=False)
1014 # Temporarily add embedding components to dynamic features
1015 original_features = exp.dynamic_features.copy()
1016 exp.dynamic_features.update(embedding_features)
1018 # Also update internal experiment attributes for the new features
1019 for feat_name, feat_ts in embedding_features.items():
1020 setattr(exp, feat_name, feat_ts)
1022 # Rebuild data hashes to include new features
1023 exp._build_data_hashes(mode=data_type)
1025 # Initialize stats tables if not already done
1026 if save_computed_stats and data_type not in exp.stats_tables:
1027 exp._set_selectivity_tables(data_type)
1029 try:
1030 # Run INTENSE analysis
1031 stats, significance, info, intense_res = compute_cell_feat_significance(
1032 exp,
1033 cell_bunch=cell_bunch,
1034 feat_bunch=list(embedding_features.keys()),
1035 data_type=data_type,
1036 metric=metric,
1037 mode=mode,
1038 n_shuffles_stage1=n_shuffles_stage1,
1039 n_shuffles_stage2=n_shuffles_stage2,
1040 metric_distr_type=metric_distr_type,
1041 noise_ampl=noise_ampl,
1042 ds=ds,
1043 use_precomputed_stats=False, # Must be False for new dynamic features
1044 save_computed_stats=False, # Don't save stats for temporary embedding features
1045 force_update=force_update,
1046 topk1=topk1,
1047 topk2=topk2,
1048 multicomp_correction=multicomp_correction,
1049 pval_thr=pval_thr,
1050 find_optimal_delays=find_optimal_delays,
1051 shift_window=shift_window,
1052 verbose=verbose,
1053 enable_parallelization=enable_parallelization,
1054 n_jobs=n_jobs,
1055 seed=seed
1056 )
1058 # Extract significant neurons from the significance results
1059 # Note: significance structure is significance[neuron_id][feat_name]
1060 significant_neurons = {}
1061 for neuron_id in significance.keys():
1062 for feat_name in embedding_features.keys():
1063 if feat_name in significance[neuron_id]:
1064 sig_info = significance[neuron_id][feat_name]
1065 if sig_info.get('stage2', False): # Check if significant in stage 2
1066 if neuron_id not in significant_neurons:
1067 significant_neurons[neuron_id] = []
1068 significant_neurons[neuron_id].append(feat_name)
1070 # Organize component selectivity
1071 component_selectivity = {comp_idx: [] for comp_idx in range(n_components)}
1072 for neuron_id, features in significant_neurons.items():
1073 for feat in features:
1074 comp_idx = int(feat.split('_comp')[-1])
1075 component_selectivity[comp_idx].append(neuron_id)
1077 # Store results
1078 results[method_name] = {
1079 'stats': stats,
1080 'significance': significance,
1081 'info': info,
1082 'significant_neurons': significant_neurons,
1083 'n_components': n_components,
1084 'component_selectivity': component_selectivity,
1085 'embedding_metadata': embedding_dict.get('metadata', {})
1086 }
1088 if verbose:
1089 n_sig_neurons = len(significant_neurons)
1090 n_total_neurons = len(exp._process_cbunch(cell_bunch))
1091 print(f"\nResults for {method_name}:")
1092 print(f" Embedding dimensions: {n_components}")
1093 print(f" Significant neurons: {n_sig_neurons}/{n_total_neurons} ({100*n_sig_neurons/n_total_neurons:.1f}%)")
1095 # Component-wise summary
1096 for comp_idx in range(n_components):
1097 n_selective = len(component_selectivity[comp_idx])
1098 if n_selective > 0:
1099 print(f" Component {comp_idx}: {n_selective} selective neurons")
1101 finally:
1102 # Restore original features
1103 exp.dynamic_features = original_features
1105 # Remove temporary attributes
1106 for feat_name in embedding_features.keys():
1107 if hasattr(exp, feat_name):
1108 delattr(exp, feat_name)
1110 return results