Coverage for src/driada/experiment/synthetic/manifold_spatial_3d.py: 9.62%

104 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1""" 

23D spatial manifold generation for place cells. 

3 

4This module contains functions for generating synthetic neural data on 3D spatial 

5manifolds, typically used to model place cells in 3D environments. 

6""" 

7 

8import numpy as np 

9from .core import validate_peak_rate, generate_pseudo_calcium_signal 

10from .utils import get_effective_decay_time 

11from ..exp_base import Experiment 

12from ...information.info_base import TimeSeries, MultiTimeSeries 

13 

14 

15def generate_3d_random_walk(length, bounds=(0, 1), step_size=0.02, momentum=0.8, seed=None): 

16 """ 

17 Generate a 3D random walk trajectory within bounded region. 

18  

19 Parameters 

20 ---------- 

21 length : int 

22 Number of time points. 

23 bounds : tuple 

24 (min, max) bounds for x, y, z coordinates. 

25 step_size : float 

26 Step size for random walk. 

27 momentum : float 

28 Momentum factor (0-1) for smoother trajectories. 

29 seed : int, optional 

30 Random seed. 

31  

32 Returns 

33 ------- 

34 positions : ndarray 

35 Shape (3, length) with x, y, z coordinates. 

36 """ 

37 if seed is not None: 

38 np.random.seed(seed) 

39 

40 positions = np.zeros((3, length)) 

41 velocity = np.zeros(3) 

42 

43 # Initialize at random position 

44 positions[:, 0] = np.random.uniform(bounds[0], bounds[1], 3) 

45 

46 for t in range(1, length): 

47 # Random walk with momentum 

48 velocity = momentum * velocity + (1 - momentum) * np.random.randn(3) * step_size 

49 

50 # Update position 

51 new_pos = positions[:, t-1] + velocity 

52 

53 # Bounce off walls 

54 for dim in range(3): 

55 if new_pos[dim] < bounds[0]: 

56 new_pos[dim] = bounds[0] 

57 velocity[dim] = -velocity[dim] 

58 elif new_pos[dim] > bounds[1]: 

59 new_pos[dim] = bounds[1] 

60 velocity[dim] = -velocity[dim] 

61 

62 positions[:, t] = new_pos 

63 

64 return positions 

65 

66 

67def gaussian_place_field_3d(positions, center, sigma=0.1): 

68 """ 

69 Calculate neural response using 3D Gaussian place field. 

70  

71 Parameters 

72 ---------- 

73 positions : ndarray 

74 Shape (3, n_timepoints) with x, y, z coordinates. 

75 center : ndarray 

76 Shape (3,) with place field center coordinates. 

77 sigma : float 

78 Width of the place field. 

79  

80 Returns 

81 ------- 

82 response : ndarray 

83 Neural response (firing rate modulation). 

84 """ 

85 # Calculate squared distance from center 

86 dx = positions[0, :] - center[0] 

87 dy = positions[1, :] - center[1] 

88 dz = positions[2, :] - center[2] 

89 dist_sq = dx**2 + dy**2 + dz**2 

90 

91 # Gaussian response 

92 response = np.exp(-dist_sq / (2 * sigma**2)) 

93 

94 return response 

95 

96 

97def generate_3d_manifold_neurons(n_neurons, positions, field_sigma=0.1, 

98 baseline_rate=0.1, peak_rate=1.0, 

99 noise_std=0.05, grid_arrangement=True, 

100 bounds=(0, 1), seed=None): 

101 """ 

102 Generate population of place cells with 3D Gaussian place fields. 

103  

104 Parameters 

105 ---------- 

106 n_neurons : int 

107 Number of neurons. 

108 positions : ndarray 

109 Shape (3, n_timepoints) with x, y, z positions. 

110 field_sigma : float 

111 Width of place fields. Default is 0.1. 

112 baseline_rate : float 

113 Baseline firing rate. Default is 0.1 Hz. 

114 peak_rate : float 

115 Peak firing rate at place field center. Default is 1.0 Hz. 

116 noise_std : float 

117 Noise in firing rates. 

118 grid_arrangement : bool 

119 If True, arrange place fields in a 3D grid. Otherwise random. 

120 bounds : tuple 

121 (min, max) bounds for place field centers. 

122 seed : int, optional 

123 Random seed. 

124  

125 Returns 

126 ------- 

127 firing_rates : ndarray 

128 Shape (n_neurons, n_timepoints) with firing rates. 

129 place_field_centers : ndarray 

130 Shape (n_neurons, 3) with place field centers. 

131 """ 

