Coverage for src/driada/experiment/neuron.py: 37.89%

161 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1import numpy as np 

2from numba import njit 

3from scipy.stats import median_abs_deviation 

4from scipy.optimize import minimize 

5from ..information.info_base import TimeSeries 

6from .wavelet_event_detection import * 

7 

8DEFAULT_T_RISE = 0.25 #sec 

9DEFAULT_T_OFF = 2.0 #sec 

10 

11DEFAULT_FPS = 20.0 #frames per sec 

12DEFAULT_MIN_BEHAVIOUR_TIME = 0.25 #sec 

13 

14MIN_CA_SHIFT = 5 # MIN_SHIFT*t_off is the minimal random signal shift for a given cell 

15 

16#TODO: add numba decorators where possible 

17class Neuron(): 

18 """ 

19 Class for representing all information about a single neuron. 

20 

21 Attributes 

22 ---------- 

23 test: str 

24 description 

25 

26 Methods 

27 ------- 

28 test(arg=None) 

29 description 

30 """ 

31 

32 @staticmethod 

33 @njit() 

34 def spike_form(t, t_rise, t_off): 

35 form = (1-np.exp(-t/t_rise))*np.exp(-t/t_off) 

36 return form/max(form) 

37 

38 

39 @staticmethod 

40 def get_restored_calcium(sp, t_rise, t_off): 

41 x = np.linspace(0, 1000, num=1000) 

42 spform = Neuron.spike_form(x, t_rise=t_rise, t_off=t_off) 

43 conv = np.convolve(sp, spform) 

44 return conv 

45 

46 @staticmethod 

47 def ca_mse_error(t_off, ca, spk, t_rise): 

48 # TODO: fix for new spike format 

49 re_ca = Neuron.get_restored_calcium(spk, t_rise, t_off) 

50 return np.sqrt(np.sum(np.abs(ca - re_ca[:len(ca)])**2)/len(ca)) 

51 

52 @staticmethod 

53 def calcium_preprocessing(ca): 

54 ca[np.where(ca < 0)[0]] = 0 

55 #ca = ca + np.abs(min(ca)) 

56 ca += np.random.random(size=len(ca))*1e-8 

57 return ca 

58 

59 def __init__(self, cell_id, ca, sp, 

60 default_t_rise=DEFAULT_T_RISE, default_t_off=DEFAULT_T_OFF, fps=DEFAULT_FPS, 

61 fit_individual_t_off=False): 

62 

63 if default_t_rise is None: 

64 default_t_rise = DEFAULT_T_RISE 

65 if default_t_off is None: 

66 default_t_off = DEFAULT_T_OFF 

67 if fps is None: 

68 fps = DEFAULT_FPS 

69 

70 self.cell_id = cell_id 

71 self.ca = TimeSeries(Neuron.calcium_preprocessing(ca), discrete=False) 

72 if sp is None: 

73 self.sp = None 

74 else: 

75 self.sp = TimeSeries(sp.astype(int), discrete=True) 

76 self.n_frames = len(self.ca.data) 

77 

78 self.sp_count = np.sum(self.sp.data.astype(bool).astype(int)) if self.sp is not None else 0 

79 self.t_off = None 

80 self.noise_ampl = None 

81 self.mad = None 

82 self.snr = None 

83 

84 self.default_t_off = default_t_off*fps 

85 self.default_t_rise = default_t_rise*fps 

86 

87 if fit_individual_t_off: 

88 t_off = self.get_t_off() 

89 else: 

90 t_off = self.default_t_off 

91 

92 # add shuffle mask according to computed characteristic calcium decay time 

93 self.ca.shuffle_mask = np.ones(self.n_frames).astype(bool) 

94 min_shift = int(t_off * MIN_CA_SHIFT) 

95 self.ca.shuffle_mask[:min_shift] = False 

96 self.ca.shuffle_mask[self.n_frames - min_shift:] = False 

97 

98 def reconstruct_spikes(self, **kwargs): 

99 raise AttributeError('Spike reconstruction not implemented') 

100 

101 def get_mad(self): 

102 if self.mad is None: 

103 try: 

104 self.snr, self.mad = self._calc_snr() 

105 except ValueError: 

106 self.mad = median_abs_deviation(self.ca.data) 

107 return self.mad 

108 

109 def get_snr(self): 

110 if self.snr is None: 

111 self.snr, self.mad = self._calc_snr() 

112 return self.snr 

113 

114 def _calc_snr(self): 

115 spk_inds = np.nonzero(self.sp.data)[0] 

116 mad = median_abs_deviation(self.ca.data) 

117 if len(spk_inds) > 0: 

118 sn = np.mean(self.ca.data[spk_inds])/mad 

119 if np.isnan(sn): 

120 raise ValueError('Error in snr calculation') 

121 else: 

122 raise ValueError('No spikes found!') 

123 

124 return sn, mad 

125 

126 def get_t_off(self): 

127 if self.t_off is None: 

128 self.t_off, self.noise_ampl = self._fit_t_off() 

129 

130 return self.t_off 

131 

132 def get_noise_ampl(self): 

133 if self.noise_ampl is None: 

