Coverage for src/driada/experiment/synthetic/manifold_spatial_3d.py: 9.62%
104 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1"""
23D spatial manifold generation for place cells.
4This module contains functions for generating synthetic neural data on 3D spatial
5manifolds, typically used to model place cells in 3D environments.
6"""
8import numpy as np
9from .core import validate_peak_rate, generate_pseudo_calcium_signal
10from .utils import get_effective_decay_time
11from ..exp_base import Experiment
12from ...information.info_base import TimeSeries, MultiTimeSeries
15def generate_3d_random_walk(length, bounds=(0, 1), step_size=0.02, momentum=0.8, seed=None):
16 """
17 Generate a 3D random walk trajectory within bounded region.
19 Parameters
20 ----------
21 length : int
22 Number of time points.
23 bounds : tuple
24 (min, max) bounds for x, y, z coordinates.
25 step_size : float
26 Step size for random walk.
27 momentum : float
28 Momentum factor (0-1) for smoother trajectories.
29 seed : int, optional
30 Random seed.
32 Returns
33 -------
34 positions : ndarray
35 Shape (3, length) with x, y, z coordinates.
36 """
37 if seed is not None:
38 np.random.seed(seed)
40 positions = np.zeros((3, length))
41 velocity = np.zeros(3)
43 # Initialize at random position
44 positions[:, 0] = np.random.uniform(bounds[0], bounds[1], 3)
46 for t in range(1, length):
47 # Random walk with momentum
48 velocity = momentum * velocity + (1 - momentum) * np.random.randn(3) * step_size
50 # Update position
51 new_pos = positions[:, t-1] + velocity
53 # Bounce off walls
54 for dim in range(3):
55 if new_pos[dim] < bounds[0]:
56 new_pos[dim] = bounds[0]
57 velocity[dim] = -velocity[dim]
58 elif new_pos[dim] > bounds[1]:
59 new_pos[dim] = bounds[1]
60 velocity[dim] = -velocity[dim]
62 positions[:, t] = new_pos
64 return positions
67def gaussian_place_field_3d(positions, center, sigma=0.1):
68 """
69 Calculate neural response using 3D Gaussian place field.
71 Parameters
72 ----------
73 positions : ndarray
74 Shape (3, n_timepoints) with x, y, z coordinates.
75 center : ndarray
76 Shape (3,) with place field center coordinates.
77 sigma : float
78 Width of the place field.
80 Returns
81 -------
82 response : ndarray
83 Neural response (firing rate modulation).
84 """
85 # Calculate squared distance from center
86 dx = positions[0, :] - center[0]
87 dy = positions[1, :] - center[1]
88 dz = positions[2, :] - center[2]
89 dist_sq = dx**2 + dy**2 + dz**2
91 # Gaussian response
92 response = np.exp(-dist_sq / (2 * sigma**2))
94 return response
97def generate_3d_manifold_neurons(n_neurons, positions, field_sigma=0.1,
98 baseline_rate=0.1, peak_rate=1.0,
99 noise_std=0.05, grid_arrangement=True,
100 bounds=(0, 1), seed=None):
101 """
102 Generate population of place cells with 3D Gaussian place fields.
104 Parameters
105 ----------
106 n_neurons : int
107 Number of neurons.
108 positions : ndarray
109 Shape (3, n_timepoints) with x, y, z positions.
110 field_sigma : float
111 Width of place fields. Default is 0.1.
112 baseline_rate : float
113 Baseline firing rate. Default is 0.1 Hz.
114 peak_rate : float
115 Peak firing rate at place field center. Default is 1.0 Hz.
116 noise_std : float
117 Noise in firing rates.
118 grid_arrangement : bool
119 If True, arrange place fields in a 3D grid. Otherwise random.
120 bounds : tuple
121 (min, max) bounds for place field centers.
122 seed : int, optional
123 Random seed.
125 Returns
126 -------
127 firing_rates : ndarray
128 Shape (n_neurons, n_timepoints) with firing rates.
129 place_field_centers : ndarray
130 Shape (n_neurons, 3) with place field centers.
131 """
132 # Validate firing rate
133 validate_peak_rate(peak_rate, context="generate_3d_manifold_neurons")
135 if seed is not None:
136 np.random.seed(seed)
138 n_timepoints = positions.shape[1]
140 # Generate place field centers
141 if grid_arrangement:
142 # Arrange in a 3D grid
143 n_per_side = int(np.ceil(n_neurons**(1/3)))
144 x_centers = np.linspace(bounds[0] + 0.1, bounds[1] - 0.1, n_per_side)
145 y_centers = np.linspace(bounds[0] + 0.1, bounds[1] - 0.1, n_per_side)
146 z_centers = np.linspace(bounds[0] + 0.1, bounds[1] - 0.1, n_per_side)
148 centers = []
149 for x in x_centers:
150 for y in y_centers:
151 for z in z_centers:
152 centers.append([x, y, z])
153 if len(centers) >= n_neurons:
154 break
155 if len(centers) >= n_neurons:
156 break
157 if len(centers) >= n_neurons:
158 break
160 place_field_centers = np.array(centers[:n_neurons])
162 # Add small jitter
163 jitter = np.random.normal(0, 0.02, place_field_centers.shape)
164 place_field_centers += jitter
165 place_field_centers = np.clip(place_field_centers, bounds[0], bounds[1])
166 else:
167 # Random placement
168 place_field_centers = np.random.uniform(bounds[0], bounds[1], (n_neurons, 3))
170 # Generate firing rates
171 firing_rates = np.zeros((n_neurons, n_timepoints))
173 for i in range(n_neurons):
174 # Gaussian place field
175 place_response = gaussian_place_field_3d(positions, place_field_centers[i], field_sigma)
177 # Scale to desired firing rate range
178 firing_rate = baseline_rate + (peak_rate - baseline_rate) * place_response
180 # Add noise
181 noise = np.random.normal(0, noise_std, n_timepoints)
182 firing_rate = np.maximum(0, firing_rate + noise)
184 firing_rates[i, :] = firing_rate
186 return firing_rates, place_field_centers
189def generate_3d_manifold_data(n_neurons, duration=600, sampling_rate=20.0,
190 field_sigma=0.1, step_size=0.02, momentum=0.8,
191 baseline_rate=0.1, peak_rate=1.0,
192 noise_std=0.05, grid_arrangement=True,
193 decay_time=2.0, calcium_noise_std=0.1,
194 bounds=(0, 1), seed=None, verbose=True):
195 """
196 Generate synthetic data with neurons on 3D spatial manifold (place cells).
198 Parameters
199 ----------
200 n_neurons : int
201 Number of neurons.
202 duration : float
203 Duration in seconds.
204 sampling_rate : float
205 Sampling rate in Hz.
206 field_sigma : float
207 Width of place fields.
208 step_size : float
209 Step size for random walk.
210 momentum : float
211 Momentum for smoother trajectories.
212 baseline_rate : float
213 Baseline firing rate. Default is 0.1 Hz.
214 peak_rate : float
215 Peak firing rate. Default is 1.0 Hz.
216 noise_std : float
217 Firing rate noise.
218 grid_arrangement : bool
219 Arrange place fields in grid.
220 decay_time : float
221 Calcium decay time.
222 calcium_noise_std : float
223 Calcium signal noise.
224 bounds : tuple
225 Spatial bounds.
226 seed : int, optional
227 Random seed.
228 verbose : bool
229 Print progress.
231 Returns
232 -------
233 calcium_signals : ndarray
234 Calcium signals (n_neurons x n_timepoints).
235 positions : ndarray
236 Position trajectory (3 x n_timepoints).
237 place_field_centers : ndarray
238 Place field centers (n_neurons x 3).
239 firing_rates : ndarray
240 Underlying firing rates.
241 """
242 if seed is not None:
243 np.random.seed(seed)
245 n_timepoints = int(duration * sampling_rate)
247 if verbose:
248 print(f'Generating 3D spatial manifold data: {n_neurons} neurons, {duration}s')
250 # Generate spatial trajectory
251 if verbose:
252 print(' Generating 3D random walk trajectory...')
253 positions = generate_3d_random_walk(n_timepoints, bounds, step_size, momentum, seed)
255 # Generate neural responses
256 if verbose:
257 print(' Generating neural responses with 3D place fields...')
258 firing_rates, place_field_centers = generate_3d_manifold_neurons(
259 n_neurons, positions, field_sigma,
260 baseline_rate, peak_rate, noise_std,
261 grid_arrangement, bounds,
262 seed=(seed + 1) if seed else None
263 )
265 # Convert to calcium signals
266 if verbose:
267 print(' Converting to calcium signals...')
268 calcium_signals = np.zeros((n_neurons, n_timepoints))
270 for i in range(n_neurons):
271 # Generate Poisson events
272 prob_spike = firing_rates[i, :] / sampling_rate
273 prob_spike = np.clip(prob_spike, 0, 1)
274 events = np.random.binomial(1, prob_spike)
276 # Convert to calcium
277 calcium_signal = generate_pseudo_calcium_signal(
278 events=events,
279 duration=duration,
280 sampling_rate=sampling_rate,
281 amplitude_range=(0.5, 2.0),
282 decay_time=decay_time,
283 noise_std=calcium_noise_std
284 )
285 calcium_signals[i, :] = calcium_signal
287 if verbose:
288 print(' Done!')
290 return calcium_signals, positions, place_field_centers, firing_rates
293def generate_3d_manifold_exp(n_neurons=125, duration=600, fps=20.0,
294 field_sigma=0.1, step_size=0.02, momentum=0.8,
295 baseline_rate=0.1, peak_rate=1.0,
296 noise_std=0.05, grid_arrangement=True,
297 decay_time=2.0, calcium_noise_std=0.1,
298 bounds=(0, 1), seed=None, verbose=True,
299 return_info=False):
300 """
301 Generate complete experiment with 3D spatial manifold (place cells).
303 Parameters
304 ----------
305 n_neurons : int
306 Number of neurons.
307 duration : float
308 Duration in seconds.
309 fps : float
310 Sampling rate.
311 field_sigma : float
312 Place field width.
313 step_size : float
314 Random walk step size.
315 momentum : float
316 Trajectory smoothness.
317 baseline_rate : float
318 Baseline firing rate. Default is 0.1 Hz.
319 peak_rate : float
320 Peak firing rate. Default is 1.0 Hz.
321 noise_std : float
322 Firing rate noise.
323 grid_arrangement : bool
324 Grid arrangement of place fields.
325 decay_time : float
326 Calcium decay time.
327 calcium_noise_std : float
328 Calcium noise.
329 bounds : tuple
330 Spatial bounds.
331 seed : int, optional
332 Random seed.
333 verbose : bool
334 Print progress.
335 return_info : bool
336 If True, return (exp, info) tuple with additional information.
338 Returns
339 -------
340 exp : Experiment
341 DRIADA Experiment object with 3D spatial manifold data.
342 info : dict, optional
343 Only returned if return_info=True. Contains:
344 - manifold_type: '3d_spatial'
345 - n_neurons: Number of neurons
346 - positions: 3D trajectory (n_frames, 3)
347 - place_field_centers: 3D place field centers (n_neurons, 3)
348 - firing_rates: Raw firing rates (n_neurons, n_frames)
349 - parameters: Dictionary of all parameters used
350 """
351 # Calculate effective decay time for shuffle mask
352 effective_decay_time = get_effective_decay_time(decay_time, duration, verbose)
354 # Generate data
355 calcium, positions, place_field_centers, firing_rates = generate_3d_manifold_data(
356 n_neurons=n_neurons,
357 duration=duration,
358 sampling_rate=fps,
359 field_sigma=field_sigma,
360 step_size=step_size,
361 momentum=momentum,
362 baseline_rate=baseline_rate,
363 peak_rate=peak_rate,
364 noise_std=noise_std,
365 grid_arrangement=grid_arrangement,
366 decay_time=decay_time,
367 calcium_noise_std=calcium_noise_std,
368 bounds=bounds,
369 seed=seed,
370 verbose=verbose
371 )
373 # Create static features
374 static_features = {
375 'fps': fps,
376 't_rise_sec': 0.04,
377 't_off_sec': effective_decay_time, # Use effective decay time for shuffle mask
378 'manifold_type': '3d_spatial',
379 'field_sigma': field_sigma,
380 'baseline_rate': baseline_rate,
381 'peak_rate': peak_rate,
382 'grid_arrangement': grid_arrangement,
383 }
385 # Create dynamic features
386 position_ts = MultiTimeSeries(
387 [TimeSeries(positions[0, :], discrete=False),
388 TimeSeries(positions[1, :], discrete=False),
389 TimeSeries(positions[2, :], discrete=False)]
390 )
392 # Also create separate x, y, z features
393 x_ts = TimeSeries(
394 data=positions[0, :],
395 discrete=False
396 )
398 y_ts = TimeSeries(
399 data=positions[1, :],
400 discrete=False
401 )
403 z_ts = TimeSeries(
404 data=positions[2, :],
405 discrete=False
406 )
408 dynamic_features = {
409 'position_3d': position_ts,
410 'x': x_ts,
411 'y': y_ts,
412 'z': z_ts
413 }
415 # Store additional information
416 static_features['place_field_centers'] = place_field_centers
418 # Create experiment
419 exp = Experiment(
420 signature='3d_spatial_manifold_exp',
421 calcium=calcium,
422 spikes=None,
423 static_features=static_features,
424 dynamic_features=dynamic_features,
425 exp_identificators={
426 'manifold': '3d_spatial',
427 'n_neurons': n_neurons,
428 'duration': duration
429 }
430 )
432 # Store firing rates
433 exp.firing_rates = firing_rates
435 # Create info dictionary if requested
436 if return_info:
437 info = {
438 'manifold_type': '3d_spatial',
439 'n_neurons': n_neurons,
440 'positions': positions,
441 'place_field_centers': place_field_centers,
442 'firing_rates': firing_rates,
443 'parameters': {
444 'field_sigma': field_sigma,
445 'step_size': step_size,
446 'momentum': momentum,
447 'baseline_rate': baseline_rate,
448 'peak_rate': peak_rate,
449 'noise_std': noise_std,
450 'grid_arrangement': grid_arrangement,
451 'decay_time': decay_time,
452 'calcium_noise_std': calcium_noise_std,
453 'bounds': bounds
454 }
455 }
456 return exp, info
458 return exp