132 # Validate firing rate 

133 validate_peak_rate(peak_rate, context="generate_3d_manifold_neurons") 

134 

135 if seed is not None: 

136 np.random.seed(seed) 

137 

138 n_timepoints = positions.shape[1] 

139 

140 # Generate place field centers 

141 if grid_arrangement: 

142 # Arrange in a 3D grid 

143 n_per_side = int(np.ceil(n_neurons**(1/3))) 

144 x_centers = np.linspace(bounds[0] + 0.1, bounds[1] - 0.1, n_per_side) 

145 y_centers = np.linspace(bounds[0] + 0.1, bounds[1] - 0.1, n_per_side) 

146 z_centers = np.linspace(bounds[0] + 0.1, bounds[1] - 0.1, n_per_side) 

147 

148 centers = [] 

149 for x in x_centers: 

150 for y in y_centers: 

151 for z in z_centers: 

152 centers.append([x, y, z]) 

153 if len(centers) >= n_neurons: 

154 break 

155 if len(centers) >= n_neurons: 

156 break 

157 if len(centers) >= n_neurons: 

158 break 

159 

160 place_field_centers = np.array(centers[:n_neurons]) 

161 

162 # Add small jitter 

163 jitter = np.random.normal(0, 0.02, place_field_centers.shape) 

164 place_field_centers += jitter 

165 place_field_centers = np.clip(place_field_centers, bounds[0], bounds[1]) 

166 else: 

167 # Random placement 

168 place_field_centers = np.random.uniform(bounds[0], bounds[1], (n_neurons, 3)) 

169 

170 # Generate firing rates 

171 firing_rates = np.zeros((n_neurons, n_timepoints)) 

172 

173 for i in range(n_neurons): 

174 # Gaussian place field 

175 place_response = gaussian_place_field_3d(positions, place_field_centers[i], field_sigma) 

176 

177 # Scale to desired firing rate range 

178 firing_rate = baseline_rate + (peak_rate - baseline_rate) * place_response 

179 

180 # Add noise 

181 noise = np.random.normal(0, noise_std, n_timepoints) 

182 firing_rate = np.maximum(0, firing_rate + noise) 

183 

184 firing_rates[i, :] = firing_rate 

185 

186 return firing_rates, place_field_centers 

187 

188 

189def generate_3d_manifold_data(n_neurons, duration=600, sampling_rate=20.0, 

190 field_sigma=0.1, step_size=0.02, momentum=0.8, 

191 baseline_rate=0.1, peak_rate=1.0, 

192 noise_std=0.05, grid_arrangement=True, 

193 decay_time=2.0, calcium_noise_std=0.1, 

194 bounds=(0, 1), seed=None, verbose=True): 

195 """ 

196 Generate synthetic data with neurons on 3D spatial manifold (place cells). 

197  

198 Parameters 

199 ---------- 

200 n_neurons : int 

201 Number of neurons. 

202 duration : float 

203 Duration in seconds. 

204 sampling_rate : float 

205 Sampling rate in Hz. 

206 field_sigma : float 

207 Width of place fields. 

208 step_size : float 

209 Step size for random walk. 

210 momentum : float 

211 Momentum for smoother trajectories. 

212 baseline_rate : float 

213 Baseline firing rate. Default is 0.1 Hz. 

214 peak_rate : float 

215 Peak firing rate. Default is 1.0 Hz. 

216 noise_std : float 

217 Firing rate noise. 

218 grid_arrangement : bool 

219 Arrange place fields in grid. 

220 decay_time : float 

221 Calcium decay time. 

222 calcium_noise_std : float 

223 Calcium signal noise. 

224 bounds : tuple 

225 Spatial bounds. 

226 seed : int, optional 

227 Random seed. 

228 verbose : bool 

229 Print progress. 

230  

231 Returns 

232 ------- 

233 calcium_signals : ndarray 

234 Calcium signals (n_neurons x n_timepoints). 

235 positions : ndarray 

236 Position trajectory (3 x n_timepoints). 

237 place_field_centers : ndarray 

238 Place field centers (n_neurons x 3). 

239 firing_rates : ndarray 

240 Underlying firing rates. 

241 """ 

