Coverage for src/driada/experiment/exp_build.py: 13.21%

106 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1import copy 

2import os.path 

3import numpy as np 

4import pickle 

5 

6from .exp_base import Experiment 

7from ..information.info_base import TimeSeries 

8from ..utils.naming import construct_session_name 

9from ..utils.output import show_output 

10from .neuron import DEFAULT_FPS, DEFAULT_T_OFF, DEFAULT_T_RISE 

11from ..gdrive.download import download_gdrive_data, initialize_iabs_router 

12 

13 

14def load_exp_from_aligned_data(data_source, 

15 exp_params, 

16 data, 

17 force_continuous=[], 

18 bad_frames=[], 

19 static_features=None, 

20 verbose=True, 

21 reconstruct_spikes='wavelet'): 

22 

23 expname = construct_session_name(data_source, exp_params) 

24 adata = copy.deepcopy(data) 

25 key_mapping = {key.lower(): key for key in adata.keys()} 

26 

27 if verbose: 

28 print(f'Building experiment {expname}...') 

29 

30 if 'calcium' in key_mapping: 

31 calcium = adata.pop(key_mapping['calcium']) 

32 else: 

33 raise ValueError('No calcium data found!') 

34 

35 spikes = None 

36 if 'spikes' in key_mapping: 

37 spikes = adata.pop(key_mapping['spikes']) 

38 

39 dyn_features = adata.copy() 

40 

41 def is_garbage(vals): 

42 return len(set(vals)) == 1 or np.sum(np.isnan(vals)).astype(int) == len(vals) 

43 

44 if len(force_continuous) != 0: 

45 feat_is_continuous = {f: f in force_continuous for f in dyn_features.keys()} 

46 filt_dyn_features = {f: TimeSeries(vals, discrete=not feat_is_continuous[f]) for f, vals in dyn_features.items() if not is_garbage(vals)} 

47 else: 

48 filt_dyn_features = {f: TimeSeries(vals) for f, vals in dyn_features.items() if not is_garbage(vals)} 

49 

50 if verbose: 

51 print('behaviour variables:') 

52 print() 

53 for f, ts in filt_dyn_features.items(): 

54 print(f"'{f}'", 'discrete' if ts.discrete else 'continuous') 

55 

56 # check for constant features 

57 constfeats = set(dyn_features.keys() - set(filt_dyn_features.keys())) 

58 

59 if len(constfeats) != 0 and verbose: 

60 print(f'features {constfeats} dropped as constant or empty') 

61 

62 auto_continuous = [fn for fn, ts in filt_dyn_features.items() if not ts.discrete] 

63 if verbose: 

64 print(f'features {auto_continuous} automatically determined as continuous') 

65 print() 

66 

67 if len(force_continuous) != 0: 

68 if set(auto_continuous) != (set(force_continuous) & set(dyn_features.keys())): 

69 print('Warning: auto determined continuous features do not coincide with force_continuous list! Automatic labelling will be overridden') 

70 for fn, ts in filt_dyn_features.items(): 

71 if len(set(ts.data)) > 2 and not(fn in force_continuous): 

72 filt_dyn_features[fn] = TimeSeries(ts.data.astype(bool).astype(int), discrete=True) 

73 if verbose: 

74 print(f'feature {fn} converted to integer') 

75 

76 signature = f'Exp {expname}' 

77 

78 # set default static experiment features if not provided 

79 default_static_features = {'t_rise_sec': DEFAULT_T_RISE, 

80 't_off_sec': DEFAULT_T_OFF, 

81 'fps': DEFAULT_FPS} 

82 

83 if static_features is None: 

84 static_features = dict() 

85 for sf in default_static_features.keys(): 

86 if sf not in static_features: 

87 static_features.update({sf: default_static_features[sf]}) 

88 

89 exp = Experiment(signature, 

90 calcium, 

91 spikes, 

92 exp_params, 

93 static_features, 

94 filt_dyn_features, 

95 reconstruct_spikes=reconstruct_spikes, 

96 bad_frames_mask=np.array([True if _ in bad_frames else False for _ in range(calcium.shape[1])]) 

97 ) 

98 

99 

100 return exp 

101 

102 

103def load_experiment(data_source, 

104 exp_params, 

105 force_rebuild=False, 

106 force_reload=False, 

107 via_pydrive=True, 

108 gauth=None, 

109 root='DRIADA data', 

110 exp_path=None, 

111 data_path=None, 

112 force_continuous=[], 

113 bad_frames=[], 

114 static_features=None, 

115 reconstruct_spikes='wavelet', 

116 save_to_pickle=False, 

117 verbose=True): 

118 

119 os.makedirs(root, exist_ok=True) 

120 if not os.path.isdir(root): 

121 raise ValueError('Root must be a folder!') 

122 

123 if exp_path is None: 

124 expname = construct_session_name(data_source, exp_params) 

125 exp_path = os.path.join(root, expname, f'Exp {expname}.pickle') 

126 

127 if os.path.exists(exp_path) and not force_rebuild and not force_reload: 

128 Exp = load_exp_from_pickle(exp_path, verbose=verbose) 

129 return Exp, None 

130 

131 else: 

132 if data_source == 'IABS': 

133 if data_path is None: 

134 data_path = os.path.join(root, 

135 expname, 

136 'Aligned data', 

137 f'{expname} syn data.npz') 

138 if verbose: 

139 print(f'Path to data: {data_path}') 

140 

141 data_exists = os.path.exists(data_path) 

142 if verbose: 

143 if data_exists: 

144 print('Aligned data for experiment construction found successfully') 

145 else: 

146 print('Failed to locate aligned data for experiment construction') 

147 

148 if force_reload or not data_exists: 

149 if verbose: 

150 print('Loading data from cloud storage...') 

151 data_router, data_pieces = initialize_iabs_router(root=root) 

152 success, load_log = download_gdrive_data(data_router, 

153 expname, 

154 data_pieces=['Aligned data'], 

155 via_pydrive=via_pydrive, 

156 tdir=root, 

157 gauth=gauth) 

158 

159 if not success: 

160 print('=========== BEGINNING OF LOADING LOG ============') 

161 show_output(load_log) 

162 print('=========== END OF LOADING LOG ============') 

163 raise FileNotFoundError(f'Cannot download {expname}, see loading log above') 

164 

165 else: 

166 load_log = None 

167 

168 aligned_data = dict(np.load(data_path)) 

169 Exp = load_exp_from_aligned_data(data_source, 

170 exp_params, 

171 aligned_data, 

172 force_continuous=force_continuous, 

173 static_features=static_features, 

174 verbose=verbose, 

175 bad_frames=bad_frames, 

176 reconstruct_spikes=reconstruct_spikes) 

177 

178 if save_to_pickle: 

179 save_exp_to_pickle(Exp, exp_path, verbose=verbose) 

180 return Exp, load_log 

181 

182 else: 

183 raise ValueError('External data sources are not yet supported') 

184 

185 

186def save_exp_to_pickle(exp, path, verbose=True): 

187 with open(path, "wb") as f: 

188 pickle.dump(exp, f) 

189 if verbose: 

190 print(f'Experiment {exp.signature} saved to {path}\n') 

191 

192 

193def load_exp_from_pickle(path, verbose=True): 

194 with open(path, "rb") as f: 

195 exp = pickle.load(f,) 

196 if verbose: 

197 print(f'Experiment {exp.signature} loaded from {path}\n') 

198 return exp