Coverage for src/driada/experiment/exp_base.py: 60.05%

368 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1import numpy as np 

2import warnings 

3import tqdm 

4from itertools import combinations 

5import pickle 

6 

7from ..information.info_base import TimeSeries 

8from ..information.info_base import MultiTimeSeries 

9from .neuron import DEFAULT_MIN_BEHAVIOUR_TIME, Neuron 

10from .wavelet_event_detection import WVT_EVENT_DETECTION_PARAMS, extract_wvt_events, events_to_ts_array, ridges_to_containers 

11from ..utils.data import get_hash, populate_nested_dict 

12from ..information.info_base import get_1d_mi 

13 

14STATS_VARS = ['data_hash', 'opt_delay', 'pre_pval', 'pre_rval', 'pval', 'rval', 'me', 'rel_me_beh', 'rel_me_ca'] 

15SIGNIFICANCE_VARS = ['stage1', 'shuffles1', 'stage2', 'shuffles2', 'final_p_thr', 'multicomp_corr', 'pairwise_pval_thr'] 

16DEFAULT_STATS = dict(zip(STATS_VARS, [None for _ in STATS_VARS])) 

17DEFAULT_SIGNIFICANCE = dict(zip(SIGNIFICANCE_VARS, [None for _ in SIGNIFICANCE_VARS])) 

18 

19 

20def check_dynamic_features(dynamic_features): 

21 dfeat_lengths = {} 

22 for feat_id in dynamic_features: 

23 current_ts = dynamic_features[feat_id] 

24 if isinstance(current_ts, TimeSeries): 

25 len_ts = len(current_ts.data) 

26 elif isinstance(current_ts, np.ndarray): 

27 # Handle raw numpy arrays 

28 len_ts = current_ts.shape[-1] if current_ts.ndim > 1 else len(current_ts) 

29 elif hasattr(current_ts, 'data') and hasattr(current_ts.data, 'shape'): # MultiTimeSeries or similar 

30 len_ts = current_ts.data.shape[1] # MultiTimeSeries data is (n_features, n_timepoints) 

31 else: 

32 len_ts = len(current_ts) 

33 

34 dfeat_lengths[feat_id] = len_ts 

35 

36 #TODO: add fix for 0 features 

37 if len(set(dfeat_lengths.values())) != 1: 

38 print(dfeat_lengths) 

39 raise ValueError('Dynamic features have different lengths!') 

40 

41 

42class Experiment(): 

43 ''' 

44 Class for all Ca2+ experiment types 

45 

46 Attributes 

47 ---------- 

48 

49 Methods 

50 ------- 

51 

52 ''' 

53 

54 def __init__(self, signature, calcium, spikes, exp_identificators, 

55 static_features, dynamic_features, **kwargs): 

56 

57 fit_individual_t_off = kwargs.get('fit_individual_t_off', False) 

58 reconstruct_spikes = kwargs.get('reconstruct_spikes', 'wavelet') 

59 bad_frames_mask = kwargs.get('bad_frames_mask', None) 

60 spike_kwargs = kwargs.get('spike_kwargs', None) 

61 

62 check_dynamic_features(dynamic_features) 

63 self.exp_identificators = exp_identificators 

64 self.signature = signature 

65 

66 for idx in exp_identificators: 

67 setattr(self, idx, exp_identificators[idx]) 

68 

69 if calcium is None: 

70 raise AttributeError('No calcium data provided') 

71 

72 if reconstruct_spikes is None: 

73 if spikes is None: 

74 warnings.warn('No spike data provided, spikes reconstruction from Ca2+ data disabled') 

75 else: 

76 if spikes is not None: 

77 warnings.warn(f'Spike data will be overridden by reconstructed spikes from Ca2+ data with method={reconstruct_spikes}') 

78 

79 # Store the reconstruction method for potential future use 

80 self.spike_reconstruction_method = reconstruct_spikes 

81 

82 # Reconstruct spikes 

83 spikes = self._reconstruct_spikes(calcium, reconstruct_spikes, static_features.get('fps'), spike_kwargs) 

84 

85 self.filtered_flag = False 

86 if bad_frames_mask is not None: 

87 calcium, spikes, dynamic_features = self._trim_data(calcium, 

88 spikes, 

89 dynamic_features, 

90 bad_frames_mask) 

91 else: 

92 for feat_id in dynamic_features.copy(): 