242 if seed is not None: 

243 np.random.seed(seed) 

244 

245 n_timepoints = int(duration * sampling_rate) 

246 

247 if verbose: 

248 print(f'Generating 3D spatial manifold data: {n_neurons} neurons, {duration}s') 

249 

250 # Generate spatial trajectory 

251 if verbose: 

252 print(' Generating 3D random walk trajectory...') 

253 positions = generate_3d_random_walk(n_timepoints, bounds, step_size, momentum, seed) 

254 

255 # Generate neural responses 

256 if verbose: 

257 print(' Generating neural responses with 3D place fields...') 

258 firing_rates, place_field_centers = generate_3d_manifold_neurons( 

259 n_neurons, positions, field_sigma, 

260 baseline_rate, peak_rate, noise_std, 

261 grid_arrangement, bounds, 

262 seed=(seed + 1) if seed else None 

263 ) 

264 

265 # Convert to calcium signals 

266 if verbose: 

267 print(' Converting to calcium signals...') 

268 calcium_signals = np.zeros((n_neurons, n_timepoints)) 

269 

270 for i in range(n_neurons): 

271 # Generate Poisson events 

272 prob_spike = firing_rates[i, :] / sampling_rate 

273 prob_spike = np.clip(prob_spike, 0, 1) 

274 events = np.random.binomial(1, prob_spike) 

275 

276 # Convert to calcium 

277 calcium_signal = generate_pseudo_calcium_signal( 

278 events=events, 

279 duration=duration, 

280 sampling_rate=sampling_rate, 

281 amplitude_range=(0.5, 2.0), 

282 decay_time=decay_time, 

283 noise_std=calcium_noise_std 

284 ) 

285 calcium_signals[i, :] = calcium_signal 

286 

287 if verbose: 

288 print(' Done!') 

289 

290 return calcium_signals, positions, place_field_centers, firing_rates 

291 

292 

293def generate_3d_manifold_exp(n_neurons=125, duration=600, fps=20.0, 

294 field_sigma=0.1, step_size=0.02, momentum=0.8, 

295 baseline_rate=0.1, peak_rate=1.0, 

296 noise_std=0.05, grid_arrangement=True, 

297 decay_time=2.0, calcium_noise_std=0.1, 

298 bounds=(0, 1), seed=None, verbose=True, 

299 return_info=False): 

300 """ 

301 Generate complete experiment with 3D spatial manifold (place cells). 

302  

303 Parameters 

304 ---------- 

305 n_neurons : int 

306 Number of neurons. 

307 duration : float 

308 Duration in seconds. 

309 fps : float 

310 Sampling rate. 

311 field_sigma : float 

312 Place field width. 

313 step_size : float 

314 Random walk step size. 

315 momentum : float 

316 Trajectory smoothness. 

317 baseline_rate : float 

318 Baseline firing rate. Default is 0.1 Hz. 

319 peak_rate : float 

320 Peak firing rate. Default is 1.0 Hz. 

321 noise_std : float 

322 Firing rate noise. 

323 grid_arrangement : bool 

324 Grid arrangement of place fields. 

325 decay_time : float 

326 Calcium decay time. 

327 calcium_noise_std : float 

328 Calcium noise. 

329 bounds : tuple 

330 Spatial bounds. 

331 seed : int, optional 

332 Random seed. 

333 verbose : bool 

334 Print progress. 

335 return_info : bool 

336 If True, return (exp, info) tuple with additional information. 

337  

338 Returns 

339 ------- 

340 exp : Experiment 

341 DRIADA Experiment object with 3D spatial manifold data. 

342 info : dict, optional 

343 Only returned if return_info=True. Contains: 

344 - manifold_type: '3d_spatial' 

345 - n_neurons: Number of neurons 

346 - positions: 3D trajectory (n_frames, 3) 

347 - place_field_centers: 3D place field centers (n_neurons, 3) 

348 - firing_rates: Raw firing rates (n_neurons, n_frames) 

349 - parameters: Dictionary of all parameters used 

350 """ 

