Coverage for src/driada/experiment/synthetic/experiment_generators.py: 77.17%
276 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1"""
2High-level experiment generators for synthetic neural data.
4This module contains functions that generate complete synthetic experiments
5by combining various types of neural data (manifold-based, feature-selective, mixed).
6"""
8import numpy as np
9import tqdm
10from .core import validate_peak_rate, generate_pseudo_calcium_signal, generate_pseudo_calcium_multisignal
11from .time_series import (
12 generate_binary_time_series, generate_fbm_time_series,
13 select_signal_roi, delete_one_islands, apply_poisson_to_binary_series
14)
15from .manifold_circular import generate_circular_manifold_data
16from .manifold_spatial_2d import generate_2d_manifold_data
17from .manifold_spatial_3d import generate_3d_manifold_data
18from .mixed_selectivity import (
19 generate_multiselectivity_patterns,
20 generate_synthetic_data_mixed_selectivity
21)
22from ..exp_base import Experiment
23from ...information.info_base import TimeSeries, MultiTimeSeries
26def generate_synthetic_data(nfeats, nneurons, ftype='c', duration=600, seed=42, sampling_rate=20.0,
27 rate_0=0.1, rate_1=1.0, skip_prob=0.0, hurst=0.5, ampl_range=(0.5, 2), decay_time=2,
28 avg_islands=10, avg_duration=5, noise_std=0.1, verbose=True, pregenerated_features=None,
29 apply_random_neuron_shifts=False):
30 """
31 Generate synthetic neural data with feature-selective neurons.
33 Parameters
34 ----------
35 nfeats : int
36 Number of features.
37 nneurons : int
38 Number of neurons.
39 ftype : str
40 Feature type: 'c' for continuous, 'd' for discrete.
41 duration : float
42 Duration in seconds.
43 seed : int
44 Random seed.
45 sampling_rate : float
46 Sampling rate in Hz.
47 rate_0 : float
48 Baseline firing rate.
49 rate_1 : float
50 Active firing rate.
51 skip_prob : float
52 Probability of skipping islands.
53 hurst : float
54 Hurst parameter for FBM.
55 ampl_range : tuple
56 Amplitude range for calcium events.
57 decay_time : float
58 Calcium decay time.
59 avg_islands : int
60 Average number of islands for discrete features.
61 avg_duration : int
62 Average duration of islands.
63 noise_std : float
64 Noise standard deviation.
65 verbose : bool
66 Print progress.
67 pregenerated_features : list, optional
68 Use provided features instead of generating new ones.
69 apply_random_neuron_shifts : bool
70 Apply random shifts to break correlations.
72 Returns
73 -------
74 features : ndarray
75 Feature time series (nfeats x length).
76 signals : ndarray
77 Neural signals (nneurons x length).
78 ground_truth : ndarray
79 Ground truth matrix (nfeats x nneurons).
80 """
81 gt = np.zeros((nfeats, nneurons))
82 length = int(duration * sampling_rate)
84 # Handle edge case of 0 neurons
85 if nneurons == 0:
86 return np.array([]), np.array([]).reshape(0, length), gt
88 # Use pregenerated features if provided, otherwise generate new ones
89 if pregenerated_features is not None:
90 print('Using pregenerated features...')
91 all_feats = pregenerated_features
92 if len(all_feats) != nfeats:
93 raise ValueError(f'Number of pregenerated features ({len(all_feats)}) does not match nfeats ({nfeats})')
94 else:
95 print('Generating features...')
96 all_feats = []
97 for i in tqdm.tqdm(np.arange(nfeats)):
98 if ftype == 'c':
99 # Generate the series with unique seed for each feature
100 feature_seed = seed + i if seed is not None else None
101 fbm_series = generate_fbm_time_series(length, hurst, seed=feature_seed)
102 all_feats.append(fbm_series)
104 elif ftype == 'd':
105 # Generate binary series
106 binary_series = generate_binary_time_series(length, avg_islands, avg_duration * sampling_rate)
107 all_feats.append(binary_series)
109 else:
110 raise ValueError('unknown feature flag')
112 print('Generating signals...')
113 if nfeats > 0:
114 fois = np.random.choice(np.arange(nfeats), size=nneurons)
115 gt[fois, np.arange(nneurons)] = 1 # add info about ground truth feature-signal connections
116 else:
117 # If no features, neurons won't be selective to any feature
118 fois = np.full(nneurons, -1) # Use -1 to indicate no feature selection
119 all_signals = []
121 for j in tqdm.tqdm(np.arange(nneurons)):
122 foi = fois[j]
124 # Handle case where there are no features
125 if foi == -1 or nfeats == 0:
126 # Generate random baseline activity
127 binary_series = generate_binary_time_series(length, avg_islands // 2, avg_duration * sampling_rate // 2)
128 elif ftype == 'c':
129 csignal = all_feats[foi].copy() # Make a copy to avoid modifying the original
131 # Apply random per-neuron shift to break correlations
132 if apply_random_neuron_shifts:
133 # Apply a unique random shift for this neuron
134 neuron_shift = np.random.randint(0, length)
135 csignal = np.roll(csignal, neuron_shift)
136 if verbose and j < 3: # Print for first 3 neurons only
137 print(f' Neuron {j}: Applied shift={neuron_shift} to continuous feature {foi}')
139 loc, lower_border, upper_border = select_signal_roi(csignal, seed=seed)
140 # Generate binary series from a continuous one
141 binary_series = np.zeros(length)
142 binary_series[np.where((csignal >= lower_border) & (csignal <= upper_border))] = 1
144 elif ftype == 'd':
145 binary_series = all_feats[foi].copy() # Make a copy
147 # Apply random per-neuron shift to break correlations
148 if apply_random_neuron_shifts:
149 # Apply a unique random shift for this neuron
150 neuron_shift = np.random.randint(0, length)
151 binary_series = np.roll(binary_series, neuron_shift)
152 if verbose and j < 3: # Print for first 3 neurons only
153 print(f' Neuron {j}: Applied shift={neuron_shift} to discrete feature {foi}')
155 else:
156 raise ValueError('unknown feature flag')
158 # randomly skip some on periods
159 mod_binary_series = delete_one_islands(binary_series, skip_prob)
161 # Apply Poisson process
162 poisson_series = apply_poisson_to_binary_series(mod_binary_series,
163 rate_0 / sampling_rate,
164 rate_1 / sampling_rate)
166 # Generate pseudo-calcium
167 pseudo_calcium_signal = generate_pseudo_calcium_signal(duration=duration,
168 events=poisson_series,
169 sampling_rate=sampling_rate,
170 amplitude_range=ampl_range,
171 decay_time=decay_time,
172 noise_std=noise_std)
174 all_signals.append(pseudo_calcium_signal)
175 # Do not modify seed during feature generation
176 # if seed is not None:
177 # seed += 1 # save reproducibility, but break degeneracy
179 return np.vstack(all_feats), np.vstack(all_signals), gt
182def generate_synthetic_exp(n_dfeats=20, n_cfeats=20, nneurons=500, seed=0, fps=20, with_spikes=False, duration=1200):
183 """
184 Generate a synthetic experiment with neurons selective to discrete and continuous features.
186 Parameters
187 ----------
188 n_dfeats : int, optional
189 Number of discrete features. Default: 20.
190 n_cfeats : int, optional
191 Number of continuous features. Default: 20.
192 nneurons : int, optional
193 Total number of neurons. Default: 500.
194 seed : int, optional
195 Random seed for reproducibility. Default: 0.
196 fps : float, optional
197 Frames per second. Default: 20.
198 with_spikes : bool, optional
199 If True, reconstruct spikes from calcium using wavelet method. Default: False.
200 duration : int, optional
201 Duration of the experiment in seconds. Default: 1200.
203 Returns
204 -------
205 exp : Experiment
206 Synthetic experiment object with calcium signals and optionally spike data.
207 """
208 # Set the numpy random seed at the beginning of the function
209 if seed is not None:
210 np.random.seed(seed)
211 # Split neurons between those responding to discrete and continuous features
212 # For odd numbers, give the extra neuron to the first group
213 # But if one type has 0 features, allocate all neurons to the other type
214 if n_dfeats == 0:
215 n_neurons_discrete = 0
216 n_neurons_continuous = nneurons
217 elif n_cfeats == 0:
218 n_neurons_discrete = nneurons
219 n_neurons_continuous = 0
220 else:
221 n_neurons_discrete = (nneurons + 1) // 2
222 n_neurons_continuous = nneurons // 2
224 dfeats, calcium1, gt = generate_synthetic_data(n_dfeats,
225 n_neurons_discrete,
226 duration=duration,
227 hurst=0.3,
228 ftype='d',
229 seed=seed,
230 rate_0=0.1,
231 rate_1=1.0,
232 skip_prob=0.1,
233 noise_std=0.1,
234 sampling_rate=fps)
236 cfeats, calcium2, gt2 = generate_synthetic_data(n_cfeats, # Fixed: was n_dfeats
237 n_neurons_continuous,
238 duration=duration,
239 hurst=0.3,
240 ftype='c',
241 seed=seed,
242 rate_0=0.1,
243 rate_1=1.0,
244 skip_prob=0.1,
245 noise_std=0.1,
246 sampling_rate=fps)
248 discr_ts = {f'd_feat_{i}': TimeSeries(dfeats[i, :], discrete=True) for i in range(len(dfeats))}
249 cont_ts = {f'c_feat_{i}': TimeSeries(cfeats[i, :], discrete=False) for i in range(len(cfeats))}
251 # Combine calcium signals, handling empty arrays
252 if n_neurons_discrete == 0:
253 all_calcium = calcium2
254 elif n_neurons_continuous == 0:
255 all_calcium = calcium1
256 else:
257 all_calcium = np.vstack([calcium1, calcium2])
259 # Create experiment
260 if with_spikes:
261 # Create experiment with spike reconstruction
262 exp = Experiment('Synthetic',
263 all_calcium,
264 None,
265 {},
266 {'fps': fps},
267 {**discr_ts, **cont_ts},
268 reconstruct_spikes='wavelet')
269 else:
270 # Create experiment without spikes
271 exp = Experiment('Synthetic',
272 all_calcium,
273 None,
274 {},
275 {'fps': fps},
276 {**discr_ts, **cont_ts},
277 reconstruct_spikes=None)
279 return exp
282def generate_mixed_population_exp(n_neurons=100, manifold_fraction=0.6,
283 manifold_type='2d_spatial', manifold_params=None,
284 n_discrete_features=3, n_continuous_features=3,
285 feature_params=None, correlation_mode='independent',
286 correlation_strength=0.3, duration=600, fps=20.0,
287 seed=None, verbose=True, return_info=False):
288 """
289 Generate synthetic experiment with mixed population of manifold and feature-selective cells.
291 This function creates a neural population combining spatial cells (place cells, head direction)
292 with feature-selective cells responding to behavioral variables. The mixing ratio and
293 correlations between spatial and behavioral activities can be configured.
295 Parameters
296 ----------
297 n_neurons : int
298 Total number of neurons in the population.
299 manifold_fraction : float
300 Fraction of neurons that are manifold cells (0.0-1.0).
301 Remaining neurons will be feature-selective.
302 manifold_type : str
303 Type of manifold: 'circular', '2d_spatial', '3d_spatial'.
304 manifold_params : dict, optional
305 Parameters for manifold generation. If None, uses defaults.
306 n_discrete_features : int
307 Number of discrete behavioral features.
308 n_continuous_features : int
309 Number of continuous behavioral features.
310 feature_params : dict, optional
311 Parameters for feature generation. If None, uses defaults.
312 correlation_mode : str
313 How to correlate spatial and behavioral activities:
314 - 'independent': No correlation between spatial and behavioral
315 - 'spatial_correlated': Behavioral features modulated by spatial position
316 - 'feature_correlated': Spatial activity modulated by behavioral features
317 correlation_strength : float
318 Strength of correlation (0.0-1.0) when correlation_mode is not 'independent'.
319 duration : float
320 Duration of experiment in seconds.
321 fps : float
322 Sampling rate in Hz.
323 seed : int, optional
324 Random seed for reproducibility.
325 verbose : bool
326 Print progress messages.
327 return_info : bool
328 If True, return (exp, info) tuple. If False (default), return only exp.
330 Returns
331 -------
332 exp : Experiment
333 Experiment object with mixed population.
334 info : dict (only if return_info=True)
335 Dictionary containing:
336 - 'population_composition': Details about neuron allocation
337 - 'manifold_info': Information about manifold cells
338 - 'feature_selectivity': Information about feature-selective cells
339 - 'spatial_data': Spatial trajectory data
340 - 'behavioral_features': Behavioral feature data
341 - 'correlation_applied': Correlation mode used
343 Examples
344 --------
345 >>> # Generate population with 60% place cells, 40% feature-selective
346 >>> exp, info = generate_mixed_population_exp(
347 ... n_neurons=50,
348 ... manifold_fraction=0.6,
349 ... manifold_type='2d_spatial',
350 ... correlation_mode='spatial_correlated'
351 ... )
353 >>> # Check population composition
354 >>> print(f"Manifold cells: {info['population_composition']['n_manifold']}")
355 >>> print(f"Feature-selective: {info['population_composition']['n_feature_selective']}")
357 Notes
358 -----
359 The function integrates existing manifold and feature generators to create
360 realistic mixed populations. Spatial correlations can model scenarios where
361 behavioral variables depend on location (e.g., speed varying with position)
362 or where spatial coding is modulated by behavioral state.
363 """
364 if seed is not None:
365 np.random.seed(seed)
367 # Validate parameters
368 if not 0.0 <= manifold_fraction <= 1.0:
369 raise ValueError(f"manifold_fraction must be between 0.0 and 1.0, got {manifold_fraction}")
371 if manifold_type not in ['circular', '2d_spatial', '3d_spatial']:
372 raise ValueError(f"manifold_type must be 'circular', '2d_spatial', or '3d_spatial', got {manifold_type}")
374 if correlation_mode not in ['independent', 'spatial_correlated', 'feature_correlated']:
375 raise ValueError(f"Invalid correlation_mode: {correlation_mode}")
377 if not 0.0 <= correlation_strength <= 1.0:
378 raise ValueError(f"correlation_strength must be between 0.0 and 1.0, got {correlation_strength}")
380 selectivity_prob = feature_params.get('selectivity_prob', 1.0) if feature_params else 1.0
381 # Calculate population allocation
382 n_manifold = int(n_neurons * manifold_fraction)
383 n_feature_selective = int((n_neurons - n_manifold) * selectivity_prob)
385 if verbose:
386 print(f'Generating mixed population: {n_neurons} total neurons')
387 print(f' Manifold cells ({manifold_type}): {n_manifold}')
388 print(f' Expected feature-selective cells: {n_feature_selective}')
389 print(f' Correlation mode: {correlation_mode}')
391 # Set default parameters
392 if manifold_params is None:
393 manifold_params = {
394 'field_sigma': 0.1,
395 'baseline_rate': 0.1,
396 'peak_rate': 1.0, # Realistic for calcium imaging
397 'noise_std': 0.05,
398 'decay_time': 2.0,
399 'calcium_noise_std': 0.1
400 }
402 if feature_params is None:
403 feature_params = {
404 'rate_0': 0.1,
405 'rate_1': 1.0,
406 'skip_prob': 0.1,
407 'hurst': 0.3,
408 'ampl_range': (0.5, 2.0),
409 'decay_time': 2.0,
410 'noise_std': 0.1
411 }
413 # Initialize containers
414 all_calcium_signals = []
415 dynamic_features = {}
416 manifold_info = {}
417 spatial_data = None
418 feature_selectivity = None
420 # Generate manifold cells
421 if n_manifold > 0:
422 if verbose:
423 print(f' Generating {n_manifold} {manifold_type} manifold cells...')
425 manifold_seed = seed if seed is None else seed + 1000
427 if manifold_type == 'circular':
428 calcium_manifold, head_direction, preferred_dirs, firing_rates = \
429 generate_circular_manifold_data(
430 n_manifold, duration, fps,
431 kappa=manifold_params.get('kappa', 4.0),
432 step_std=manifold_params.get('step_std', 0.1),
433 baseline_rate=manifold_params['baseline_rate'],
434 peak_rate=manifold_params['peak_rate'],
435 noise_std=manifold_params['noise_std'],
436 decay_time=manifold_params['decay_time'],
437 calcium_noise_std=manifold_params['calcium_noise_std'],
438 seed=manifold_seed,
439 verbose=verbose
440 )
442 # Add circular features
443 dynamic_features['head_direction'] = TimeSeries(head_direction, discrete=False)
444 dynamic_features['circular_angle'] = MultiTimeSeries([
445 TimeSeries(np.cos(head_direction), discrete=False),
446 TimeSeries(np.sin(head_direction), discrete=False)
447 ])
449 spatial_data = head_direction
450 manifold_info = {
451 'manifold_type': 'circular',
452 'head_direction': head_direction,
453 'preferred_directions': preferred_dirs,
454 'firing_rates': firing_rates
455 }
457 elif manifold_type == '2d_spatial':
458 calcium_manifold, positions, centers, firing_rates = \
459 generate_2d_manifold_data(
460 n_manifold, duration, fps,
461 field_sigma=manifold_params['field_sigma'],
462 step_size=manifold_params.get('step_size', 0.02),
463 momentum=manifold_params.get('momentum', 0.8),
464 baseline_rate=manifold_params['baseline_rate'],
465 peak_rate=manifold_params['peak_rate'],
466 noise_std=manifold_params['noise_std'],
467 decay_time=manifold_params['decay_time'],
468 calcium_noise_std=manifold_params['calcium_noise_std'],
469 grid_arrangement=manifold_params.get('grid_arrangement', True),
470 seed=manifold_seed,
471 verbose=verbose
472 )
474 # Add spatial features
475 dynamic_features['x_position'] = TimeSeries(positions[0, :], discrete=False)
476 dynamic_features['y_position'] = TimeSeries(positions[1, :], discrete=False)
477 dynamic_features['position_2d'] = MultiTimeSeries([
478 TimeSeries(positions[0, :], discrete=False),
479 TimeSeries(positions[1, :], discrete=False)
480 ])
482 spatial_data = positions
483 manifold_info = {
484 'manifold_type': '2d_spatial',
485 'positions': positions,
486 'place_field_centers': centers,
487 'firing_rates': firing_rates
488 }
490 elif manifold_type == '3d_spatial':
491 calcium_manifold, positions, centers, firing_rates = \
492 generate_3d_manifold_data(
493 n_manifold, duration, fps,
494 field_sigma=manifold_params['field_sigma'],
495 step_size=manifold_params.get('step_size', 0.02),
496 momentum=manifold_params.get('momentum', 0.8),
497 baseline_rate=manifold_params['baseline_rate'],
498 peak_rate=manifold_params['peak_rate'],
499 noise_std=manifold_params['noise_std'],
500 decay_time=manifold_params['decay_time'],
501 calcium_noise_std=manifold_params['calcium_noise_std'],
502 grid_arrangement=manifold_params.get('grid_arrangement', True),
503 seed=manifold_seed,
504 verbose=verbose
505 )
507 # Add 3D spatial features
508 dynamic_features['x_position'] = TimeSeries(positions[0, :], discrete=False)
509 dynamic_features['y_position'] = TimeSeries(positions[1, :], discrete=False)
510 dynamic_features['z_position'] = TimeSeries(positions[2, :], discrete=False)
511 dynamic_features['position_3d'] = MultiTimeSeries([
512 TimeSeries(positions[0, :], discrete=False),
513 TimeSeries(positions[1, :], discrete=False),
514 TimeSeries(positions[2, :], discrete=False)
515 ])
517 spatial_data = positions
518 manifold_info = {
519 'manifold_type': '3d_spatial',
520 'positions': positions,
521 'place_field_centers': centers,
522 'firing_rates': firing_rates
523 }
525 all_calcium_signals.append(calcium_manifold)
527 # Generate behavioral features
528 behavioral_features_data = {}
530 if n_discrete_features > 0 or n_continuous_features > 0:
531 if verbose:
532 print(f' Generating behavioral features: {n_discrete_features} discrete, {n_continuous_features} continuous')
534 length = int(duration * fps)
535 feature_seed = seed if seed is None else seed + 2000
537 # Generate discrete features
538 for i in range(n_discrete_features):
539 binary_series = generate_binary_time_series(
540 length,
541 avg_islands=feature_params.get('avg_islands', 10),
542 avg_duration=int(feature_params.get('avg_duration', 5) * fps)
543 )
545 feat_name = f'd_feat_{i}'
546 behavioral_features_data[feat_name] = binary_series
547 dynamic_features[feat_name] = TimeSeries(binary_series, discrete=True)
548 if feature_seed is not None:
549 feature_seed += 1
551 # Generate continuous features
552 for i in range(n_continuous_features):
553 fbm_series = generate_fbm_time_series(
554 length,
555 hurst=feature_params['hurst'],
556 seed=feature_seed
557 )
559 feat_name = f'c_feat_{i}'
560 behavioral_features_data[feat_name] = fbm_series
561 dynamic_features[feat_name] = TimeSeries(fbm_series, discrete=False)
562 if feature_seed is not None:
563 feature_seed += 1
565 # Apply correlation if requested
566 if correlation_mode == 'spatial_correlated' and spatial_data is not None:
567 if verbose:
568 print(f' Applying spatial correlation (strength={correlation_strength})')
570 # Modulate behavioral features based on spatial position
571 for feat_name, feat_data in behavioral_features_data.items():
572 if 'c_feat' in feat_name: # Only continuous features
573 # Use average position as spatial signal
574 if spatial_data.ndim == 1: # Circular case
575 spatial_signal = np.sin(spatial_data) # Project to [-1, 1]
576 else: # 2D/3D spatial case
577 spatial_signal = np.mean(spatial_data, axis=0) # Average position
579 # Normalize spatial signal
580 spatial_signal = (spatial_signal - np.mean(spatial_signal)) / np.std(spatial_signal)
582 # Apply correlation
583 correlated_feat = (1 - correlation_strength) * feat_data + \
584 correlation_strength * spatial_signal * np.std(feat_data)
586 behavioral_features_data[feat_name] = correlated_feat
587 dynamic_features[feat_name] = TimeSeries(correlated_feat, discrete=False)
589 elif correlation_mode == 'independent':
590 # Ensure true independence by regenerating features with different seeds
591 if verbose:
592 print(f' Ensuring feature independence by regenerating behavioral features...')
594 # Use completely different seeds for independent features
595 independent_seed = seed + 10000 if seed is not None else None
597 # Regenerate discrete features with new seeds
598 for i in range(n_discrete_features):
599 if independent_seed is not None:
600 np.random.seed(independent_seed + i * 100)
602 # Generate new binary series with different temporal pattern
603 binary_series = generate_binary_time_series(
604 length,
605 avg_islands=feature_params.get('avg_islands', 10) + np.random.randint(-3, 4), # Vary parameters
606 avg_duration=int(feature_params.get('avg_duration', 5) * fps * np.random.uniform(0.5, 1.5))
607 )
609 feat_name = f'd_feat_{i}'
610 behavioral_features_data[feat_name] = binary_series
611 dynamic_features[feat_name] = TimeSeries(binary_series, discrete=True)
613 # Regenerate continuous features with new seeds
614 for i in range(n_continuous_features):
615 if independent_seed is not None:
616 np.random.seed(independent_seed + 1000 + i * 100)
618 # Use random low Hurst parameter (0.2-0.4) to break temporal autocorrelation and ensure independence
619 # Anti-persistent behavior (H < 0.5) breaks accidental spatial correlations
620 # Varying H across features prevents systematic correlations
621 low_hurst = np.random.uniform(0.2, 0.4)
623 # Apply random circular shift to break correlation with spatial trajectory
624 # Random shift between 1/4 and 3/4 of the series length
625 roll_shift = np.random.randint(length // 4, 3 * length // 4)
627 if verbose:
628 print(f' Feature c_feat_{i}: Using Hurst={low_hurst:.3f}, roll_shift={roll_shift} for independence')
630 fbm_series = generate_fbm_time_series(
631 length,
632 hurst=low_hurst,
633 seed=independent_seed + 1000 + i * 100 if independent_seed is not None else None,
634 roll_shift=roll_shift
635 )
637 feat_name = f'c_feat_{i}'
638 behavioral_features_data[feat_name] = fbm_series
639 dynamic_features[feat_name] = TimeSeries(fbm_series, discrete=False)
641 # Generate feature-selective cells
642 if n_feature_selective > 0:
643 if verbose:
644 print(f' Generating {n_feature_selective} feature-selective cells...')
646 feature_seed = seed if seed is None else seed + 3000
648 # Prepare features for synthetic data generation
649 discrete_feats = [behavioral_features_data[f'd_feat_{i}']
650 for i in range(n_discrete_features)]
651 continuous_feats = [behavioral_features_data[f'c_feat_{i}']
652 for i in range(n_continuous_features)]
654 all_feats = discrete_feats + continuous_feats
656 if len(all_feats) == 0:
657 # No features - generate baseline neurons
658 calcium_features = np.random.normal(0, feature_params['noise_std'],
659 (n_feature_selective, int(duration * fps)))
660 gt_features = np.zeros((0, n_feature_selective))
661 else:
662 # Check if mixed selectivity is requested
663 use_mixed_selectivity = feature_params.get('multi_select_prob', 0) > 0
665 if use_mixed_selectivity:
666 # Use mixed selectivity generation
667 selectivity_seed = None if feature_seed is None else feature_seed + 500
669 # Generate selectivity patterns
670 selectivity_matrix = generate_multiselectivity_patterns(
671 n_feature_selective,
672 n_discrete_features + n_continuous_features,
673 mode='random',
674 selectivity_prob=feature_params.get('selectivity_prob', 0.8),
675 multi_select_prob=feature_params.get('multi_select_prob', 0.4),
676 weights_mode='random',
677 seed=selectivity_seed
678 )
680 # Create features dictionary
681 features_dict = {}
682 for i in range(n_discrete_features):
683 features_dict[f'd_feat_{i}'] = behavioral_features_data[f'd_feat_{i}']
684 for i in range(n_continuous_features):
685 features_dict[f'c_feat_{i}'] = behavioral_features_data[f'c_feat_{i}']
687 # Generate mixed selective signals
688 calcium_features, gt_features = generate_synthetic_data_mixed_selectivity(
689 features_dict, n_feature_selective, selectivity_matrix,
690 duration=duration,
691 seed=feature_seed,
692 sampling_rate=fps,
693 rate_0=feature_params['rate_0'],
694 rate_1=feature_params['rate_1'],
695 skip_prob=feature_params['skip_prob'],
696 ampl_range=feature_params['ampl_range'],
697 decay_time=feature_params['decay_time'],
698 noise_std=feature_params['noise_std'],
699 verbose=False
700 )
701 else:
702 # Original code for single selectivity
703 # Generate neurons for discrete features
704 all_calcium_parts = []
705 all_gt_parts = []
707 if n_discrete_features > 0:
708 # Generate neurons selective to discrete features
709 discrete_seed = None if feature_seed is None else feature_seed + 10
710 # Pass pregenerated discrete features
711 discrete_feat_list = [behavioral_features_data[f'd_feat_{i}']
712 for i in range(n_discrete_features)]
713 feats_d, calcium_d, gt_d = generate_synthetic_data(
714 n_discrete_features, n_feature_selective // 2 if n_continuous_features > 0 else n_feature_selective,
715 ftype='d',
716 duration=duration,
717 seed=discrete_seed,
718 sampling_rate=fps,
719 rate_0=feature_params['rate_0'],
720 rate_1=feature_params['rate_1'],
721 skip_prob=feature_params['skip_prob'],
722 ampl_range=feature_params['ampl_range'],
723 decay_time=feature_params['decay_time'],
724 noise_std=feature_params['noise_std'],
725 verbose=verbose,
726 pregenerated_features=discrete_feat_list,
727 apply_random_neuron_shifts=(correlation_mode == 'independent')
728 )
729 all_calcium_parts.append(calcium_d)
730 # Adjust gt_d indices to account for all features
731 gt_d_adjusted = np.zeros((n_discrete_features + n_continuous_features, gt_d.shape[1]))
732 gt_d_adjusted[:n_discrete_features, :] = gt_d
733 all_gt_parts.append(gt_d_adjusted)
735 if n_continuous_features > 0:
736 # Generate neurons selective to continuous features
737 remaining_neurons = n_feature_selective - (len(all_calcium_parts[0]) if all_calcium_parts else 0)
738 continuous_seed = None if feature_seed is None else feature_seed + 100
739 # Pass pregenerated continuous features
740 continuous_feat_list = [behavioral_features_data[f'c_feat_{i}']
741 for i in range(n_continuous_features)]
742 feats_c, calcium_c, gt_c = generate_synthetic_data(
743 n_continuous_features, remaining_neurons,
744 ftype='c',
745 duration=duration,
746 seed=continuous_seed,
747 sampling_rate=fps,
748 rate_0=feature_params['rate_0'],
749 rate_1=feature_params['rate_1'],
750 skip_prob=feature_params['skip_prob'],
751 hurst=feature_params['hurst'],
752 ampl_range=feature_params['ampl_range'],
753 decay_time=feature_params['decay_time'],
754 noise_std=feature_params['noise_std'],
755 verbose=verbose,
756 pregenerated_features=continuous_feat_list,
757 apply_random_neuron_shifts=(correlation_mode == 'independent')
758 )
759 all_calcium_parts.append(calcium_c)
760 # Adjust gt_c indices to account for discrete features
761 gt_c_adjusted = np.zeros((n_discrete_features + n_continuous_features, gt_c.shape[1]))
762 gt_c_adjusted[n_discrete_features:, :] = gt_c
763 all_gt_parts.append(gt_c_adjusted)
765 # Combine calcium signals and ground truth
766 if len(all_calcium_parts) == 1:
767 calcium_features = all_calcium_parts[0]
768 gt_features = all_gt_parts[0]
769 else:
770 calcium_features = np.vstack(all_calcium_parts)
771 # Combine ground truth matrices
772 gt_features = np.zeros((n_discrete_features + n_continuous_features, calcium_features.shape[0]))
773 neuron_idx = 0
774 for gt_part in all_gt_parts:
775 n_neurons_part = gt_part.shape[1] if len(gt_part.shape) > 1 else 0
776 if n_neurons_part > 0:
777 gt_features[:, neuron_idx:neuron_idx + n_neurons_part] = gt_part
778 neuron_idx += n_neurons_part
780 # Apply feature correlation if requested
781 if correlation_mode == 'feature_correlated' and spatial_data is not None and n_manifold > 0:
782 if verbose:
783 print(f' Applying feature correlation to manifold cells (strength={correlation_strength})')
785 # Modulate manifold cells based on behavioral features
786 if len(all_feats) > 0:
787 # Use first continuous feature as modulation signal
788 modulation_signal = None
789 for feat_name, feat_data in behavioral_features_data.items():
790 if 'c_feat' in feat_name:
791 modulation_signal = feat_data
792 break
794 if modulation_signal is not None:
795 # Normalize modulation signal
796 mod_norm = (modulation_signal - np.mean(modulation_signal)) / np.std(modulation_signal)
798 # Apply to manifold calcium signals
799 for i in range(n_manifold):
800 baseline = np.mean(calcium_manifold[i])
801 modulated = calcium_manifold[i] + correlation_strength * mod_norm * baseline * 0.2
802 calcium_manifold[i] = np.maximum(0, modulated) # Ensure non-negative
804 all_calcium_signals.append(calcium_features)
805 feature_selectivity = gt_features
807 # Combine all calcium signals
808 if len(all_calcium_signals) == 1:
809 combined_calcium = all_calcium_signals[0]
810 else:
811 combined_calcium = np.vstack(all_calcium_signals)
813 # Create static features
814 static_features = {
815 'fps': fps,
816 't_rise_sec': 0.5,
817 't_off_sec': manifold_params.get('decay_time', 2.0)
818 }
820 # Create experiment
821 exp = Experiment(
822 'MixedPopulation',
823 combined_calcium,
824 None, # No spike data
825 {}, # No identificators
826 static_features,
827 dynamic_features,
828 reconstruct_spikes=None
829 )
831 # Prepare comprehensive info dictionary
832 info = {
833 'population_composition': {
834 'n_manifold': n_manifold,
835 'n_feature_selective': n_feature_selective,
836 'manifold_type': manifold_type,
837 'manifold_indices': list(range(n_manifold)),
838 'feature_indices': list(range(n_manifold, n_neurons)),
839 'manifold_fraction': manifold_fraction
840 },
841 'manifold_info': manifold_info,
842 'feature_selectivity': feature_selectivity,
843 'spatial_data': spatial_data,
844 'behavioral_features': behavioral_features_data,
845 'correlation_applied': correlation_mode,
846 'correlation_strength': correlation_strength if correlation_mode != 'independent' else 0.0,
847 'parameters': {
848 'manifold_params': manifold_params,
849 'feature_params': feature_params,
850 'n_discrete_features': n_discrete_features,
851 'n_continuous_features': n_continuous_features
852 }
853 }
855 if verbose:
856 print(f' Mixed population generated successfully!')
857 print(f' Total calcium traces: {combined_calcium.shape}')
858 print(f' Total features: {len(dynamic_features)}')
860 if return_info:
861 return exp, info
862 else:
863 return exp