93 feat_data = dynamic_features[feat_id] 

94 # Skip if already a TimeSeries or MultiTimeSeries 

95 if not isinstance(feat_data, (TimeSeries, MultiTimeSeries)): 

96 # Convert numpy arrays based on dimensionality 

97 if isinstance(feat_data, np.ndarray): 

98 if feat_data.ndim == 1: 

99 # 1D array -> TimeSeries 

100 dynamic_features[feat_id] = TimeSeries(feat_data) 

101 elif feat_data.ndim == 2: 

102 # 2D array -> MultiTimeSeries (each row is a component) 

103 ts_list = [TimeSeries(feat_data[i, :], discrete=False) for i in range(feat_data.shape[0])] 

104 dynamic_features[feat_id] = MultiTimeSeries(ts_list) 

105 else: 

106 raise ValueError(f"Feature {feat_id} has unsupported dimensionality: {feat_data.ndim}D") 

107 else: 

108 # Assume it's 1D data if not numpy array 

109 dynamic_features[feat_id] = TimeSeries(feat_data) 

110 

111 self.n_cells = calcium.shape[0] 

112 self.n_frames = calcium.shape[1] 

113 

114 # Store raw calcium and spikes arrays temporarily 

115 self._calcium_raw = calcium 

116 self._spikes_raw = spikes if spikes is not None else np.zeros(calcium.shape) 

117 

118 self.neurons = [] 

119 

120 print('Building neurons...') 

121 for i in tqdm.tqdm(np.arange(self.n_cells), position=0, leave=True): 

122 cell = Neuron(str(i), 

123 self._calcium_raw[i, :], 

124 self._spikes_raw[i, :], 

125 default_t_rise=static_features.get('t_rise_sec'), 

126 default_t_off=static_features.get('t_off_sec'), 

127 fps=static_features.get('fps'), 

128 fit_individual_t_off=fit_individual_t_off) 

129 

130 self.neurons.append(cell) 

131 

132 # Now create MultiTimeSeries from neurons to preserve their shuffle masks 

133 calcium_ts_list = [neuron.ca for neuron in self.neurons] 

134 spikes_ts_list = [neuron.sp if neuron.sp is not None else TimeSeries(np.zeros(self.n_frames), discrete=True) for neuron in self.neurons] 

135 

136 # Create MultiTimeSeries from the TimeSeries objects in neurons 

137 # This preserves the individual shuffle masks created by each Neuron 

138 self.calcium = MultiTimeSeries(calcium_ts_list) 

139 self.spikes = MultiTimeSeries(spikes_ts_list) 

140 

141 self.dynamic_features = dynamic_features 

142 for feat_id in dynamic_features: 

143 if isinstance(feat_id, str): 

144 setattr(self, feat_id, dynamic_features[feat_id]) 

145 # Skip tuples (multifeatures) as they can't be attribute names 

146 

147 for sfeat_name in static_features: 

148 setattr(self, sfeat_name, static_features[sfeat_name]) 

149 

150 # for selectivity data from INTENSE 

151 self.stats_tables = {} 

152 self.significance_tables = {} 

153 self.selectivity_tables_initialized = False 

154 

155 # for dimensionality reduction embeddings 

156 self.embeddings = {'calcium': {}, 'spikes': {}} 

157 

158 print('Building data hashes...') 

159 self._build_data_hashes(mode='calcium') 

160 if reconstruct_spikes is not None or spikes is not None: 

161 self._build_data_hashes(mode='spikes') 

162 

163 print('Final checkpoint...') 

164 self._checkpoint() 

165 #self._load_precomputed_data(**kwargs) 

166 

167 print(f'Experiment "{self.signature}" constructed successfully with {self.n_cells} neurons and {len(self.dynamic_features)} features') 

168 

169 def check_ds(self, ds): 

170 if not hasattr(self, 'fps'): 

171 raise ValueError(f'fps not set for {self.signature}') 

172 

173 time_step = 1.0/self.fps 

174 if time_step*ds > DEFAULT_MIN_BEHAVIOUR_TIME: 

175 print('Downsampling constant is too high: some behaviour acts may be skipped. ' 

176 f'Current minimal behaviour time interval is set to {DEFAULT_MIN_BEHAVIOUR_TIME} sec, ' 

177 f'downsampling {ds} will create time gaps of {time_step*ds} sec') 

178 

179 def _set_selectivity_tables(self, mode, fbunch=None, cbunch=None): 

