Coverage for src/driada/experiment/wavelet_event_detection.py: 44.23%
156 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1import os
2from os.path import join, splitext
3import tqdm
4import matplotlib.pyplot as plt
6# Fix scipy compatibility issue for ssqueezepy
7import scipy.integrate
8if not hasattr(scipy.integrate, 'trapz'):
9 scipy.integrate.trapz = scipy.integrate.trapezoid
11from ssqueezepy import cwt
12from ssqueezepy.wavelets import Wavelet, time_resolution
14from scipy.ndimage import gaussian_filter1d
15from scipy.signal import argrelmax
16from numba import njit
18from .wavelet_ridge import *
20WVT_EVENT_DETECTION_PARAMS = {
21 'fps': 20, # fps, frames
22 'sigma': 8, # smoothing parameter for peak detection, frames
23 'beta': 2, # Generalized Morse Wavelet parameter, FIXED
24 'gamma': 3, # Generalized Morse Wavelet parameter, FIXED
25 'eps': 10, # spacing between consecutive events, frames
26 'manual_scales': np.logspace(2.5,5.5,50, base=2),
28 # ridge filtering params
29 'scale_length_thr': 40, # min number of scales where ridge is present thr, higher = less events. max=len(manual_scales)
30 'max_scale_thr': 7, # index of a scale with max ridge intensity thr, higher = less events. < 5 = noise, > 20 = huge events
31 'max_ampl_thr': 0.05, # max ridge intensity thr, higher = less events. < 5 = noise, > 20 = huge events
32 'max_dur_thr': 200, # max event duration thr, higher = more events (but probably strange ones)
33}
34MIN_EVENT_DUR = 0.5 # sec
35MAX_EVENT_DUR = 2.5 # sec
38def wvt_viz(x, Wx):
39 fig, axs = plt.subplots(2, 1, figsize=(12,12))
40 axs[0].set_xlim(0, len(x))
41 axs[0].plot(x, c='b')
42 axs[1].imshow(np.abs(Wx), aspect='auto', cmap='turbo')
45def get_cwt_ridges(sig, wavelet=None, fps=20, scmin=150, scmax=250, all_wvt_times=None, wvt_scales=None):
46 if wvt_scales is not None:
47 scales = wvt_scales
48 else:
49 scales = 'log-piecewise'
50 W, wvt_scales = cwt(sig, wavelet=wavelet, fs=fps, scales=scales)
52 #wvtdata = np.real(np.abs(W))
53 wvtdata = np.real(W)
54 scale_inds = np.arange(scmin, scmax)[::-1]
55 if all_wvt_times is None:
56 all_wvt_times = [time_resolution(wavelet, scale=wvt_scales[sc], nondim=False, min_decay=200) for sc in scale_inds]
58 # determine peak positions for all scales
59 peaks = np.zeros((len(scale_inds), len(sig)))
61 all_ridges = []
62 for i, si in enumerate(scale_inds[:]):
63 wvt_time = all_wvt_times[i]
64 max_inds = argrelmax(wvtdata[si,:], order=10)[0]
65 peaks[i, max_inds] = wvtdata[si, max_inds]
66 #max_inds = np.nonzero(peaks[i,:])[0]
67 #print(peaks[i, max_inds])
69 if len(all_ridges) == 0:
70 all_ridges = [Ridge(mi, peaks[i, mi], wvt_scales[si], wvt_time) for mi in max_inds]
71 else:
72 # 1. extend old ridges
73 prev_wvt_time = all_wvt_times[i-1]
74 live_ridges = [ridge for ridge in all_ridges if not ridge.terminated]
75 maxima_used_for_prolongation = []
77 for ridge in live_ridges:
78 # 1.1 get ridge tip from previous scale
79 last_max_index = ridge.tip()
80 # 1.2 compute time window based on 68% of wavelet energy
81 wlb, wrb = last_max_index - prev_wvt_time, last_max_index + prev_wvt_time
82 # 1.3 get list of candidate maxima of the current scale falling into the window
83 candidates = [mi for mi in max_inds if (mi > wlb) and (mi < wrb)]
84 # 1.4 extending ridges
85 if len(candidates) == 0:
86 # gaps lead to ridge termination
87 #print(f'ridge with start time {ridge.indices[0]} terminated')
88 ridge.terminate()
89 elif len(candidates) == 1:
90 # extend ridge
91 cand = candidates[0]
92 ridge.extend(cand, peaks[i, cand], wvt_scales[si], wvt_time)
93 maxima_used_for_prolongation.append(cand)
94 #print(f'ridge with start time {ridge.indices[0]} extended')
95 else:
96 # extend ridge with the best maximum, others will later form new ridges
97 best_cand = candidates[np.argmax(peaks[i, np.array(candidates)])]
98 ridge.extend(best_cand, peaks[i, best_cand], wvt_scales[si], wvt_time)
99 maxima_used_for_prolongation.append(best_cand)
100 #maxima_used_for_prolongation.extend(candidates)
102 # 2. generate new ridges
103 new_ridges = [Ridge(mi, peaks[i, mi], wvt_scales[si], wvt_time) for mi in max_inds if mi not in maxima_used_for_prolongation]
104 all_ridges.extend(new_ridges)
106 for r in all_ridges:
107 r.terminate()
109 return all_ridges
112# TODO: add support for numba >0.59.0 or "numba_acceleration" flag
113@njit()
114def get_cwt_ridges_fast(wvtdata, peaks, wvt_times, wvt_scales):
115 # determine peak positions for all scales
117 start = True
118 for si in range(wvtdata.shape[0]):
119 wvt_time = wvt_times[si]
120 max_inds = np.nonzero(peaks[si,:])[0]
122 if start:
123 all_ridges = [Ridge(mi, peaks[si, mi], wvt_scales[si], wvt_time) for mi in max_inds]
124 start = False
125 else:
126 # 1. extend old ridges
127 prev_wvt_time = wvt_times[si-1]
128 live_ridges = [ridge for ridge in all_ridges if not ridge.terminated]
129 maxima_used_for_prolongation = []
131 for ridge in live_ridges:
132 # 1.1 get ridge tip from previous scale
133 last_max_index = ridge.tip()
134 # 1.2 compute time window based on 68% of wavelet energy
135 wlb, wrb = last_max_index - prev_wvt_time, last_max_index + prev_wvt_time
136 # 1.3 get list of candidate maxima of the current scale falling into the window
137 candidates = [mi for mi in max_inds if (mi > wlb) and (mi < wrb)]
138 # 1.4 extending ridges
139 if len(candidates) == 0:
140 # gaps lead to ridge termination
141 #print(f'ridge with start time {ridge.indices[0]} terminated')
142 ridge.terminate()
143 elif len(candidates) == 1:
144 # extend ridge
145 cand = candidates[0]
146 ridge.extend(cand, peaks[si, cand], wvt_scales[si], wvt_time)
147 maxima_used_for_prolongation.append(cand)
148 #print(f'ridge with start time {ridge.indices[0]} extended')
149 else:
150 # extend ridge with the best maximum, others will later form new ridges
151 best_cand = candidates[np.argmax(peaks[si, np.array(candidates)])]
152 ridge.extend(best_cand, peaks[si, best_cand], wvt_scales[si], wvt_time)
153 maxima_used_for_prolongation.append(best_cand)
154 #maxima_used_for_prolongation.extend(candidates)
156 # 2. generate new ridges
157 new_ridges = [Ridge(mi, peaks[si, mi], wvt_scales[si], wvt_time) for mi in max_inds if mi not in maxima_used_for_prolongation]
158 # Use += instead of extend() to fix Numba 0.60+ type inference issue
159 all_ridges += new_ridges
161 for r in all_ridges:
162 r.terminate()
164 return all_ridges
167def passing_criterion(ridge, scale_length_thr=40, max_scale_thr=10, max_ampl_thr=0.05, max_dur_thr=100):
168 crit = ridge.length >= scale_length_thr and ridge.max_scale >= max_scale_thr and ridge.max_ampl >= max_ampl_thr and ridge.duration <= max_dur_thr
169 return crit
172def get_events_from_ridges(all_ridges, scale_length_thr=40, max_scale_thr=10, max_ampl_thr=0.05, max_dur_thr=100):
173 event_ridges = [r for r in all_ridges if passing_criterion(r,
174 scale_length_thr=scale_length_thr,
175 max_scale_thr=max_scale_thr,
176 max_ampl_thr=max_ampl_thr,
177 max_dur_thr=max_dur_thr)]
179 st_evinds = [r.indices[0] for r in event_ridges]
180 end_evinds = [r.indices[-1] for r in event_ridges]
181 return st_evinds, end_evinds
184def events_from_trace(trace, wavelet, manual_scales, rel_wvt_times,
185 fps=20, sigma=8, eps=10,
186 scale_length_thr=40,
187 max_scale_thr=7,
188 max_ampl_thr=0.05,
189 max_dur_thr=200):
191 trace = (trace - min(trace))/(max(trace) - min(trace))
192 sig = gaussian_filter1d(trace, sigma=sigma)
194 W, wvt_scales = cwt(sig, wavelet=wavelet, fs=fps, scales=manual_scales)
195 rev_wvtdata = np.real(W)
197 all_max_inds = argrelmax(rev_wvtdata, axis=1, order=eps)
198 peaks = np.zeros(rev_wvtdata.shape)
199 peaks[all_max_inds] = rev_wvtdata[all_max_inds]
201 all_ridges = get_cwt_ridges_fast(rev_wvtdata, peaks, rel_wvt_times, manual_scales)
203 st_evinds, end_evinds = get_events_from_ridges(all_ridges,
204 scale_length_thr=scale_length_thr,
205 max_scale_thr=max_scale_thr,
206 max_ampl_thr=max_ampl_thr,
207 max_dur_thr=max_dur_thr)
209 return all_ridges, st_evinds, end_evinds
212def extract_wvt_events(traces, wvt_kwargs):
213 fps = wvt_kwargs.get('fps', 20)
214 beta = wvt_kwargs.get('beta', 2)
215 gamma = wvt_kwargs.get('gamma', 3)
216 sigma = wvt_kwargs.get('sigma', 8)
217 eps = wvt_kwargs.get('eps', 10)
218 manual_scales = wvt_kwargs.get('manual_scales', np.logspace(2.5,5.5,50, base=2))
220 scale_length_thr = wvt_kwargs.get('scale_length_thr', 40)
221 max_scale_thr = wvt_kwargs.get('max_scale_thr', 7)
222 max_ampl_thr = wvt_kwargs.get('max_ampl_thr', 0.05)
223 max_dur_thr = wvt_kwargs.get('max_dur_thr', 200)
225 wavelet = Wavelet(('gmw', {'gamma': gamma, 'beta': beta, 'centered_scale': True}), N=8196)
227 rel_wvt_times = [time_resolution(wavelet, scale=sc, nondim=False, min_decay=200) for sc in manual_scales]
229 st_ev_inds = []
230 end_ev_inds = []
231 all_ridges = []
232 for i, trace in tqdm.tqdm(enumerate(traces), total=len(traces)):
233 ridges, st_ev, end_ev = events_from_trace(trace,
234 wavelet,
235 manual_scales,
236 rel_wvt_times,
237 fps=fps,
238 sigma=sigma,
239 eps=eps,
240 scale_length_thr=scale_length_thr,
241 max_scale_thr=max_scale_thr,
242 max_ampl_thr=max_ampl_thr,
243 max_dur_thr=max_dur_thr)
245 st_ev_inds.append(st_ev)
246 end_ev_inds.append(end_ev)
247 all_ridges.append(ridges)
249 return st_ev_inds, end_ev_inds, all_ridges
252@njit
253def events_to_ts_array_numba(length, ncells, st_ev_inds_flat, end_ev_inds_flat, event_counts, fps, min_event_dur, max_event_dur):
254 """Numba-optimized version of events_to_ts_array."""
255 spikes = np.zeros((ncells, length))
257 mindur = int(min_event_dur * fps)
258 maxdur = int(max_event_dur * fps)
260 event_idx = 0
261 for i in range(ncells):
262 for j in range(event_counts[i]):
263 start = int(st_ev_inds_flat[event_idx])
264 end = int(end_ev_inds_flat[event_idx])
265 start_, end_ = min(start, end), max(start, end)
266 dur = end_ - start_
267 if mindur <= dur <= maxdur:
268 spikes[i, start_: end_] = 1
269 elif dur > maxdur:
270 spikes[i, start_: start_ + maxdur] = 1
271 else:
272 middle = (start_ + end_)//2
273 spikes[i, int(middle - mindur//2): int(middle + mindur//2)] = 1
274 event_idx += 1
276 return spikes
279def events_to_ts_array(length, st_ev_inds, end_ev_inds, fps):
280 """Convert event indices to time series array with spike trains."""
281 ncells = len(end_ev_inds)
283 # Flatten the jagged arrays for numba
284 event_counts = np.array([len(st_ev_inds[i]) for i in range(ncells)])
285 st_ev_inds_flat = np.concatenate([st_ev_inds[i] for i in range(ncells)])
286 end_ev_inds_flat = np.concatenate([end_ev_inds[i] for i in range(ncells)])
288 # Call numba function
289 return events_to_ts_array_numba(length, ncells, st_ev_inds_flat, end_ev_inds_flat,
290 event_counts, fps, MIN_EVENT_DUR, MAX_EVENT_DUR)