Coverage for src/driada/experiment/wavelet_event_detection.py: 44.23%

156 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1import os 

2from os.path import join, splitext 

3import tqdm 

4import matplotlib.pyplot as plt 

5 

6# Fix scipy compatibility issue for ssqueezepy 

7import scipy.integrate 

8if not hasattr(scipy.integrate, 'trapz'): 

9 scipy.integrate.trapz = scipy.integrate.trapezoid 

10 

11from ssqueezepy import cwt 

12from ssqueezepy.wavelets import Wavelet, time_resolution 

13 

14from scipy.ndimage import gaussian_filter1d 

15from scipy.signal import argrelmax 

16from numba import njit 

17 

18from .wavelet_ridge import * 

19 

20WVT_EVENT_DETECTION_PARAMS = { 

21 'fps': 20, # fps, frames 

22 'sigma': 8, # smoothing parameter for peak detection, frames 

23 'beta': 2, # Generalized Morse Wavelet parameter, FIXED 

24 'gamma': 3, # Generalized Morse Wavelet parameter, FIXED 

25 'eps': 10, # spacing between consecutive events, frames 

26 'manual_scales': np.logspace(2.5,5.5,50, base=2), 

27 

28 # ridge filtering params 

29 'scale_length_thr': 40, # min number of scales where ridge is present thr, higher = less events. max=len(manual_scales) 

30 'max_scale_thr': 7, # index of a scale with max ridge intensity thr, higher = less events. < 5 = noise, > 20 = huge events 

31 'max_ampl_thr': 0.05, # max ridge intensity thr, higher = less events. < 5 = noise, > 20 = huge events 

32 'max_dur_thr': 200, # max event duration thr, higher = more events (but probably strange ones) 

33} 

34MIN_EVENT_DUR = 0.5 # sec 

35MAX_EVENT_DUR = 2.5 # sec 

36 

37 

38def wvt_viz(x, Wx): 

39 fig, axs = plt.subplots(2, 1, figsize=(12,12)) 

40 axs[0].set_xlim(0, len(x)) 

41 axs[0].plot(x, c='b') 

42 axs[1].imshow(np.abs(Wx), aspect='auto', cmap='turbo') 

43 

44 

45def get_cwt_ridges(sig, wavelet=None, fps=20, scmin=150, scmax=250, all_wvt_times=None, wvt_scales=None): 

46 if wvt_scales is not None: 

47 scales = wvt_scales 

48 else: 

49 scales = 'log-piecewise' 

50 W, wvt_scales = cwt(sig, wavelet=wavelet, fs=fps, scales=scales) 

51 

52 #wvtdata = np.real(np.abs(W)) 

53 wvtdata = np.real(W) 

54 scale_inds = np.arange(scmin, scmax)[::-1] 

55 if all_wvt_times is None: 

56 all_wvt_times = [time_resolution(wavelet, scale=wvt_scales[sc], nondim=False, min_decay=200) for sc in scale_inds] 

57 

58 # determine peak positions for all scales 

59 peaks = np.zeros((len(scale_inds), len(sig))) 

60 

61 all_ridges = [] 

62 for i, si in enumerate(scale_inds[:]): 

63 wvt_time = all_wvt_times[i] 

64 max_inds = argrelmax(wvtdata[si,:], order=10)[0] 

65 peaks[i, max_inds] = wvtdata[si, max_inds] 

66 #max_inds = np.nonzero(peaks[i,:])[0] 

67 #print(peaks[i, max_inds]) 

68 

69 if len(all_ridges) == 0: 

70 all_ridges = [Ridge(mi, peaks[i, mi], wvt_scales[si], wvt_time) for mi in max_inds] 

71 else: 

72 # 1. extend old ridges 

73 prev_wvt_time = all_wvt_times[i-1] 

74 live_ridges = [ridge for ridge in all_ridges if not ridge.terminated] 

75 maxima_used_for_prolongation = [] 

76 

77 for ridge in live_ridges: 

78 # 1.1 get ridge tip from previous scale 

79 last_max_index = ridge.tip() 

80 # 1.2 compute time window based on 68% of wavelet energy 

81 wlb, wrb = last_max_index - prev_wvt_time, last_max_index + prev_wvt_time 

82 # 1.3 get list of candidate maxima of the current scale falling into the window 

83 candidates = [mi for mi in max_inds if (mi > wlb) and (mi < wrb)] 

84 # 1.4 extending ridges 

85 if len(candidates) == 0: 

86 # gaps lead to ridge termination 

87 #print(f'ridge with start time {ridge.indices[0]} terminated') 

88 ridge.terminate() 

89 elif len(candidates) == 1: 

90 # extend ridge 

91 cand = candidates[0] 

92 ridge.extend(cand, peaks[i, cand], wvt_scales[si], wvt_time) 

93 maxima_used_for_prolongation.append(cand) 

94 #print(f'ridge with start time {ridge.indices[0]} extended') 