134 self.t_off, self.noise_ampl = self._fit_t_off() 

135 

136 return self.noise_ampl 

137 

138 

139 def _fit_t_off(self): 

140 

141 #TODO: fit for an arbitrary kernel form. 

142 #TODO: add nonlinear summation fit if needed 

143 

144 res = minimize(Neuron.ca_mse_error, (np.array([self.default_t_off])), args=(self.ca.data, self.sp.data, self.default_t_rise)) 

145 opt_t_off = res.x[0] 

146 noise_amplitude = res.fun 

147 

148 if opt_t_off > self.default_t_off*5: 

149 print(f'Calculated t_off={int(opt_t_off)} for neuron {self.cell_id} is suspiciously high, check signal quality. t_off has been automatically lowered to {self.default_t_off*5}') 

150 

151 return min(opt_t_off, self.default_t_off*5), noise_amplitude 

152 

153 

154 def get_shuffled_calcium(self, method = 'roll_based', no_ts=True, **kwargs): 

155 try: 

156 fn = getattr(self, f'_shuffle_calcium_data_{method}') 

157 except AttributeError(): 

158 raise UserWarning('Unknown calcium data shuffling method') 

159 

160 sh_ca = fn(**kwargs) 

161 sh_ca = Neuron.calcium_preprocessing(sh_ca) 

162 if not no_ts: 

163 sh_ca = TimeSeries(Neuron.calcium_preprocessing(sh_ca), discrete=False) 

164 

165 return sh_ca 

166 

167 

168 def _shuffle_calcium_data_waveform_based(self, **kwargs): 

169 

170 shuf_ca = np.zeros(self.n_frames) 

171 opt_t_off, noise_amplitude = self.get_t_off(), self.get_noise_ampl() 

172 

173 #noise = np.random.normal(loc = 0, scale = noise_amplitude, size = len(self.ca)) 

174 

175 conv = Neuron.get_restored_calcium(self.sp.data, 5, opt_t_off) 

176 background = self.ca.data - conv[:len(self.ca.data)] 

177 

178 pspk = self._shuffle_spikes_data_isi_based() 

179 psconv = Neuron.get_restored_calcium(pspk, 5, opt_t_off) 

180 

181 #shuf_ca = conv[:len(self.ca.data)] + noise 

182 shuf_ca = psconv[:len(self.ca.data)] + background 

183 return shuf_ca 

184 

185 

186 def _shuffle_calcium_data_chunks_based(self, **kwargs): 

187 if 'n' not in kwargs: 

188 n = 100 

189 else: 

190 n = kwargs['n'] 

191 

192 shuf_ca = np.zeros(self.n_frames) 

193 ca = self.ca.data 

194 chunks = np.concatenate(np.split(ca[:-len(ca)%n], n), ca[-(len(ca)%n):]) 

195 inds = np.arange(n) 

196 np.random.shuffle(inds) 

197 

198 shuf_ca[:] = np.concatenate(tuple(np.array(chunks)[inds])) 

199 

200 return shuf_ca 

201 

202 

203 def _shuffle_calcium_data_roll_based(self, **kwargs): 

204 opt_t_off = self.get_t_off() 

205 if 'shift' in kwargs: 

206 shift = kwargs['shift'] 

207 else: 

208 shift = np.random.randint(3*opt_t_off, self.n_frames - 3*opt_t_off) 

209 

210 shuf_ca = np.roll(self.ca.data, shift) 

211 

212 return shuf_ca 

213 

214 

215 def get_shuffled_spikes(self, method = 'isi_based', no_ts=True, **kwargs): 

216 if self.sp is None: 

217 raise AttributeError('Unable to shuffle spikes without spikes data') 

218 

219 try: 

220 fn = getattr(self, f'_shuffle_spikes_data_{method}') 

221 except AttributeError(): 

222 raise UserWarning('Unknown calcium data shuffling method') 

223 

224 sh_data = fn(**kwargs) 

225 if not no_ts: 

226 return TimeSeries(sh_data, discrete=True) 

227 else: 

228 return sh_data 

229 

230 

231 def _shuffle_spikes_data_isi_based(self): 

232 nfr = self.n_frames 

233 

234 pseudo_spikes = np.zeros(nfr) 

235 event_inds = np.where(self.sp.data != 0)[0] 

236 

237 if len(event_inds) == 0: #if no events were detected, there is nothing to shuffle 

238 return self.sp.data 

239 

240 event_vals = self.sp.data[event_inds] 

241 first_random_pos = np.random.choice(nfr - (max(event_inds) - min(event_inds))) 

242 

243 interspike_intervals = np.diff(event_inds) 

244 rng = np.arange(len(interspike_intervals)) 

245 np.random.shuffle(rng) 

246 disordered_interspike_intervals = interspike_intervals[rng] 

247 

248 pseudo_event_inds = np.cumsum(np.insert(disordered_interspike_intervals, 

249 0, first_random_pos)) 

250 

251 pseudo_event_vals = event_vals 

252 np.random.shuffle(event_vals) 

253 pseudo_spikes[pseudo_event_inds] = pseudo_event_vals 

254 

255 return pseudo_spikes