180 # neuron-feature pair statistics 

181 stats_table = self._populate_cell_feat_dict(DEFAULT_STATS, fbunch=fbunch, cbunch=cbunch) 

182 

183 # neuron-feature pair significance-related data 

184 significance_table = self._populate_cell_feat_dict(DEFAULT_SIGNIFICANCE, 

185 fbunch=fbunch, 

186 cbunch=cbunch) 

187 self.stats_tables[mode] = stats_table 

188 self.significance_tables[mode] = significance_table 

189 self.selectivity_tables_initialized = True 

190 

191 def _build_pair_hash(self, cell_id, feat_id, mode='calcium'): 

192 ''' 

193 Builds a unique hash-based representation of activity-feature pair data. 

194 feat_id should be a string or an iterable of strings (in case of joint MI calculation). 

195 ''' 

196 if mode == 'calcium': 

197 act = self.neurons[cell_id].ca.data 

198 elif mode == 'spikes': 

199 act = self.neurons[cell_id].sp.data 

200 else: 

201 raise ValueError('"mode" can be either "calcium" or "spikes"') 

202 

203 act_hash = get_hash(act) 

204 

205 if (not isinstance(feat_id, str)) and len(feat_id) == 1: 

206 feat_id = feat_id[0] 

207 

208 if isinstance(feat_id, str): 

209 dyn_data = self.dynamic_features[feat_id].data 

210 dyn_data_hash = get_hash(dyn_data) 

211 pair_hash = (act_hash, dyn_data_hash) 

212 

213 else: 

214 ordered_fnames = tuple(sorted(list(feat_id))) 

215 list_of_hashes = [act_hash] 

216 for fname in ordered_fnames: 

217 dyn_data = self.dynamic_features[fname].data 

218 dyn_data_hash = get_hash(dyn_data) 

219 list_of_hashes.append(dyn_data_hash) 

220 

221 pair_hash = tuple(list_of_hashes) 

222 

223 return pair_hash 

224 

225 def _build_data_hashes(self, mode='calcium'): 

226 ''' 

227 Builds a unique hash-based representation of calcium-feature pair data for all cell-feature pairs.. 

228 ''' 

229 default_data_hashes = {dfeat: dict(zip(range(self.n_cells), [None for _ in range(self.n_cells)])) for dfeat in self.dynamic_features.keys()} 

230 self._data_hashes = {'calcium': default_data_hashes, 'spikes': default_data_hashes} 

231 for feat_id in self.dynamic_features: 

232 for cell_id in range(self.n_cells): 

233 self._data_hashes[mode][feat_id][cell_id] = self._build_pair_hash(cell_id, feat_id, mode=mode) 

234 

235 def _trim_data(self, calcium, spikes, dynamic_features, bad_frames_mask, force_filter=False): 

236 

237 if not force_filter and self.filtered_flag: 

238 raise AttributeError('Data is already filtered, if you want to force filtering it again, set "force_filter = True"') 

239 

240 f_calcium = calcium[:, ~bad_frames_mask] 

241 if spikes is not None: 

242 f_spikes = spikes[:, ~bad_frames_mask] 

243 else: 

244 f_spikes = None 

245 

246 f_dynamic_features = {} 

247 for feat_id in dynamic_features: 

248 current_ts = dynamic_features[feat_id] 

249 if isinstance(current_ts, TimeSeries): 

250 f_ts = TimeSeries(current_ts.data[~bad_frames_mask], discrete=current_ts.discrete) 

251 else: 

252 f_ts = TimeSeries(current_ts[~bad_frames_mask]) 

253 

254 f_dynamic_features[feat_id] = f_ts 

255 

256 self.filtered_flag = True 

257 self.bad_frames_mask = bad_frames_mask 

258 

259 return f_calcium, f_spikes, f_dynamic_features 

260 

261 def _checkpoint(self): 

262 ''' 

263 Check build for common errors 

264 ''' 

265 if self.n_cells > self.n_frames: 

266 raise UserWarning('Number of cells > number of time frames, looks like the data is transposed') 

267 

268 for dfeat in ['calcium', 'spikes']: 

269 if self.n_frames not in getattr(self, dfeat).shape: 

270 raise ValueError(f'"{dfeat}" feature has inappropriate shape: {getattr(self, dfeat).data.shape}' 

271 f'inconsistent with data length {self.n_frames}') 

272 

