Coverage for src/driada/experiment/spike_reconstruction.py: 100.00%
48 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1"""
2Spike reconstruction module for DRIADA.
4This module provides functions for reconstructing spike trains from calcium
5imaging data using various methods.
6"""
8import numpy as np
9from typing import Tuple, Dict, Any, Optional, Callable
10from scipy.ndimage import gaussian_filter1d
11from scipy.signal import find_peaks
13from ..information.info_base import TimeSeries, MultiTimeSeries
14from .wavelet_event_detection import (
15 WVT_EVENT_DETECTION_PARAMS,
16 extract_wvt_events,
17 events_to_ts_array,
18 ridges_to_containers
19)
22def reconstruct_spikes(
23 calcium: MultiTimeSeries,
24 method: str = 'wavelet',
25 fps: float = 20.0,
26 params: Optional[Dict[str, Any]] = None
27) -> Tuple[MultiTimeSeries, Dict[str, Any]]:
28 """
29 Reconstruct spike trains from calcium signals.
31 Parameters
32 ----------
33 calcium : MultiTimeSeries
34 Calcium imaging data with each component being a neuron
35 method : str or callable
36 Reconstruction method: 'wavelet', 'threshold', or callable
37 fps : float
38 Sampling rate in frames per second
39 params : dict, optional
40 Method-specific parameters
42 Returns
43 -------
44 spikes : MultiTimeSeries
45 Reconstructed spike trains (discrete)
46 metadata : dict
47 Reconstruction metadata
48 """
49 params = params or {}
51 if callable(method):
52 # Custom method
53 return method(calcium, fps, params)
55 elif method == 'wavelet':
56 return wavelet_reconstruction(calcium, fps, params)
58 elif method == 'threshold':
59 return threshold_reconstruction(calcium, fps, params)
61 else:
62 raise ValueError(
63 f"Unknown method '{method}'. Use 'wavelet', 'threshold', "
64 f"or provide a callable."
65 )
68def wavelet_reconstruction(
69 calcium: MultiTimeSeries,
70 fps: float,
71 params: Dict[str, Any]
72) -> Tuple[MultiTimeSeries, Dict[str, Any]]:
73 """
74 Wavelet-based spike reconstruction.
76 Parameters
77 ----------
78 calcium : MultiTimeSeries
79 Calcium signals
80 fps : float
81 Sampling rate
82 params : dict
83 Wavelet parameters
85 Returns
86 -------
87 spikes : MultiTimeSeries
88 Spike trains
89 metadata : dict
90 Reconstruction metadata
91 """
92 # Get calcium data as numpy array
93 calcium_data = np.asarray(calcium.data) # Convert to numpy array
95 # Set up wavelet parameters
96 wvt_kwargs = WVT_EVENT_DETECTION_PARAMS.copy()
97 wvt_kwargs['fps'] = fps
98 wvt_kwargs.update(params)
100 # Extract events
101 st_ev_inds, end_ev_inds, all_ridges = extract_wvt_events(
102 calcium_data, wvt_kwargs
103 )
105 # Convert to spike array
106 spikes_data = events_to_ts_array(
107 calcium_data.shape[1], st_ev_inds, end_ev_inds, fps
108 )
110 # Create spike MultiTimeSeries
111 spike_ts_list = [
112 TimeSeries(spikes_data[i, :], discrete=True)
113 for i in range(spikes_data.shape[0])
114 ]
115 spikes = MultiTimeSeries(spike_ts_list)
117 # Prepare metadata
118 metadata = {
119 'method': 'wavelet',
120 'parameters': wvt_kwargs,
121 'start_events': st_ev_inds,
122 'end_events': end_ev_inds,
123 'ridges': [ridges_to_containers(ridges) for ridges in all_ridges]
124 }
126 return spikes, metadata
129def threshold_reconstruction(
130 calcium: MultiTimeSeries,
131 fps: float,
132 params: Dict[str, Any]
133) -> Tuple[MultiTimeSeries, Dict[str, Any]]:
134 """
135 Simple threshold-based spike reconstruction.
137 This method detects spikes when the derivative of the calcium signal
138 exceeds a threshold, similar to classical spike detection methods.
140 Parameters
141 ----------
142 calcium : MultiTimeSeries
143 Calcium signals
144 fps : float
145 Sampling rate
146 params : dict
147 Parameters including:
148 - threshold_std : float, number of STDs above mean for detection (default: 2.5)
149 - smooth_sigma : float, gaussian smoothing sigma in frames (default: 2)
150 - min_spike_interval : float, minimum interval between spikes in seconds (default: 0.1)
152 Returns
153 -------
154 spikes : MultiTimeSeries
155 Binary spike trains
156 metadata : dict
157 Reconstruction metadata
158 """
159 # Default parameters
160 threshold_std = params.get('threshold_std', 2.5)
161 smooth_sigma = params.get('smooth_sigma', 2)
162 min_spike_interval = params.get('min_spike_interval', 0.1)
163 min_spike_frames = int(min_spike_interval * fps)
165 calcium_data = np.asarray(calcium.data)
166 n_neurons, n_frames = calcium_data.shape
167 spikes_data = np.zeros_like(calcium_data)
169 all_spike_times = []
171 for i in range(n_neurons):
172 # Get calcium trace
173 trace = calcium_data[i, :]
175 # Smooth the signal
176 smoothed = gaussian_filter1d(trace, sigma=smooth_sigma)
178 # Compute derivative (rate of calcium increase)
179 diff = np.diff(smoothed)
180 diff = np.concatenate([[0], diff]) # Pad to maintain size
182 # Compute threshold
183 threshold = np.mean(diff) + threshold_std * np.std(diff)
185 # Find peaks in derivative
186 peaks, properties = find_peaks(
187 diff,
188 height=threshold,
189 distance=min_spike_frames
190 )
192 # Mark spikes
193 spikes_data[i, peaks] = 1
194 all_spike_times.append(peaks)
196 # Create spike MultiTimeSeries
197 spike_ts_list = [
198 TimeSeries(spikes_data[i, :], discrete=True)
199 for i in range(n_neurons)
200 ]
201 spikes = MultiTimeSeries(spike_ts_list)
203 # Prepare metadata
204 metadata = {
205 'method': 'threshold',
206 'parameters': {
207 'threshold_std': threshold_std,
208 'smooth_sigma': smooth_sigma,
209 'min_spike_interval': min_spike_interval,
210 'fps': fps
211 },
212 'spike_times': all_spike_times
213 }
215 return spikes, metadata