Coverage for src/driada/experiment/spike_reconstruction.py: 100.00%

48 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1""" 

2Spike reconstruction module for DRIADA. 

3 

4This module provides functions for reconstructing spike trains from calcium 

5imaging data using various methods. 

6""" 

7 

8import numpy as np 

9from typing import Tuple, Dict, Any, Optional, Callable 

10from scipy.ndimage import gaussian_filter1d 

11from scipy.signal import find_peaks 

12 

13from ..information.info_base import TimeSeries, MultiTimeSeries 

14from .wavelet_event_detection import ( 

15 WVT_EVENT_DETECTION_PARAMS, 

16 extract_wvt_events, 

17 events_to_ts_array, 

18 ridges_to_containers 

19) 

20 

21 

22def reconstruct_spikes( 

23 calcium: MultiTimeSeries, 

24 method: str = 'wavelet', 

25 fps: float = 20.0, 

26 params: Optional[Dict[str, Any]] = None 

27) -> Tuple[MultiTimeSeries, Dict[str, Any]]: 

28 """ 

29 Reconstruct spike trains from calcium signals. 

30  

31 Parameters 

32 ---------- 

33 calcium : MultiTimeSeries 

34 Calcium imaging data with each component being a neuron 

35 method : str or callable 

36 Reconstruction method: 'wavelet', 'threshold', or callable 

37 fps : float 

38 Sampling rate in frames per second 

39 params : dict, optional 

40 Method-specific parameters 

41  

42 Returns 

43 ------- 

44 spikes : MultiTimeSeries 

45 Reconstructed spike trains (discrete) 

46 metadata : dict 

47 Reconstruction metadata 

48 """ 

49 params = params or {} 

50 

51 if callable(method): 

52 # Custom method 

53 return method(calcium, fps, params) 

54 

55 elif method == 'wavelet': 

56 return wavelet_reconstruction(calcium, fps, params) 

57 

58 elif method == 'threshold': 

59 return threshold_reconstruction(calcium, fps, params) 

60 

61 else: 

62 raise ValueError( 

63 f"Unknown method '{method}'. Use 'wavelet', 'threshold', " 

64 f"or provide a callable." 

65 ) 

66 

67 

68def wavelet_reconstruction( 

69 calcium: MultiTimeSeries, 

70 fps: float, 

71 params: Dict[str, Any] 

72) -> Tuple[MultiTimeSeries, Dict[str, Any]]: 

73 """ 

74 Wavelet-based spike reconstruction. 

75  

76 Parameters 

77 ---------- 

78 calcium : MultiTimeSeries 

79 Calcium signals 

80 fps : float 

81 Sampling rate 

82 params : dict 

83 Wavelet parameters 

84  

85 Returns 

86 ------- 

87 spikes : MultiTimeSeries 

88 Spike trains 

89 metadata : dict 

90 Reconstruction metadata 

91 """ 

92 # Get calcium data as numpy array 

93 calcium_data = np.asarray(calcium.data) # Convert to numpy array 

94 

95 # Set up wavelet parameters 

96 wvt_kwargs = WVT_EVENT_DETECTION_PARAMS.copy() 

97 wvt_kwargs['fps'] = fps 

98 wvt_kwargs.update(params) 

99 

100 # Extract events 

101 st_ev_inds, end_ev_inds, all_ridges = extract_wvt_events( 

102 calcium_data, wvt_kwargs 

103 ) 

104 

105 # Convert to spike array 

106 spikes_data = events_to_ts_array( 

107 calcium_data.shape[1], st_ev_inds, end_ev_inds, fps 

108 ) 

109 

110 # Create spike MultiTimeSeries 

111 spike_ts_list = [ 

112 TimeSeries(spikes_data[i, :], discrete=True) 

113 for i in range(spikes_data.shape[0]) 

114 ] 

115 spikes = MultiTimeSeries(spike_ts_list) 

116 

117 # Prepare metadata 

118 metadata = { 

119 'method': 'wavelet', 

120 'parameters': wvt_kwargs, 

121 'start_events': st_ev_inds, 

122 'end_events': end_ev_inds, 

123 'ridges': [ridges_to_containers(ridges) for ridges in all_ridges] 

124 } 

125 

126 return spikes, metadata 

127 

128 

129def threshold_reconstruction( 

130 calcium: MultiTimeSeries, 

131 fps: float, 

132 params: Dict[str, Any] 

133) -> Tuple[MultiTimeSeries, Dict[str, Any]]: 

134 """ 

135 Simple threshold-based spike reconstruction. 

136  

137 This method detects spikes when the derivative of the calcium signal 

138 exceeds a threshold, similar to classical spike detection methods. 

139  

140 Parameters 

141 ---------- 

142 calcium : MultiTimeSeries 

143 Calcium signals 

144 fps : float 

145 Sampling rate 

146 params : dict 

147 Parameters including: 

148 - threshold_std : float, number of STDs above mean for detection (default: 2.5) 

149 - smooth_sigma : float, gaussian smoothing sigma in frames (default: 2) 

150 - min_spike_interval : float, minimum interval between spikes in seconds (default: 0.1) 

151  

152 Returns 

153 ------- 

154 spikes : MultiTimeSeries 

155 Binary spike trains 

156 metadata : dict 

157 Reconstruction metadata 

158 """ 

159 # Default parameters 

160 threshold_std = params.get('threshold_std', 2.5) 

161 smooth_sigma = params.get('smooth_sigma', 2) 

162 min_spike_interval = params.get('min_spike_interval', 0.1) 

163 min_spike_frames = int(min_spike_interval * fps) 

164 

165 calcium_data = np.asarray(calcium.data) 

166 n_neurons, n_frames = calcium_data.shape 

167 spikes_data = np.zeros_like(calcium_data) 

168 

169 all_spike_times = [] 

170 

171 for i in range(n_neurons): 

172 # Get calcium trace 

173 trace = calcium_data[i, :] 

174 

175 # Smooth the signal 

176 smoothed = gaussian_filter1d(trace, sigma=smooth_sigma) 

177 

178 # Compute derivative (rate of calcium increase) 

179 diff = np.diff(smoothed) 

180 diff = np.concatenate([[0], diff]) # Pad to maintain size 

181 

182 # Compute threshold 

183 threshold = np.mean(diff) + threshold_std * np.std(diff) 

184 

185 # Find peaks in derivative 

186 peaks, properties = find_peaks( 

187 diff, 

188 height=threshold, 

189 distance=min_spike_frames 

190 ) 

191 

192 # Mark spikes 

193 spikes_data[i, peaks] = 1 

194 all_spike_times.append(peaks) 

195 

196 # Create spike MultiTimeSeries 

197 spike_ts_list = [ 

198 TimeSeries(spikes_data[i, :], discrete=True) 

199 for i in range(n_neurons) 

200 ] 

201 spikes = MultiTimeSeries(spike_ts_list) 

202 

203 # Prepare metadata 

204 metadata = { 

205 'method': 'threshold', 

206 'parameters': { 

207 'threshold_std': threshold_std, 

208 'smooth_sigma': smooth_sigma, 

209 'min_spike_interval': min_spike_interval, 

210 'fps': fps 

211 }, 

212 'spike_times': all_spike_times 

213 } 

214 

215 return spikes, metadata