273 for dfeat in self.dynamic_features.keys(): 

274 if isinstance(dfeat, str): 

275 if self.n_frames not in getattr(self, dfeat).data.shape: 

276 raise ValueError(f'"{dfeat}" feature has inappropriate shape: {getattr(self, dfeat).data.shape}' 

277 f'inconsistent with data length {self.n_frames}') 

278 else: 

279 # For tuple features (multifeatures), check the underlying data 

280 feat_data = self.dynamic_features[dfeat] 

281 if hasattr(feat_data, 'data') and self.n_frames not in feat_data.data.shape: 

282 raise ValueError(f'"{dfeat}" feature has inappropriate shape: {feat_data.data.shape}' 

283 f'inconsistent with data length {self.n_frames}') 

284 

285 def _populate_cell_feat_dict(self, content, fbunch=None, cbunch=None): 

286 ''' 

287 Helper function. Creates a nested dictionary of feature-cell pairs and populates every cell with 'content' variable. 

288 Outer dict: dynamic features, inner dict: cells 

289 ''' 

290 cell_ids = self._process_cbunch(cbunch) 

291 feat_ids = self._process_fbunch(fbunch, allow_multifeatures=True) 

292 nested_dict = populate_nested_dict(content, feat_ids, cell_ids) 

293 return nested_dict 

294 

295 def _process_cbunch(self, cbunch): 

296 ''' 

297 Helper function. Turns cell indices (int, iterable or None) into a list of cell numbers. 

298 ''' 

299 if isinstance(cbunch, int): 

300 cell_ids = [cbunch] 

301 elif cbunch is None: 

302 cell_ids = list(np.arange(self.n_cells)) 

303 else: 

304 cell_ids = list(cbunch) 

305 

306 return cell_ids 

307 

308 def _process_fbunch(self, fbunch, allow_multifeatures=False, mode='calcium'): 

309 ''' 

310 Helper function. Turns feature names (str, iterable or None) into a list of feature names 

311 ''' 

312 if isinstance(fbunch, str): 

313 feat_ids = [fbunch] 

314 

315 elif fbunch is None: # default set of features 

316 if allow_multifeatures: 

317 try: 

318 # stats table contains up-to-date set of features, including multifeatures 

319 feat_ids = list(self.stats_tables[mode].keys()) 

320 except KeyError: 

321 # if stats is not available, take pre-defined full set of features 

322 feat_ids = list(self.dynamic_features.keys()) 

323 else: 

324 feat_ids = list(self.dynamic_features.keys()) 

325 

326 else: 

327 feat_ids = [] 

328 

329 # check for multifeatures 

330 for fname in fbunch: 

331 if isinstance(fname, str): 

332 if fname in self.dynamic_features: 

333 feat_ids.append(fname) 

334 else: 

335 if allow_multifeatures: 

336 feat_ids.append(tuple(sorted(list(fname)))) 

337 else: 

338 raise ValueError('Multifeature detected in "allow_multifeatures=False" mode') 

339 

340 return feat_ids 

341 

342 def _process_sbunch(self, sbunch, significance_mode=False): 

343 ''' 

344 Helper function. Turns stats type names (str, iterable or None) into a list of stats types 

345 ''' 

346 if significance_mode: 

347 default_list = SIGNIFICANCE_VARS 

348 else: 

349 default_list = STATS_VARS 

350 

351 if isinstance(sbunch, str): 

352 return [sbunch] 

353 

354 elif sbunch is None: 

355 return default_list 

356 

357 else: 

358 return [st for st in sbunch if st in default_list] 

359 

360 def _add_multifeature_to_data_hashes(self, feat_id, mode='calcium'): 

361 ''' 

362 Add previously unseen multifeature (e.g. ['x','y']) to table with data hashes. 

363 This function ignores multifeatures that already exist in the table. 

364 ''' 

365 if (not isinstance(feat_id, str)) and len(feat_id) == 1: 

366 feat_id = feat_id[0] 

367 

368 if not isinstance(feat_id, str): 

369 ordered_fnames = tuple(sorted(list(feat_id))) 

370 if ordered_fnames not in self._data_hashes[mode]: 

371 all_hashes = [self._build_pair_hash(cell_id, ordered_fnames) for cell_id in range(self.n_cells)] 

372 new_dict = {ordered_fnames: dict(zip(range(self.n_cells), all_hashes))} 