351 # Calculate effective decay time for shuffle mask 

352 effective_decay_time = get_effective_decay_time(decay_time, duration, verbose) 

353 

354 # Generate data 

355 calcium, positions, place_field_centers, firing_rates = generate_3d_manifold_data( 

356 n_neurons=n_neurons, 

357 duration=duration, 

358 sampling_rate=fps, 

359 field_sigma=field_sigma, 

360 step_size=step_size, 

361 momentum=momentum, 

362 baseline_rate=baseline_rate, 

363 peak_rate=peak_rate, 

364 noise_std=noise_std, 

365 grid_arrangement=grid_arrangement, 

366 decay_time=decay_time, 

367 calcium_noise_std=calcium_noise_std, 

368 bounds=bounds, 

369 seed=seed, 

370 verbose=verbose 

371 ) 

372 

373 # Create static features 

374 static_features = { 

375 'fps': fps, 

376 't_rise_sec': 0.04, 

377 't_off_sec': effective_decay_time, # Use effective decay time for shuffle mask 

378 'manifold_type': '3d_spatial', 

379 'field_sigma': field_sigma, 

380 'baseline_rate': baseline_rate, 

381 'peak_rate': peak_rate, 

382 'grid_arrangement': grid_arrangement, 

383 } 

384 

385 # Create dynamic features 

386 position_ts = MultiTimeSeries( 

387 [TimeSeries(positions[0, :], discrete=False), 

388 TimeSeries(positions[1, :], discrete=False), 

389 TimeSeries(positions[2, :], discrete=False)] 

390 ) 

391 

392 # Also create separate x, y, z features 

393 x_ts = TimeSeries( 

394 data=positions[0, :], 

395 discrete=False 

396 ) 

397 

398 y_ts = TimeSeries( 

399 data=positions[1, :], 

400 discrete=False 

401 ) 

402 

403 z_ts = TimeSeries( 

404 data=positions[2, :], 

405 discrete=False 

406 ) 

407 

408 dynamic_features = { 

409 'position_3d': position_ts, 

410 'x': x_ts, 

411 'y': y_ts, 

412 'z': z_ts 

413 } 

414 

415 # Store additional information 

416 static_features['place_field_centers'] = place_field_centers 

417 

418 # Create experiment 

419 exp = Experiment( 

420 signature='3d_spatial_manifold_exp', 

421 calcium=calcium, 

422 spikes=None, 

423 static_features=static_features, 

424 dynamic_features=dynamic_features, 

425 exp_identificators={ 

426 'manifold': '3d_spatial', 

427 'n_neurons': n_neurons, 

428 'duration': duration 

429 } 

430 ) 

431 

432 # Store firing rates 

433 exp.firing_rates = firing_rates 

434 

435 # Create info dictionary if requested 

436 if return_info: 

437 info = { 

438 'manifold_type': '3d_spatial', 

439 'n_neurons': n_neurons, 

440 'positions': positions, 

441 'place_field_centers': place_field_centers, 

442 'firing_rates': firing_rates, 

443 'parameters': { 

444 'field_sigma': field_sigma, 

445 'step_size': step_size, 

446 'momentum': momentum, 

447 'baseline_rate': baseline_rate, 

448 'peak_rate': peak_rate, 

449 'noise_std': noise_std, 

450 'grid_arrangement': grid_arrangement, 

451 'decay_time': decay_time, 

452 'calcium_noise_std': calcium_noise_std, 

453 'bounds': bounds 

454 } 

455 } 

456 return exp, info 

457 

458 return exp