95 else: 

96 # extend ridge with the best maximum, others will later form new ridges 

97 best_cand = candidates[np.argmax(peaks[i, np.array(candidates)])] 

98 ridge.extend(best_cand, peaks[i, best_cand], wvt_scales[si], wvt_time) 

99 maxima_used_for_prolongation.append(best_cand) 

100 #maxima_used_for_prolongation.extend(candidates) 

101 

102 # 2. generate new ridges 

103 new_ridges = [Ridge(mi, peaks[i, mi], wvt_scales[si], wvt_time) for mi in max_inds if mi not in maxima_used_for_prolongation] 

104 all_ridges.extend(new_ridges) 

105 

106 for r in all_ridges: 

107 r.terminate() 

108 

109 return all_ridges 

110 

111 

112# TODO: add support for numba >0.59.0 or "numba_acceleration" flag 

113@njit() 

114def get_cwt_ridges_fast(wvtdata, peaks, wvt_times, wvt_scales): 

115 # determine peak positions for all scales 

116 

117 start = True 

118 for si in range(wvtdata.shape[0]): 

119 wvt_time = wvt_times[si] 

120 max_inds = np.nonzero(peaks[si,:])[0] 

121 

122 if start: 

123 all_ridges = [Ridge(mi, peaks[si, mi], wvt_scales[si], wvt_time) for mi in max_inds] 

124 start = False 

125 else: 

126 # 1. extend old ridges 

127 prev_wvt_time = wvt_times[si-1] 

128 live_ridges = [ridge for ridge in all_ridges if not ridge.terminated] 

129 maxima_used_for_prolongation = [] 

130 

131 for ridge in live_ridges: 

132 # 1.1 get ridge tip from previous scale 

133 last_max_index = ridge.tip() 

134 # 1.2 compute time window based on 68% of wavelet energy 

135 wlb, wrb = last_max_index - prev_wvt_time, last_max_index + prev_wvt_time 

136 # 1.3 get list of candidate maxima of the current scale falling into the window 

137 candidates = [mi for mi in max_inds if (mi > wlb) and (mi < wrb)] 

138 # 1.4 extending ridges 

139 if len(candidates) == 0: 

140 # gaps lead to ridge termination 

141 #print(f'ridge with start time {ridge.indices[0]} terminated') 

142 ridge.terminate() 

143 elif len(candidates) == 1: 

144 # extend ridge 

145 cand = candidates[0] 

146 ridge.extend(cand, peaks[si, cand], wvt_scales[si], wvt_time) 

147 maxima_used_for_prolongation.append(cand) 

148 #print(f'ridge with start time {ridge.indices[0]} extended') 

149 else: 

150 # extend ridge with the best maximum, others will later form new ridges 

151 best_cand = candidates[np.argmax(peaks[si, np.array(candidates)])] 

152 ridge.extend(best_cand, peaks[si, best_cand], wvt_scales[si], wvt_time) 

153 maxima_used_for_prolongation.append(best_cand) 

154 #maxima_used_for_prolongation.extend(candidates) 

155 

156 # 2. generate new ridges 

157 new_ridges = [Ridge(mi, peaks[si, mi], wvt_scales[si], wvt_time) for mi in max_inds if mi not in maxima_used_for_prolongation] 

158 # Use += instead of extend() to fix Numba 0.60+ type inference issue 

159 all_ridges += new_ridges 

160 

161 for r in all_ridges: 

162 r.terminate() 

163 

164 return all_ridges 

165 

166 

167def passing_criterion(ridge, scale_length_thr=40, max_scale_thr=10, max_ampl_thr=0.05, max_dur_thr=100): 

168 crit = ridge.length >= scale_length_thr and ridge.max_scale >= max_scale_thr and ridge.max_ampl >= max_ampl_thr and ridge.duration <= max_dur_thr 

169 return crit 

170 

171 

172def get_events_from_ridges(all_ridges, scale_length_thr=40, max_scale_thr=10, max_ampl_thr=0.05, max_dur_thr=100): 

173 event_ridges = [r for r in all_ridges if passing_criterion(r, 

174 scale_length_thr=scale_length_thr, 

175 max_scale_thr=max_scale_thr, 

176 max_ampl_thr=max_ampl_thr, 

177 max_dur_thr=max_dur_thr)] 

178 

179 st_evinds = [r.indices[0] for r in event_ridges] 

180 end_evinds = [r.indices[-1] for r in event_ridges] 

181 return st_evinds, end_evinds 

182 

183 

184def events_from_trace(trace, wavelet, manual_scales, rel_wvt_times, 

185 fps=20, sigma=8, eps=10, 

186 scale_length_thr=40, 

187 max_scale_thr=7, 

188 max_ampl_thr=0.05, 

189 max_dur_thr=200): 

190 

191 trace = (trace - min(trace))/(max(trace) - min(trace)) 