373 self._data_hashes[mode].update(new_dict) 

374 

375 else: 

376 raise ValueError('This method is for multifeature update only') 

377 

378 def _add_multifeature_to_stats(self, feat_id, mode='calcium'): 

379 ''' 

380 Add previously unseen multifeature (e.g. ['x','y']) to statistics and significance tables. 

381 This function ignores multifeatures that already exist in the table. 

382 ''' 

383 if (not isinstance(feat_id, str)) and len(feat_id) == 1: 

384 feat_id = feat_id[0] 

385 

386 if not isinstance(feat_id, str): 

387 ordered_fnames = tuple(sorted(list(feat_id))) 

388 if ordered_fnames not in self.stats_tables[mode]: 

389 print(f'Multifeature {feat_id} is new, it will be added to stats table') 

390 self.stats_tables[mode][ordered_fnames] = {cell_id: DEFAULT_STATS.copy() 

391 for cell_id in range(self.n_cells)} 

392 

393 self.significance_tables[mode][ordered_fnames] = {cell_id: DEFAULT_SIGNIFICANCE.copy() 

394 for cell_id in range(self.n_cells)} 

395 

396 else: 

397 raise ValueError('This method is for multifeature update only') 

398 

399 def _check_stats_relevance(self, cell_id, feat_id, mode='calcium'): 

400 ''' 

401 A guardian function that prevents access to non-existing and irrelevant data. 

402 

403 This function checks whether the calcium-feature pair statistics has already been calculated. 

404 It ensures the data (both calcium and dynamic feature) has not changed since the last 

405 calculation by checking hash values of both data arrays. 

406 

407 This function always refers to stats table but works equally well with significance table 

408 since they are always updated simultaneously 

409 ''' 

410 

411 if not isinstance(feat_id, str): 

412 feat_id = tuple(sorted(list(feat_id))) 

413 

414 if feat_id not in self.stats_tables[mode]: 

415 raise ValueError(f'Feature {feat_id} is not present in stats. \n If this is a single feature, ' 

416 'check the input data, since all single features are processed automatically.' 

417 'If this is a multifeature (e.g. ["x", "y"]), compute MI significance to create stats') 

418 

419 pair_hash = self._data_hashes[mode][feat_id][cell_id] 

420 existing_hash = self.stats_tables[mode][feat_id][cell_id]['data_hash'] 

421 

422 # if (stats does not exist yet) or (stats exists and data is the same): 

423 if (existing_hash is None) or (pair_hash == existing_hash): 

424 return True 

425 

426 else: 

427 print(f'Looks like the data for the pair (cell {cell_id}, feature {feat_id}) ' 

428 'has been changed since the last calculation)') 

429 

430 return False 

431 

432 def _update_stats_and_significance(self, stats, mode, cell_id, feat_id, stage2_only): 

433 ''' 

434 Updates stats table and linked significance table to erase irrelevant data properly 

435 ''' 

436 # update statistics 

437 self.stats_tables[mode][feat_id][cell_id].update(stats) 

438 if not stage2_only: 

439 # erase significance data completely since stats for stage 1 has been modified 

440 self.significance_tables[mode][feat_id][cell_id].update(DEFAULT_SIGNIFICANCE.copy()) 

441 else: 

442 # erase significance data for stage 2 since stats for stage 2 has been modified 

443 self.significance_tables[mode][feat_id][cell_id].update({'stage2': None, 'shuffles2': None}) 

444 

445 def update_neuron_feature_pair_stats(self, stats, cell_id, feat_id, mode='calcium', force_update=False, stage2_only=False): 

446 ''' 

447 Updates calcium-feature pair statistics. 

448 feat_id should be a string or an iterable of strings (in case of joint MI calculation). 

449 This function allows multifeatures. 

450 ''' 

451 

452 if not isinstance(feat_id, str): 

453 self._add_multifeature_to_data_hashes(feat_id, mode=mode) 

454 self._add_multifeature_to_stats(feat_id, mode=mode) 

455 

456 if self._check_stats_relevance(cell_id, feat_id, mode=mode): 

457 self._update_stats_and_significance(stats, mode, cell_id, feat_id, stage2_only=stage2_only) 

458 

459 else: 

460 if not force_update: 

461 print(f'To forcefully update the stats, set "force_update=True"') 

462 else: 

463 self._update_stats_and_significance(stats, mode, cell_id, feat_id, stage2_only=stage2_only) 

