Coverage for src/driada/experiment/neuron.py: 37.89%
161 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1import numpy as np
2from numba import njit
3from scipy.stats import median_abs_deviation
4from scipy.optimize import minimize
5from ..information.info_base import TimeSeries
6from .wavelet_event_detection import *
8DEFAULT_T_RISE = 0.25 #sec
9DEFAULT_T_OFF = 2.0 #sec
11DEFAULT_FPS = 20.0 #frames per sec
12DEFAULT_MIN_BEHAVIOUR_TIME = 0.25 #sec
14MIN_CA_SHIFT = 5 # MIN_SHIFT*t_off is the minimal random signal shift for a given cell
16#TODO: add numba decorators where possible
17class Neuron():
18 """
19 Class for representing all information about a single neuron.
21 Attributes
22 ----------
23 test: str
24 description
26 Methods
27 -------
28 test(arg=None)
29 description
30 """
32 @staticmethod
33 @njit()
34 def spike_form(t, t_rise, t_off):
35 form = (1-np.exp(-t/t_rise))*np.exp(-t/t_off)
36 return form/max(form)
39 @staticmethod
40 def get_restored_calcium(sp, t_rise, t_off):
41 x = np.linspace(0, 1000, num=1000)
42 spform = Neuron.spike_form(x, t_rise=t_rise, t_off=t_off)
43 conv = np.convolve(sp, spform)
44 return conv
46 @staticmethod
47 def ca_mse_error(t_off, ca, spk, t_rise):
48 # TODO: fix for new spike format
49 re_ca = Neuron.get_restored_calcium(spk, t_rise, t_off)
50 return np.sqrt(np.sum(np.abs(ca - re_ca[:len(ca)])**2)/len(ca))
52 @staticmethod
53 def calcium_preprocessing(ca):
54 ca[np.where(ca < 0)[0]] = 0
55 #ca = ca + np.abs(min(ca))
56 ca += np.random.random(size=len(ca))*1e-8
57 return ca
59 def __init__(self, cell_id, ca, sp,
60 default_t_rise=DEFAULT_T_RISE, default_t_off=DEFAULT_T_OFF, fps=DEFAULT_FPS,
61 fit_individual_t_off=False):
63 if default_t_rise is None:
64 default_t_rise = DEFAULT_T_RISE
65 if default_t_off is None:
66 default_t_off = DEFAULT_T_OFF
67 if fps is None:
68 fps = DEFAULT_FPS
70 self.cell_id = cell_id
71 self.ca = TimeSeries(Neuron.calcium_preprocessing(ca), discrete=False)
72 if sp is None:
73 self.sp = None
74 else:
75 self.sp = TimeSeries(sp.astype(int), discrete=True)
76 self.n_frames = len(self.ca.data)
78 self.sp_count = np.sum(self.sp.data.astype(bool).astype(int)) if self.sp is not None else 0
79 self.t_off = None
80 self.noise_ampl = None
81 self.mad = None
82 self.snr = None
84 self.default_t_off = default_t_off*fps
85 self.default_t_rise = default_t_rise*fps
87 if fit_individual_t_off:
88 t_off = self.get_t_off()
89 else:
90 t_off = self.default_t_off
92 # add shuffle mask according to computed characteristic calcium decay time
93 self.ca.shuffle_mask = np.ones(self.n_frames).astype(bool)
94 min_shift = int(t_off * MIN_CA_SHIFT)
95 self.ca.shuffle_mask[:min_shift] = False
96 self.ca.shuffle_mask[self.n_frames - min_shift:] = False
98 def reconstruct_spikes(self, **kwargs):
99 raise AttributeError('Spike reconstruction not implemented')
101 def get_mad(self):
102 if self.mad is None:
103 try:
104 self.snr, self.mad = self._calc_snr()
105 except ValueError:
106 self.mad = median_abs_deviation(self.ca.data)
107 return self.mad
109 def get_snr(self):
110 if self.snr is None:
111 self.snr, self.mad = self._calc_snr()
112 return self.snr
114 def _calc_snr(self):
115 spk_inds = np.nonzero(self.sp.data)[0]
116 mad = median_abs_deviation(self.ca.data)
117 if len(spk_inds) > 0:
118 sn = np.mean(self.ca.data[spk_inds])/mad
119 if np.isnan(sn):
120 raise ValueError('Error in snr calculation')
121 else:
122 raise ValueError('No spikes found!')
124 return sn, mad
126 def get_t_off(self):
127 if self.t_off is None:
128 self.t_off, self.noise_ampl = self._fit_t_off()
130 return self.t_off
132 def get_noise_ampl(self):
133 if self.noise_ampl is None:
134 self.t_off, self.noise_ampl = self._fit_t_off()
136 return self.noise_ampl
139 def _fit_t_off(self):
141 #TODO: fit for an arbitrary kernel form.
142 #TODO: add nonlinear summation fit if needed
144 res = minimize(Neuron.ca_mse_error, (np.array([self.default_t_off])), args=(self.ca.data, self.sp.data, self.default_t_rise))
145 opt_t_off = res.x[0]
146 noise_amplitude = res.fun
148 if opt_t_off > self.default_t_off*5:
149 print(f'Calculated t_off={int(opt_t_off)} for neuron {self.cell_id} is suspiciously high, check signal quality. t_off has been automatically lowered to {self.default_t_off*5}')
151 return min(opt_t_off, self.default_t_off*5), noise_amplitude
154 def get_shuffled_calcium(self, method = 'roll_based', no_ts=True, **kwargs):
155 try:
156 fn = getattr(self, f'_shuffle_calcium_data_{method}')
157 except AttributeError():
158 raise UserWarning('Unknown calcium data shuffling method')
160 sh_ca = fn(**kwargs)
161 sh_ca = Neuron.calcium_preprocessing(sh_ca)
162 if not no_ts:
163 sh_ca = TimeSeries(Neuron.calcium_preprocessing(sh_ca), discrete=False)
165 return sh_ca
168 def _shuffle_calcium_data_waveform_based(self, **kwargs):
170 shuf_ca = np.zeros(self.n_frames)
171 opt_t_off, noise_amplitude = self.get_t_off(), self.get_noise_ampl()
173 #noise = np.random.normal(loc = 0, scale = noise_amplitude, size = len(self.ca))
175 conv = Neuron.get_restored_calcium(self.sp.data, 5, opt_t_off)
176 background = self.ca.data - conv[:len(self.ca.data)]
178 pspk = self._shuffle_spikes_data_isi_based()
179 psconv = Neuron.get_restored_calcium(pspk, 5, opt_t_off)
181 #shuf_ca = conv[:len(self.ca.data)] + noise
182 shuf_ca = psconv[:len(self.ca.data)] + background
183 return shuf_ca
186 def _shuffle_calcium_data_chunks_based(self, **kwargs):
187 if 'n' not in kwargs:
188 n = 100
189 else:
190 n = kwargs['n']
192 shuf_ca = np.zeros(self.n_frames)
193 ca = self.ca.data
194 chunks = np.concatenate(np.split(ca[:-len(ca)%n], n), ca[-(len(ca)%n):])
195 inds = np.arange(n)
196 np.random.shuffle(inds)
198 shuf_ca[:] = np.concatenate(tuple(np.array(chunks)[inds]))
200 return shuf_ca
203 def _shuffle_calcium_data_roll_based(self, **kwargs):
204 opt_t_off = self.get_t_off()
205 if 'shift' in kwargs:
206 shift = kwargs['shift']
207 else:
208 shift = np.random.randint(3*opt_t_off, self.n_frames - 3*opt_t_off)
210 shuf_ca = np.roll(self.ca.data, shift)
212 return shuf_ca
215 def get_shuffled_spikes(self, method = 'isi_based', no_ts=True, **kwargs):
216 if self.sp is None:
217 raise AttributeError('Unable to shuffle spikes without spikes data')
219 try:
220 fn = getattr(self, f'_shuffle_spikes_data_{method}')
221 except AttributeError():
222 raise UserWarning('Unknown calcium data shuffling method')
224 sh_data = fn(**kwargs)
225 if not no_ts:
226 return TimeSeries(sh_data, discrete=True)
227 else:
228 return sh_data
231 def _shuffle_spikes_data_isi_based(self):
232 nfr = self.n_frames
234 pseudo_spikes = np.zeros(nfr)
235 event_inds = np.where(self.sp.data != 0)[0]
237 if len(event_inds) == 0: #if no events were detected, there is nothing to shuffle
238 return self.sp.data
240 event_vals = self.sp.data[event_inds]
241 first_random_pos = np.random.choice(nfr - (max(event_inds) - min(event_inds)))
243 interspike_intervals = np.diff(event_inds)
244 rng = np.arange(len(interspike_intervals))
245 np.random.shuffle(rng)
246 disordered_interspike_intervals = interspike_intervals[rng]
248 pseudo_event_inds = np.cumsum(np.insert(disordered_interspike_intervals,
249 0, first_random_pos))
251 pseudo_event_vals = event_vals
252 np.random.shuffle(event_vals)
253 pseudo_spikes[pseudo_event_inds] = pseudo_event_vals
255 return pseudo_spikes