192 sig = gaussian_filter1d(trace, sigma=sigma) 

193 

194 W, wvt_scales = cwt(sig, wavelet=wavelet, fs=fps, scales=manual_scales) 

195 rev_wvtdata = np.real(W) 

196 

197 all_max_inds = argrelmax(rev_wvtdata, axis=1, order=eps) 

198 peaks = np.zeros(rev_wvtdata.shape) 

199 peaks[all_max_inds] = rev_wvtdata[all_max_inds] 

200 

201 all_ridges = get_cwt_ridges_fast(rev_wvtdata, peaks, rel_wvt_times, manual_scales) 

202 

203 st_evinds, end_evinds = get_events_from_ridges(all_ridges, 

204 scale_length_thr=scale_length_thr, 

205 max_scale_thr=max_scale_thr, 

206 max_ampl_thr=max_ampl_thr, 

207 max_dur_thr=max_dur_thr) 

208 

209 return all_ridges, st_evinds, end_evinds 

210 

211 

212def extract_wvt_events(traces, wvt_kwargs): 

213 fps = wvt_kwargs.get('fps', 20) 

214 beta = wvt_kwargs.get('beta', 2) 

215 gamma = wvt_kwargs.get('gamma', 3) 

216 sigma = wvt_kwargs.get('sigma', 8) 

217 eps = wvt_kwargs.get('eps', 10) 

218 manual_scales = wvt_kwargs.get('manual_scales', np.logspace(2.5,5.5,50, base=2)) 

219 

220 scale_length_thr = wvt_kwargs.get('scale_length_thr', 40) 

221 max_scale_thr = wvt_kwargs.get('max_scale_thr', 7) 

222 max_ampl_thr = wvt_kwargs.get('max_ampl_thr', 0.05) 

223 max_dur_thr = wvt_kwargs.get('max_dur_thr', 200) 

224 

225 wavelet = Wavelet(('gmw', {'gamma': gamma, 'beta': beta, 'centered_scale': True}), N=8196) 

226 

227 rel_wvt_times = [time_resolution(wavelet, scale=sc, nondim=False, min_decay=200) for sc in manual_scales] 

228 

229 st_ev_inds = [] 

230 end_ev_inds = [] 

231 all_ridges = [] 

232 for i, trace in tqdm.tqdm(enumerate(traces), total=len(traces)): 

233 ridges, st_ev, end_ev = events_from_trace(trace, 

234 wavelet, 

235 manual_scales, 

236 rel_wvt_times, 

237 fps=fps, 

238 sigma=sigma, 

239 eps=eps, 

240 scale_length_thr=scale_length_thr, 

241 max_scale_thr=max_scale_thr, 

242 max_ampl_thr=max_ampl_thr, 

243 max_dur_thr=max_dur_thr) 

244 

245 st_ev_inds.append(st_ev) 

246 end_ev_inds.append(end_ev) 

247 all_ridges.append(ridges) 

248 

249 return st_ev_inds, end_ev_inds, all_ridges 

250 

251 

252@njit 

253def events_to_ts_array_numba(length, ncells, st_ev_inds_flat, end_ev_inds_flat, event_counts, fps, min_event_dur, max_event_dur): 

254 """Numba-optimized version of events_to_ts_array.""" 

255 spikes = np.zeros((ncells, length)) 

256 

257 mindur = int(min_event_dur * fps) 

258 maxdur = int(max_event_dur * fps) 

259 

260 event_idx = 0 

261 for i in range(ncells): 

262 for j in range(event_counts[i]): 

263 start = int(st_ev_inds_flat[event_idx]) 

264 end = int(end_ev_inds_flat[event_idx]) 

265 start_, end_ = min(start, end), max(start, end) 

266 dur = end_ - start_ 

267 if mindur <= dur <= maxdur: 

268 spikes[i, start_: end_] = 1 

269 elif dur > maxdur: 

270 spikes[i, start_: start_ + maxdur] = 1 

271 else: 

272 middle = (start_ + end_)//2 

273 spikes[i, int(middle - mindur//2): int(middle + mindur//2)] = 1 

274 event_idx += 1 

275 

276 return spikes 

277 

278 

279def events_to_ts_array(length, st_ev_inds, end_ev_inds, fps): 

280 """Convert event indices to time series array with spike trains.""" 

281 ncells = len(end_ev_inds) 

282 

283 # Flatten the jagged arrays for numba 

284 event_counts = np.array([len(st_ev_inds[i]) for i in range(ncells)]) 

285 st_ev_inds_flat = np.concatenate([st_ev_inds[i] for i in range(ncells)]) 

286 end_ev_inds_flat = np.concatenate([end_ev_inds[i] for i in range(ncells)]) 

287 

288 # Call numba function 

289 return events_to_ts_array_numba(length, ncells, st_ev_inds_flat, end_ev_inds_flat, 

290 event_counts, fps, MIN_EVENT_DUR, MAX_EVENT_DUR)