464 

465 def update_neuron_feature_pair_significance(self, sig, cell_id, feat_id, mode='calcium'): 

466 ''' 

467 Updates calcium-feature pair significance data. 

468 feat_id should be a string or an iterable of strings (in case of joint MI calculation). 

469 This function allows multifeatures. 

470 ''' 

471 if not isinstance(feat_id, str): 

472 self._add_multifeature_to_data_hashes(feat_id, mode=mode) 

473 self._add_multifeature_to_stats(feat_id, mode=mode) 

474 

475 if self._check_stats_relevance(cell_id, feat_id, mode=mode): 

476 self.significance_tables[mode][feat_id][cell_id].update(sig) 

477 

478 else: 

479 raise ValueError('Can not update significance table until the collision between actual data hashes and ' 

480 'saved stats data hashes is resolved. Use update_neuron_feature_pair_stats' 

481 'with "force_update=True" to forcefully rewrite statistics') 

482 

483 def get_neuron_feature_pair_stats(self, cell_id, feat_id, mode='calcium'): 

484 ''' 

485 Returns calcium-feature pair statistics. 

486 This function allows multifeatures. 

487 ''' 

488 stats = None 

489 if self._check_stats_relevance(cell_id, feat_id): 

490 stats = self.stats_tables[mode][feat_id][cell_id] 

491 else: 

492 print(f'Consider recalculating stats') 

493 

494 return stats 

495 

496 def get_neuron_feature_pair_significance(self, cell_id, feat_id, mode='calcium'): 

497 ''' 

498 Returns calcium-feature pair significance data. 

499 This function allows multifeatures. 

500 ''' 

501 sig = None 

502 if self._check_stats_relevance(cell_id, feat_id): 

503 sig = self.significance_tables[mode][feat_id][cell_id] 

504 else: 

505 print(f'Consider recalculating stats') 

506 

507 return sig 

508 

509 def get_multicell_shuffled_calcium(self, cbunch=None, method='roll_based', no_ts=True, **kwargs): 

510 ''' 

511 

512 Args: 

513 cbunch: 

514 method: 

515 **kwargs: 

516 no_ts: if True, intermediate TimeSeries objects are nor created, which speeds up shuffling 

517 

518 Returns: 

519 

520 ''' 

521 cell_list = self._process_cbunch(cbunch) 

522 agg_sh_data = np.zeros((len(cell_list), self.n_frames)) 

523 for i in cell_list: 

524 cell = self.neurons[i] 

525 sh_data = cell.get_shuffled_calcium(method=method, **kwargs, no_ts=no_ts) 

526 if no_ts: 

527 agg_sh_data[i, :] = sh_data[:] 

528 else: 

529 agg_sh_data[i, :] = sh_data.data[:] 

530 

531 return agg_sh_data 

532 

533 def get_multicell_shuffled_spikes(self, cbunch=None, method='isi_based', no_ts=True, **kwargs): 

534 ''' 

535 

536 Args: 

537 cbunch: 

538 method: 

539 **kwargs: 

540 no_ts: if True, intermediate TimeSeries objects are nor created, which speeds up shuffling 

541 

542 Returns: 

543 

544 ''' 

545 # Check if spikes data is meaningful (not all zeros) 

546 if np.allclose(self.spikes.data, 0): 

547 raise AttributeError('Unable to shuffle spikes without meaningful spikes data') 

548 

549 cell_list = self._process_cbunch(cbunch) 

550 

551 agg_sh_data = np.zeros((len(cell_list), self.n_frames)) 

552 for i in cell_list: 

553 cell = self.neurons[i] 

554 sh_data = cell.get_shuffled_spikes(method=method, **kwargs) 

555 if no_ts: 

556 agg_sh_data[i, :] = sh_data[:] 

557 else: 

558 agg_sh_data[i, :] = sh_data.data[:] 

559 

560 return agg_sh_data 

561 

562 def get_stats_slice(self, 

563 table_to_scan=None, 

564 cbunch=None, 

565 fbunch=None, 

566 sbunch=None, 

567 significance_mode=False, 

568 mode='calcium'): 

569 ''' 

570 returns slice of accumulated statistics data (or significance data if "significance_mode=True") 

571 ''' 

572 cell_ids = self._process_cbunch(cbunch) 

573 feat_ids = self._process_fbunch(fbunch, allow_multifeatures=True, mode=mode) 

