Coverage for src/driada/experiment/exp_build.py: 13.21%
106 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1import copy
2import os.path
3import numpy as np
4import pickle
6from .exp_base import Experiment
7from ..information.info_base import TimeSeries
8from ..utils.naming import construct_session_name
9from ..utils.output import show_output
10from .neuron import DEFAULT_FPS, DEFAULT_T_OFF, DEFAULT_T_RISE
11from ..gdrive.download import download_gdrive_data, initialize_iabs_router
14def load_exp_from_aligned_data(data_source,
15 exp_params,
16 data,
17 force_continuous=[],
18 bad_frames=[],
19 static_features=None,
20 verbose=True,
21 reconstruct_spikes='wavelet'):
23 expname = construct_session_name(data_source, exp_params)
24 adata = copy.deepcopy(data)
25 key_mapping = {key.lower(): key for key in adata.keys()}
27 if verbose:
28 print(f'Building experiment {expname}...')
30 if 'calcium' in key_mapping:
31 calcium = adata.pop(key_mapping['calcium'])
32 else:
33 raise ValueError('No calcium data found!')
35 spikes = None
36 if 'spikes' in key_mapping:
37 spikes = adata.pop(key_mapping['spikes'])
39 dyn_features = adata.copy()
41 def is_garbage(vals):
42 return len(set(vals)) == 1 or np.sum(np.isnan(vals)).astype(int) == len(vals)
44 if len(force_continuous) != 0:
45 feat_is_continuous = {f: f in force_continuous for f in dyn_features.keys()}
46 filt_dyn_features = {f: TimeSeries(vals, discrete=not feat_is_continuous[f]) for f, vals in dyn_features.items() if not is_garbage(vals)}
47 else:
48 filt_dyn_features = {f: TimeSeries(vals) for f, vals in dyn_features.items() if not is_garbage(vals)}
50 if verbose:
51 print('behaviour variables:')
52 print()
53 for f, ts in filt_dyn_features.items():
54 print(f"'{f}'", 'discrete' if ts.discrete else 'continuous')
56 # check for constant features
57 constfeats = set(dyn_features.keys() - set(filt_dyn_features.keys()))
59 if len(constfeats) != 0 and verbose:
60 print(f'features {constfeats} dropped as constant or empty')
62 auto_continuous = [fn for fn, ts in filt_dyn_features.items() if not ts.discrete]
63 if verbose:
64 print(f'features {auto_continuous} automatically determined as continuous')
65 print()
67 if len(force_continuous) != 0:
68 if set(auto_continuous) != (set(force_continuous) & set(dyn_features.keys())):
69 print('Warning: auto determined continuous features do not coincide with force_continuous list! Automatic labelling will be overridden')
70 for fn, ts in filt_dyn_features.items():
71 if len(set(ts.data)) > 2 and not(fn in force_continuous):
72 filt_dyn_features[fn] = TimeSeries(ts.data.astype(bool).astype(int), discrete=True)
73 if verbose:
74 print(f'feature {fn} converted to integer')
76 signature = f'Exp {expname}'
78 # set default static experiment features if not provided
79 default_static_features = {'t_rise_sec': DEFAULT_T_RISE,
80 't_off_sec': DEFAULT_T_OFF,
81 'fps': DEFAULT_FPS}
83 if static_features is None:
84 static_features = dict()
85 for sf in default_static_features.keys():
86 if sf not in static_features:
87 static_features.update({sf: default_static_features[sf]})
89 exp = Experiment(signature,
90 calcium,
91 spikes,
92 exp_params,
93 static_features,
94 filt_dyn_features,
95 reconstruct_spikes=reconstruct_spikes,
96 bad_frames_mask=np.array([True if _ in bad_frames else False for _ in range(calcium.shape[1])])
97 )
100 return exp
103def load_experiment(data_source,
104 exp_params,
105 force_rebuild=False,
106 force_reload=False,
107 via_pydrive=True,
108 gauth=None,
109 root='DRIADA data',
110 exp_path=None,
111 data_path=None,
112 force_continuous=[],
113 bad_frames=[],
114 static_features=None,
115 reconstruct_spikes='wavelet',
116 save_to_pickle=False,
117 verbose=True):
119 os.makedirs(root, exist_ok=True)
120 if not os.path.isdir(root):
121 raise ValueError('Root must be a folder!')
123 if exp_path is None:
124 expname = construct_session_name(data_source, exp_params)
125 exp_path = os.path.join(root, expname, f'Exp {expname}.pickle')
127 if os.path.exists(exp_path) and not force_rebuild and not force_reload:
128 Exp = load_exp_from_pickle(exp_path, verbose=verbose)
129 return Exp, None
131 else:
132 if data_source == 'IABS':
133 if data_path is None:
134 data_path = os.path.join(root,
135 expname,
136 'Aligned data',
137 f'{expname} syn data.npz')
138 if verbose:
139 print(f'Path to data: {data_path}')
141 data_exists = os.path.exists(data_path)
142 if verbose:
143 if data_exists:
144 print('Aligned data for experiment construction found successfully')
145 else:
146 print('Failed to locate aligned data for experiment construction')
148 if force_reload or not data_exists:
149 if verbose:
150 print('Loading data from cloud storage...')
151 data_router, data_pieces = initialize_iabs_router(root=root)
152 success, load_log = download_gdrive_data(data_router,
153 expname,
154 data_pieces=['Aligned data'],
155 via_pydrive=via_pydrive,
156 tdir=root,
157 gauth=gauth)
159 if not success:
160 print('=========== BEGINNING OF LOADING LOG ============')
161 show_output(load_log)
162 print('=========== END OF LOADING LOG ============')
163 raise FileNotFoundError(f'Cannot download {expname}, see loading log above')
165 else:
166 load_log = None
168 aligned_data = dict(np.load(data_path))
169 Exp = load_exp_from_aligned_data(data_source,
170 exp_params,
171 aligned_data,
172 force_continuous=force_continuous,
173 static_features=static_features,
174 verbose=verbose,
175 bad_frames=bad_frames,
176 reconstruct_spikes=reconstruct_spikes)
178 if save_to_pickle:
179 save_exp_to_pickle(Exp, exp_path, verbose=verbose)
180 return Exp, load_log
182 else:
183 raise ValueError('External data sources are not yet supported')
186def save_exp_to_pickle(exp, path, verbose=True):
187 with open(path, "wb") as f:
188 pickle.dump(exp, f)
189 if verbose:
190 print(f'Experiment {exp.signature} saved to {path}\n')
193def load_exp_from_pickle(path, verbose=True):
194 with open(path, "rb") as f:
195 exp = pickle.load(f,)
196 if verbose:
197 print(f'Experiment {exp.signature} loaded from {path}\n')
198 return exp