Coverage for src/driada/experiment/exp_base.py: 60.05%
368 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1import numpy as np
2import warnings
3import tqdm
4from itertools import combinations
5import pickle
7from ..information.info_base import TimeSeries
8from ..information.info_base import MultiTimeSeries
9from .neuron import DEFAULT_MIN_BEHAVIOUR_TIME, Neuron
10from .wavelet_event_detection import WVT_EVENT_DETECTION_PARAMS, extract_wvt_events, events_to_ts_array, ridges_to_containers
11from ..utils.data import get_hash, populate_nested_dict
12from ..information.info_base import get_1d_mi
14STATS_VARS = ['data_hash', 'opt_delay', 'pre_pval', 'pre_rval', 'pval', 'rval', 'me', 'rel_me_beh', 'rel_me_ca']
15SIGNIFICANCE_VARS = ['stage1', 'shuffles1', 'stage2', 'shuffles2', 'final_p_thr', 'multicomp_corr', 'pairwise_pval_thr']
16DEFAULT_STATS = dict(zip(STATS_VARS, [None for _ in STATS_VARS]))
17DEFAULT_SIGNIFICANCE = dict(zip(SIGNIFICANCE_VARS, [None for _ in SIGNIFICANCE_VARS]))
20def check_dynamic_features(dynamic_features):
21 dfeat_lengths = {}
22 for feat_id in dynamic_features:
23 current_ts = dynamic_features[feat_id]
24 if isinstance(current_ts, TimeSeries):
25 len_ts = len(current_ts.data)
26 elif isinstance(current_ts, np.ndarray):
27 # Handle raw numpy arrays
28 len_ts = current_ts.shape[-1] if current_ts.ndim > 1 else len(current_ts)
29 elif hasattr(current_ts, 'data') and hasattr(current_ts.data, 'shape'): # MultiTimeSeries or similar
30 len_ts = current_ts.data.shape[1] # MultiTimeSeries data is (n_features, n_timepoints)
31 else:
32 len_ts = len(current_ts)
34 dfeat_lengths[feat_id] = len_ts
36 #TODO: add fix for 0 features
37 if len(set(dfeat_lengths.values())) != 1:
38 print(dfeat_lengths)
39 raise ValueError('Dynamic features have different lengths!')
42class Experiment():
43 '''
44 Class for all Ca2+ experiment types
46 Attributes
47 ----------
49 Methods
50 -------
52 '''
54 def __init__(self, signature, calcium, spikes, exp_identificators,
55 static_features, dynamic_features, **kwargs):
57 fit_individual_t_off = kwargs.get('fit_individual_t_off', False)
58 reconstruct_spikes = kwargs.get('reconstruct_spikes', 'wavelet')
59 bad_frames_mask = kwargs.get('bad_frames_mask', None)
60 spike_kwargs = kwargs.get('spike_kwargs', None)
62 check_dynamic_features(dynamic_features)
63 self.exp_identificators = exp_identificators
64 self.signature = signature
66 for idx in exp_identificators:
67 setattr(self, idx, exp_identificators[idx])
69 if calcium is None:
70 raise AttributeError('No calcium data provided')
72 if reconstruct_spikes is None:
73 if spikes is None:
74 warnings.warn('No spike data provided, spikes reconstruction from Ca2+ data disabled')
75 else:
76 if spikes is not None:
77 warnings.warn(f'Spike data will be overridden by reconstructed spikes from Ca2+ data with method={reconstruct_spikes}')
79 # Store the reconstruction method for potential future use
80 self.spike_reconstruction_method = reconstruct_spikes
82 # Reconstruct spikes
83 spikes = self._reconstruct_spikes(calcium, reconstruct_spikes, static_features.get('fps'), spike_kwargs)
85 self.filtered_flag = False
86 if bad_frames_mask is not None:
87 calcium, spikes, dynamic_features = self._trim_data(calcium,
88 spikes,
89 dynamic_features,
90 bad_frames_mask)
91 else:
92 for feat_id in dynamic_features.copy():
93 feat_data = dynamic_features[feat_id]
94 # Skip if already a TimeSeries or MultiTimeSeries
95 if not isinstance(feat_data, (TimeSeries, MultiTimeSeries)):
96 # Convert numpy arrays based on dimensionality
97 if isinstance(feat_data, np.ndarray):
98 if feat_data.ndim == 1:
99 # 1D array -> TimeSeries
100 dynamic_features[feat_id] = TimeSeries(feat_data)
101 elif feat_data.ndim == 2:
102 # 2D array -> MultiTimeSeries (each row is a component)
103 ts_list = [TimeSeries(feat_data[i, :], discrete=False) for i in range(feat_data.shape[0])]
104 dynamic_features[feat_id] = MultiTimeSeries(ts_list)
105 else:
106 raise ValueError(f"Feature {feat_id} has unsupported dimensionality: {feat_data.ndim}D")
107 else:
108 # Assume it's 1D data if not numpy array
109 dynamic_features[feat_id] = TimeSeries(feat_data)
111 self.n_cells = calcium.shape[0]
112 self.n_frames = calcium.shape[1]
114 # Store raw calcium and spikes arrays temporarily
115 self._calcium_raw = calcium
116 self._spikes_raw = spikes if spikes is not None else np.zeros(calcium.shape)
118 self.neurons = []
120 print('Building neurons...')
121 for i in tqdm.tqdm(np.arange(self.n_cells), position=0, leave=True):
122 cell = Neuron(str(i),
123 self._calcium_raw[i, :],
124 self._spikes_raw[i, :],
125 default_t_rise=static_features.get('t_rise_sec'),
126 default_t_off=static_features.get('t_off_sec'),
127 fps=static_features.get('fps'),
128 fit_individual_t_off=fit_individual_t_off)
130 self.neurons.append(cell)
132 # Now create MultiTimeSeries from neurons to preserve their shuffle masks
133 calcium_ts_list = [neuron.ca for neuron in self.neurons]
134 spikes_ts_list = [neuron.sp if neuron.sp is not None else TimeSeries(np.zeros(self.n_frames), discrete=True) for neuron in self.neurons]
136 # Create MultiTimeSeries from the TimeSeries objects in neurons
137 # This preserves the individual shuffle masks created by each Neuron
138 self.calcium = MultiTimeSeries(calcium_ts_list)
139 self.spikes = MultiTimeSeries(spikes_ts_list)
141 self.dynamic_features = dynamic_features
142 for feat_id in dynamic_features:
143 if isinstance(feat_id, str):
144 setattr(self, feat_id, dynamic_features[feat_id])
145 # Skip tuples (multifeatures) as they can't be attribute names
147 for sfeat_name in static_features:
148 setattr(self, sfeat_name, static_features[sfeat_name])
150 # for selectivity data from INTENSE
151 self.stats_tables = {}
152 self.significance_tables = {}
153 self.selectivity_tables_initialized = False
155 # for dimensionality reduction embeddings
156 self.embeddings = {'calcium': {}, 'spikes': {}}
158 print('Building data hashes...')
159 self._build_data_hashes(mode='calcium')
160 if reconstruct_spikes is not None or spikes is not None:
161 self._build_data_hashes(mode='spikes')
163 print('Final checkpoint...')
164 self._checkpoint()
165 #self._load_precomputed_data(**kwargs)
167 print(f'Experiment "{self.signature}" constructed successfully with {self.n_cells} neurons and {len(self.dynamic_features)} features')
169 def check_ds(self, ds):
170 if not hasattr(self, 'fps'):
171 raise ValueError(f'fps not set for {self.signature}')
173 time_step = 1.0/self.fps
174 if time_step*ds > DEFAULT_MIN_BEHAVIOUR_TIME:
175 print('Downsampling constant is too high: some behaviour acts may be skipped. '
176 f'Current minimal behaviour time interval is set to {DEFAULT_MIN_BEHAVIOUR_TIME} sec, '
177 f'downsampling {ds} will create time gaps of {time_step*ds} sec')
179 def _set_selectivity_tables(self, mode, fbunch=None, cbunch=None):
180 # neuron-feature pair statistics
181 stats_table = self._populate_cell_feat_dict(DEFAULT_STATS, fbunch=fbunch, cbunch=cbunch)
183 # neuron-feature pair significance-related data
184 significance_table = self._populate_cell_feat_dict(DEFAULT_SIGNIFICANCE,
185 fbunch=fbunch,
186 cbunch=cbunch)
187 self.stats_tables[mode] = stats_table
188 self.significance_tables[mode] = significance_table
189 self.selectivity_tables_initialized = True
191 def _build_pair_hash(self, cell_id, feat_id, mode='calcium'):
192 '''
193 Builds a unique hash-based representation of activity-feature pair data.
194 feat_id should be a string or an iterable of strings (in case of joint MI calculation).
195 '''
196 if mode == 'calcium':
197 act = self.neurons[cell_id].ca.data
198 elif mode == 'spikes':
199 act = self.neurons[cell_id].sp.data
200 else:
201 raise ValueError('"mode" can be either "calcium" or "spikes"')
203 act_hash = get_hash(act)
205 if (not isinstance(feat_id, str)) and len(feat_id) == 1:
206 feat_id = feat_id[0]
208 if isinstance(feat_id, str):
209 dyn_data = self.dynamic_features[feat_id].data
210 dyn_data_hash = get_hash(dyn_data)
211 pair_hash = (act_hash, dyn_data_hash)
213 else:
214 ordered_fnames = tuple(sorted(list(feat_id)))
215 list_of_hashes = [act_hash]
216 for fname in ordered_fnames:
217 dyn_data = self.dynamic_features[fname].data
218 dyn_data_hash = get_hash(dyn_data)
219 list_of_hashes.append(dyn_data_hash)
221 pair_hash = tuple(list_of_hashes)
223 return pair_hash
225 def _build_data_hashes(self, mode='calcium'):
226 '''
227 Builds a unique hash-based representation of calcium-feature pair data for all cell-feature pairs..
228 '''
229 default_data_hashes = {dfeat: dict(zip(range(self.n_cells), [None for _ in range(self.n_cells)])) for dfeat in self.dynamic_features.keys()}
230 self._data_hashes = {'calcium': default_data_hashes, 'spikes': default_data_hashes}
231 for feat_id in self.dynamic_features:
232 for cell_id in range(self.n_cells):
233 self._data_hashes[mode][feat_id][cell_id] = self._build_pair_hash(cell_id, feat_id, mode=mode)
235 def _trim_data(self, calcium, spikes, dynamic_features, bad_frames_mask, force_filter=False):
237 if not force_filter and self.filtered_flag:
238 raise AttributeError('Data is already filtered, if you want to force filtering it again, set "force_filter = True"')
240 f_calcium = calcium[:, ~bad_frames_mask]
241 if spikes is not None:
242 f_spikes = spikes[:, ~bad_frames_mask]
243 else:
244 f_spikes = None
246 f_dynamic_features = {}
247 for feat_id in dynamic_features:
248 current_ts = dynamic_features[feat_id]
249 if isinstance(current_ts, TimeSeries):
250 f_ts = TimeSeries(current_ts.data[~bad_frames_mask], discrete=current_ts.discrete)
251 else:
252 f_ts = TimeSeries(current_ts[~bad_frames_mask])
254 f_dynamic_features[feat_id] = f_ts
256 self.filtered_flag = True
257 self.bad_frames_mask = bad_frames_mask
259 return f_calcium, f_spikes, f_dynamic_features
261 def _checkpoint(self):
262 '''
263 Check build for common errors
264 '''
265 if self.n_cells > self.n_frames:
266 raise UserWarning('Number of cells > number of time frames, looks like the data is transposed')
268 for dfeat in ['calcium', 'spikes']:
269 if self.n_frames not in getattr(self, dfeat).shape:
270 raise ValueError(f'"{dfeat}" feature has inappropriate shape: {getattr(self, dfeat).data.shape}'
271 f'inconsistent with data length {self.n_frames}')
273 for dfeat in self.dynamic_features.keys():
274 if isinstance(dfeat, str):
275 if self.n_frames not in getattr(self, dfeat).data.shape:
276 raise ValueError(f'"{dfeat}" feature has inappropriate shape: {getattr(self, dfeat).data.shape}'
277 f'inconsistent with data length {self.n_frames}')
278 else:
279 # For tuple features (multifeatures), check the underlying data
280 feat_data = self.dynamic_features[dfeat]
281 if hasattr(feat_data, 'data') and self.n_frames not in feat_data.data.shape:
282 raise ValueError(f'"{dfeat}" feature has inappropriate shape: {feat_data.data.shape}'
283 f'inconsistent with data length {self.n_frames}')
285 def _populate_cell_feat_dict(self, content, fbunch=None, cbunch=None):
286 '''
287 Helper function. Creates a nested dictionary of feature-cell pairs and populates every cell with 'content' variable.
288 Outer dict: dynamic features, inner dict: cells
289 '''
290 cell_ids = self._process_cbunch(cbunch)
291 feat_ids = self._process_fbunch(fbunch, allow_multifeatures=True)
292 nested_dict = populate_nested_dict(content, feat_ids, cell_ids)
293 return nested_dict
295 def _process_cbunch(self, cbunch):
296 '''
297 Helper function. Turns cell indices (int, iterable or None) into a list of cell numbers.
298 '''
299 if isinstance(cbunch, int):
300 cell_ids = [cbunch]
301 elif cbunch is None:
302 cell_ids = list(np.arange(self.n_cells))
303 else:
304 cell_ids = list(cbunch)
306 return cell_ids
308 def _process_fbunch(self, fbunch, allow_multifeatures=False, mode='calcium'):
309 '''
310 Helper function. Turns feature names (str, iterable or None) into a list of feature names
311 '''
312 if isinstance(fbunch, str):
313 feat_ids = [fbunch]
315 elif fbunch is None: # default set of features
316 if allow_multifeatures:
317 try:
318 # stats table contains up-to-date set of features, including multifeatures
319 feat_ids = list(self.stats_tables[mode].keys())
320 except KeyError:
321 # if stats is not available, take pre-defined full set of features
322 feat_ids = list(self.dynamic_features.keys())
323 else:
324 feat_ids = list(self.dynamic_features.keys())
326 else:
327 feat_ids = []
329 # check for multifeatures
330 for fname in fbunch:
331 if isinstance(fname, str):
332 if fname in self.dynamic_features:
333 feat_ids.append(fname)
334 else:
335 if allow_multifeatures:
336 feat_ids.append(tuple(sorted(list(fname))))
337 else:
338 raise ValueError('Multifeature detected in "allow_multifeatures=False" mode')
340 return feat_ids
342 def _process_sbunch(self, sbunch, significance_mode=False):
343 '''
344 Helper function. Turns stats type names (str, iterable or None) into a list of stats types
345 '''
346 if significance_mode:
347 default_list = SIGNIFICANCE_VARS
348 else:
349 default_list = STATS_VARS
351 if isinstance(sbunch, str):
352 return [sbunch]
354 elif sbunch is None:
355 return default_list
357 else:
358 return [st for st in sbunch if st in default_list]
360 def _add_multifeature_to_data_hashes(self, feat_id, mode='calcium'):
361 '''
362 Add previously unseen multifeature (e.g. ['x','y']) to table with data hashes.
363 This function ignores multifeatures that already exist in the table.
364 '''
365 if (not isinstance(feat_id, str)) and len(feat_id) == 1:
366 feat_id = feat_id[0]
368 if not isinstance(feat_id, str):
369 ordered_fnames = tuple(sorted(list(feat_id)))
370 if ordered_fnames not in self._data_hashes[mode]:
371 all_hashes = [self._build_pair_hash(cell_id, ordered_fnames) for cell_id in range(self.n_cells)]
372 new_dict = {ordered_fnames: dict(zip(range(self.n_cells), all_hashes))}
373 self._data_hashes[mode].update(new_dict)
375 else:
376 raise ValueError('This method is for multifeature update only')
378 def _add_multifeature_to_stats(self, feat_id, mode='calcium'):
379 '''
380 Add previously unseen multifeature (e.g. ['x','y']) to statistics and significance tables.
381 This function ignores multifeatures that already exist in the table.
382 '''
383 if (not isinstance(feat_id, str)) and len(feat_id) == 1:
384 feat_id = feat_id[0]
386 if not isinstance(feat_id, str):
387 ordered_fnames = tuple(sorted(list(feat_id)))
388 if ordered_fnames not in self.stats_tables[mode]:
389 print(f'Multifeature {feat_id} is new, it will be added to stats table')
390 self.stats_tables[mode][ordered_fnames] = {cell_id: DEFAULT_STATS.copy()
391 for cell_id in range(self.n_cells)}
393 self.significance_tables[mode][ordered_fnames] = {cell_id: DEFAULT_SIGNIFICANCE.copy()
394 for cell_id in range(self.n_cells)}
396 else:
397 raise ValueError('This method is for multifeature update only')
399 def _check_stats_relevance(self, cell_id, feat_id, mode='calcium'):
400 '''
401 A guardian function that prevents access to non-existing and irrelevant data.
403 This function checks whether the calcium-feature pair statistics has already been calculated.
404 It ensures the data (both calcium and dynamic feature) has not changed since the last
405 calculation by checking hash values of both data arrays.
407 This function always refers to stats table but works equally well with significance table
408 since they are always updated simultaneously
409 '''
411 if not isinstance(feat_id, str):
412 feat_id = tuple(sorted(list(feat_id)))
414 if feat_id not in self.stats_tables[mode]:
415 raise ValueError(f'Feature {feat_id} is not present in stats. \n If this is a single feature, '
416 'check the input data, since all single features are processed automatically.'
417 'If this is a multifeature (e.g. ["x", "y"]), compute MI significance to create stats')
419 pair_hash = self._data_hashes[mode][feat_id][cell_id]
420 existing_hash = self.stats_tables[mode][feat_id][cell_id]['data_hash']
422 # if (stats does not exist yet) or (stats exists and data is the same):
423 if (existing_hash is None) or (pair_hash == existing_hash):
424 return True
426 else:
427 print(f'Looks like the data for the pair (cell {cell_id}, feature {feat_id}) '
428 'has been changed since the last calculation)')
430 return False
432 def _update_stats_and_significance(self, stats, mode, cell_id, feat_id, stage2_only):
433 '''
434 Updates stats table and linked significance table to erase irrelevant data properly
435 '''
436 # update statistics
437 self.stats_tables[mode][feat_id][cell_id].update(stats)
438 if not stage2_only:
439 # erase significance data completely since stats for stage 1 has been modified
440 self.significance_tables[mode][feat_id][cell_id].update(DEFAULT_SIGNIFICANCE.copy())
441 else:
442 # erase significance data for stage 2 since stats for stage 2 has been modified
443 self.significance_tables[mode][feat_id][cell_id].update({'stage2': None, 'shuffles2': None})
445 def update_neuron_feature_pair_stats(self, stats, cell_id, feat_id, mode='calcium', force_update=False, stage2_only=False):
446 '''
447 Updates calcium-feature pair statistics.
448 feat_id should be a string or an iterable of strings (in case of joint MI calculation).
449 This function allows multifeatures.
450 '''
452 if not isinstance(feat_id, str):
453 self._add_multifeature_to_data_hashes(feat_id, mode=mode)
454 self._add_multifeature_to_stats(feat_id, mode=mode)
456 if self._check_stats_relevance(cell_id, feat_id, mode=mode):
457 self._update_stats_and_significance(stats, mode, cell_id, feat_id, stage2_only=stage2_only)
459 else:
460 if not force_update:
461 print(f'To forcefully update the stats, set "force_update=True"')
462 else:
463 self._update_stats_and_significance(stats, mode, cell_id, feat_id, stage2_only=stage2_only)
465 def update_neuron_feature_pair_significance(self, sig, cell_id, feat_id, mode='calcium'):
466 '''
467 Updates calcium-feature pair significance data.
468 feat_id should be a string or an iterable of strings (in case of joint MI calculation).
469 This function allows multifeatures.
470 '''
471 if not isinstance(feat_id, str):
472 self._add_multifeature_to_data_hashes(feat_id, mode=mode)
473 self._add_multifeature_to_stats(feat_id, mode=mode)
475 if self._check_stats_relevance(cell_id, feat_id, mode=mode):
476 self.significance_tables[mode][feat_id][cell_id].update(sig)
478 else:
479 raise ValueError('Can not update significance table until the collision between actual data hashes and '
480 'saved stats data hashes is resolved. Use update_neuron_feature_pair_stats'
481 'with "force_update=True" to forcefully rewrite statistics')
483 def get_neuron_feature_pair_stats(self, cell_id, feat_id, mode='calcium'):
484 '''
485 Returns calcium-feature pair statistics.
486 This function allows multifeatures.
487 '''
488 stats = None
489 if self._check_stats_relevance(cell_id, feat_id):
490 stats = self.stats_tables[mode][feat_id][cell_id]
491 else:
492 print(f'Consider recalculating stats')
494 return stats
496 def get_neuron_feature_pair_significance(self, cell_id, feat_id, mode='calcium'):
497 '''
498 Returns calcium-feature pair significance data.
499 This function allows multifeatures.
500 '''
501 sig = None
502 if self._check_stats_relevance(cell_id, feat_id):
503 sig = self.significance_tables[mode][feat_id][cell_id]
504 else:
505 print(f'Consider recalculating stats')
507 return sig
509 def get_multicell_shuffled_calcium(self, cbunch=None, method='roll_based', no_ts=True, **kwargs):
510 '''
512 Args:
513 cbunch:
514 method:
515 **kwargs:
516 no_ts: if True, intermediate TimeSeries objects are nor created, which speeds up shuffling
518 Returns:
520 '''
521 cell_list = self._process_cbunch(cbunch)
522 agg_sh_data = np.zeros((len(cell_list), self.n_frames))
523 for i in cell_list:
524 cell = self.neurons[i]
525 sh_data = cell.get_shuffled_calcium(method=method, **kwargs, no_ts=no_ts)
526 if no_ts:
527 agg_sh_data[i, :] = sh_data[:]
528 else:
529 agg_sh_data[i, :] = sh_data.data[:]
531 return agg_sh_data
533 def get_multicell_shuffled_spikes(self, cbunch=None, method='isi_based', no_ts=True, **kwargs):
534 '''
536 Args:
537 cbunch:
538 method:
539 **kwargs:
540 no_ts: if True, intermediate TimeSeries objects are nor created, which speeds up shuffling
542 Returns:
544 '''
545 # Check if spikes data is meaningful (not all zeros)
546 if np.allclose(self.spikes.data, 0):
547 raise AttributeError('Unable to shuffle spikes without meaningful spikes data')
549 cell_list = self._process_cbunch(cbunch)
551 agg_sh_data = np.zeros((len(cell_list), self.n_frames))
552 for i in cell_list:
553 cell = self.neurons[i]
554 sh_data = cell.get_shuffled_spikes(method=method, **kwargs)
555 if no_ts:
556 agg_sh_data[i, :] = sh_data[:]
557 else:
558 agg_sh_data[i, :] = sh_data.data[:]
560 return agg_sh_data
562 def get_stats_slice(self,
563 table_to_scan=None,
564 cbunch=None,
565 fbunch=None,
566 sbunch=None,
567 significance_mode=False,
568 mode='calcium'):
569 '''
570 returns slice of accumulated statistics data (or significance data if "significance_mode=True")
571 '''
572 cell_ids = self._process_cbunch(cbunch)
573 feat_ids = self._process_fbunch(fbunch, allow_multifeatures=True, mode=mode)
574 slist = self._process_sbunch(sbunch, significance_mode=significance_mode)
576 if table_to_scan is None:
577 if significance_mode:
578 full_table = self.significance_tables[mode]
579 else:
580 full_table = self.stats_tables[mode]
581 else:
582 full_table = table_to_scan
584 out_table = self._populate_cell_feat_dict(dict(), fbunch=fbunch, cbunch=cbunch)
585 for feat_id in feat_ids:
586 for cell_id in cell_ids:
587 out_table[feat_id][cell_id] = {s: full_table[feat_id][cell_id][s] for s in slist}
589 return out_table
591 def get_significance_slice(self, cbunch=None, fbunch=None, sbunch=None, mode='calcium'):
592 return self.get_stats_slice(cbunch=cbunch,
593 fbunch=fbunch,
594 sbunch=sbunch,
595 significance_mode=True,
596 mode=mode)
598 def get_feature_entropy(self, feat_id, ds=1):
599 '''
600 Calculates entropy of a single dynamic feature or a multifeature (e.g. ['x','y']).
601 Currently only 2-combinations of features are correctly supported,
602 for 3 and more variables calculations will be distorted (correct estimation of multivariate
603 entropy for non-gaussian variables is non-trivial).
604 '''
605 if isinstance(feat_id, str):
606 fts = self.dynamic_features[feat_id]
607 return fts.get_entropy(ds=ds)
609 else:
610 ordered_fnames = tuple(sorted(list(feat_id)))
611 tslist = [self.dynamic_features[dfeat] for dfeat in ordered_fnames]
612 single_entropies = [fts.get_entropy(ds=ds) for fts in tslist]
613 fpairs = list(combinations(tslist, 2))
614 MIs = [get_1d_mi(ts1, ts2, ds=ds) for (ts1,ts2) in fpairs]
615 return sum(single_entropies) - sum(MIs)
617 def _reconstruct_spikes(self, calcium, method, fps, spike_kwargs=None):
618 """
619 Reconstruct spikes from calcium signals using specified method.
621 Parameters
622 ----------
623 calcium : np.ndarray
624 Calcium traces, shape (n_neurons, n_timepoints)
625 method : str or callable
626 Reconstruction method: 'wavelet' or a callable function
627 fps : float
628 Sampling rate in frames per second
629 spike_kwargs : dict, optional
630 Method-specific parameters
632 Returns
633 -------
634 spikes : np.ndarray
635 Reconstructed spike trains
636 """
637 from .spike_reconstruction import reconstruct_spikes
639 # Convert calcium to MultiTimeSeries if needed
640 if not hasattr(calcium, 'data'):
641 # Create temporary MultiTimeSeries from numpy array
642 from ..information.info_base import TimeSeries, MultiTimeSeries
643 ts_list = [TimeSeries(calcium[i, :]) for i in range(calcium.shape[0])]
644 calcium_mts = MultiTimeSeries(ts_list)
645 else:
646 calcium_mts = calcium
648 # Call the unified reconstruction function
649 spikes_mts, metadata = reconstruct_spikes(
650 calcium_mts,
651 method=method,
652 fps=fps,
653 params=spike_kwargs
654 )
656 # Store metadata
657 self._reconstruction_metadata = metadata
659 # Return numpy array for backward compatibility
660 return spikes_mts.data
662 def get_significant_neurons(self, min_nspec=1, cbunch=None, fbunch=None, mode='calcium'):
663 '''
664 Returns a dict with neuron ids as keys and their significantly correlated features as values
665 Only neurons with "min_nspec" or more significantly correlated features will be returned
667 Parameters
668 ----------
669 min_nspec : int
670 Minimum number of significantly correlated features required
671 cbunch : int, list or None
672 Cell indices to analyze. By default (None), all neurons will be checked
673 fbunch : str, list or None
674 Feature names to check. By default (None), all features will be checked
675 mode : str
676 Data type: 'calcium' or 'spikes'
678 Returns
679 -------
680 dict
681 Dictionary with neuron IDs as keys and lists of significant features as values
682 '''
683 cell_ids = self._process_cbunch(cbunch)
684 feat_ids = self._process_fbunch(fbunch, allow_multifeatures=True, mode=mode)
686 # Check relevance only for requested cells and features
687 relevance = [self._check_stats_relevance(cell_id, feat_id, mode=mode)
688 for cell_id in cell_ids for feat_id in feat_ids]
689 if not np.all(np.array(relevance)):
690 raise ValueError('Stats relevance error')
692 # TODO: add significance update and pval_thr argument
693 cell_feat_dict = {cell_id: [] for cell_id in cell_ids}
694 for cell_id in cell_ids:
695 for feat_id in feat_ids:
696 if self.significance_tables[mode][feat_id][cell_id]['stage2']:
697 cell_feat_dict[cell_id].append(feat_id)
699 # filter out cells without enough specializations
700 final_cell_feat_dict = {cell_id: cell_feat_dict[cell_id]
701 for cell_id in cell_ids
702 if len(cell_feat_dict[cell_id]) >= min_nspec}
704 return final_cell_feat_dict
707 #===================================================================================
708 # not active
710 def save_mi_significance_to_file(self, fname):
711 with open(fname, 'wb') as f:
712 pickle.dump(self.mi_significance_table, f)
715 def clear_cells_mi_significance_data(self, cbunch, path_to_save = None):
716 for cell_id in cbunch:
717 for feat in self.dynamic_features:
718 self.mi_significance_table[feat][cell_id] = {}
720 if path_to_save is not None:
721 self.save_mi_significance_to_file(path_to_save)
723 def clear_features_mi_significance_data(self, feat_list, save_to_file = False):
724 pass
726 def clear_cell_feat_mi_significance_data(self, cell, feat, save_to_file = False):
727 pass
729 def _load_precomputed_data(self, **kwargs):
730 if 'mi_significance' in kwargs:
731 self.mi_significance_table = {**self.mi_significance_table, **kwargs['mi_significance']}
733 def store_embedding(self, embedding, method_name, data_type='calcium', metadata=None):
734 """
735 Store dimensionality reduction embedding in the experiment.
737 Parameters
738 ----------
739 embedding : np.ndarray
740 The embedding array, shape (n_timepoints, n_components)
741 method_name : str
742 Name of the DR method (e.g., 'pca', 'umap', 'isomap')
743 data_type : str
744 Type of data used ('calcium' or 'spikes')
745 metadata : dict, optional
746 Additional metadata about the embedding (e.g., parameters, quality metrics)
747 """
748 if data_type not in ['calcium', 'spikes']:
749 raise ValueError("data_type must be 'calcium' or 'spikes'")
751 # Check if embedding matches expected timepoints (accounting for downsampling)
752 ds = metadata.get('ds', 1) if metadata else 1
753 expected_frames = self.n_frames // ds
754 if embedding.shape[0] != expected_frames:
755 raise ValueError(
756 f"Embedding timepoints ({embedding.shape[0]}) must match expected frames "
757 f"({expected_frames} = {self.n_frames} / ds={ds})"
758 )
760 self.embeddings[data_type][method_name] = {
761 'data': embedding,
762 'metadata': metadata or {},
763 'timestamp': np.datetime64('now'),
764 'shape': embedding.shape
765 }
767 def get_embedding(self, method_name, data_type='calcium'):
768 """
769 Retrieve stored embedding.
771 Parameters
772 ----------
773 method_name : str
774 Name of the DR method
775 data_type : str
776 Type of data used ('calcium' or 'spikes')
778 Returns
779 -------
780 dict
781 Dictionary containing 'data' and 'metadata'
782 """
783 if data_type not in ['calcium', 'spikes']:
784 raise ValueError("data_type must be 'calcium' or 'spikes'")
786 if method_name not in self.embeddings[data_type]:
787 raise KeyError(f"No embedding found for method '{method_name}' with data_type '{data_type}'")
789 return self.embeddings[data_type][method_name]