574 slist = self._process_sbunch(sbunch, significance_mode=significance_mode) 

575 

576 if table_to_scan is None: 

577 if significance_mode: 

578 full_table = self.significance_tables[mode] 

579 else: 

580 full_table = self.stats_tables[mode] 

581 else: 

582 full_table = table_to_scan 

583 

584 out_table = self._populate_cell_feat_dict(dict(), fbunch=fbunch, cbunch=cbunch) 

585 for feat_id in feat_ids: 

586 for cell_id in cell_ids: 

587 out_table[feat_id][cell_id] = {s: full_table[feat_id][cell_id][s] for s in slist} 

588 

589 return out_table 

590 

591 def get_significance_slice(self, cbunch=None, fbunch=None, sbunch=None, mode='calcium'): 

592 return self.get_stats_slice(cbunch=cbunch, 

593 fbunch=fbunch, 

594 sbunch=sbunch, 

595 significance_mode=True, 

596 mode=mode) 

597 

598 def get_feature_entropy(self, feat_id, ds=1): 

599 ''' 

600 Calculates entropy of a single dynamic feature or a multifeature (e.g. ['x','y']). 

601 Currently only 2-combinations of features are correctly supported, 

602 for 3 and more variables calculations will be distorted (correct estimation of multivariate 

603 entropy for non-gaussian variables is non-trivial). 

604 ''' 

605 if isinstance(feat_id, str): 

606 fts = self.dynamic_features[feat_id] 

607 return fts.get_entropy(ds=ds) 

608 

609 else: 

610 ordered_fnames = tuple(sorted(list(feat_id))) 

611 tslist = [self.dynamic_features[dfeat] for dfeat in ordered_fnames] 

612 single_entropies = [fts.get_entropy(ds=ds) for fts in tslist] 

613 fpairs = list(combinations(tslist, 2)) 

614 MIs = [get_1d_mi(ts1, ts2, ds=ds) for (ts1,ts2) in fpairs] 

615 return sum(single_entropies) - sum(MIs) 

616 

617 def _reconstruct_spikes(self, calcium, method, fps, spike_kwargs=None): 

618 """ 

619 Reconstruct spikes from calcium signals using specified method. 

620  

621 Parameters 

622 ---------- 

623 calcium : np.ndarray 

624 Calcium traces, shape (n_neurons, n_timepoints) 

625 method : str or callable 

626 Reconstruction method: 'wavelet' or a callable function 

627 fps : float 

628 Sampling rate in frames per second 

629 spike_kwargs : dict, optional 

630 Method-specific parameters 

631  

632 Returns 

633 ------- 

634 spikes : np.ndarray 

635 Reconstructed spike trains 

636 """ 

637 from .spike_reconstruction import reconstruct_spikes 

638 

639 # Convert calcium to MultiTimeSeries if needed 

640 if not hasattr(calcium, 'data'): 

641 # Create temporary MultiTimeSeries from numpy array 

642 from ..information.info_base import TimeSeries, MultiTimeSeries 

643 ts_list = [TimeSeries(calcium[i, :]) for i in range(calcium.shape[0])] 

644 calcium_mts = MultiTimeSeries(ts_list) 

645 else: 

646 calcium_mts = calcium 

647 

648 # Call the unified reconstruction function 

649 spikes_mts, metadata = reconstruct_spikes( 

650 calcium_mts, 

651 method=method, 

652 fps=fps, 

653 params=spike_kwargs 

654 ) 

655 

656 # Store metadata 

657 self._reconstruction_metadata = metadata 

658 

659 # Return numpy array for backward compatibility 

660 return spikes_mts.data 

661 

662 def get_significant_neurons(self, min_nspec=1, cbunch=None, fbunch=None, mode='calcium'): 

663 ''' 

664 Returns a dict with neuron ids as keys and their significantly correlated features as values 

665 Only neurons with "min_nspec" or more significantly correlated features will be returned 

666  

667 Parameters 

668 ---------- 

669 min_nspec : int 

670 Minimum number of significantly correlated features required 

671 cbunch : int, list or None 

672 Cell indices to analyze. By default (None), all neurons will be checked 

673 fbunch : str, list or None 

674 Feature names to check. By default (None), all features will be checked 

675 mode : str 

676 Data type: 'calcium' or 'spikes' 

677  

678 Returns 

679 ------- 

680 dict 

681 Dictionary with neuron IDs as keys and lists of significant features as values 

682 ''' 

683 cell_ids = self._process_cbunch(cbunch) 

684 feat_ids = self._process_fbunch(fbunch, allow_multifeatures=True, mode=mode) 

685 

686 # Check relevance only for requested cells and features 

687 relevance = [self._check_stats_relevance(cell_id, feat_id, mode=mode) 

688 for cell_id in cell_ids for feat_id in feat_ids] 

689 if not np.all(np.array(relevance)): 

690 raise ValueError('Stats relevance error') 

691 

692 # TODO: add significance update and pval_thr argument 

693 cell_feat_dict = {cell_id: [] for cell_id in cell_ids} 

694 for cell_id in cell_ids: 

695 for feat_id in feat_ids: 

696 if self.significance_tables[mode][feat_id][cell_id]['stage2']: 

697 cell_feat_dict[cell_id].append(feat_id) 

698 

699 # filter out cells without enough specializations 

700 final_cell_feat_dict = {cell_id: cell_feat_dict[cell_id] 

701 for cell_id in cell_ids 

702 if len(cell_feat_dict[cell_id]) >= min_nspec} 

703 

704 return final_cell_feat_dict 

705 

706 

707 #=================================================================================== 

708 # not active 

709 

710 def save_mi_significance_to_file(self, fname): 

711 with open(fname, 'wb') as f: 

712 pickle.dump(self.mi_significance_table, f) 

713 

714 

715 def clear_cells_mi_significance_data(self, cbunch, path_to_save = None): 

716 for cell_id in cbunch: 

717 for feat in self.dynamic_features: 

718 self.mi_significance_table[feat][cell_id] = {} 

719 

720 if path_to_save is not None: 

721 self.save_mi_significance_to_file(path_to_save) 

722 

723 def clear_features_mi_significance_data(self, feat_list, save_to_file = False): 

724 pass 

725 

726 def clear_cell_feat_mi_significance_data(self, cell, feat, save_to_file = False): 

727 pass 

728 

729 def _load_precomputed_data(self, **kwargs): 

730 if 'mi_significance' in kwargs: 

731 self.mi_significance_table = {**self.mi_significance_table, **kwargs['mi_significance']} 

732 

733 def store_embedding(self, embedding, method_name, data_type='calcium', metadata=None): 

734 """ 

735 Store dimensionality reduction embedding in the experiment. 

736  

737 Parameters 

738 ---------- 

739 embedding : np.ndarray 

740 The embedding array, shape (n_timepoints, n_components) 

741 method_name : str 

742 Name of the DR method (e.g., 'pca', 'umap', 'isomap') 

743 data_type : str 

744 Type of data used ('calcium' or 'spikes') 

745 metadata : dict, optional 

746 Additional metadata about the embedding (e.g., parameters, quality metrics) 

747 """ 

748 if data_type not in ['calcium', 'spikes']: 

749 raise ValueError("data_type must be 'calcium' or 'spikes'") 

750 

751 # Check if embedding matches expected timepoints (accounting for downsampling) 

752 ds = metadata.get('ds', 1) if metadata else 1 

753 expected_frames = self.n_frames // ds 

754 if embedding.shape[0] != expected_frames: 

755 raise ValueError( 

756 f"Embedding timepoints ({embedding.shape[0]}) must match expected frames " 

757 f"({expected_frames} = {self.n_frames} / ds={ds})" 

758 ) 

759 

760 self.embeddings[data_type][method_name] = { 

761 'data': embedding, 

762 'metadata': metadata or {}, 

763 'timestamp': np.datetime64('now'), 

764 'shape': embedding.shape 

765 } 

766 

767 def get_embedding(self, method_name, data_type='calcium'): 

768 """ 

769 Retrieve stored embedding. 

770  

771 Parameters 

772 ---------- 

773 method_name : str 

774 Name of the DR method 

775 data_type : str 

776 Type of data used ('calcium' or 'spikes') 

777  

778 Returns 

779 ------- 

780 dict 

781 Dictionary containing 'data' and 'metadata' 

782 """ 

783 if data_type not in ['calcium', 'spikes']: 

784 raise ValueError("data_type must be 'calcium' or 'spikes'") 

785 

786 if method_name not in self.embeddings[data_type]: 

787 raise KeyError(f"No embedding found for method '{method_name}' with data_type '{data_type}'") 

788 

789 return self.embeddings[data_type][method_name]