Coverage for src/driada/experiment/synthetic/mixed_selectivity.py: 92.06%
126 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1"""
2Mixed selectivity generation for synthetic neural data.
4This module contains functions for generating synthetic neural data with mixed
5selectivity, where neurons can respond to multiple features simultaneously.
6"""
8import numpy as np
9import tqdm
10from .core import generate_pseudo_calcium_signal
11from .time_series import (
12 generate_binary_time_series, generate_fbm_time_series,
13 discretize_via_roi, delete_one_islands, apply_poisson_to_binary_series
14)
15from ..exp_base import Experiment
16from ...information.info_base import TimeSeries, aggregate_multiple_ts
19def generate_multiselectivity_patterns(n_neurons, n_features, mode='random',
20 selectivity_prob=0.3, multi_select_prob=0.4,
21 weights_mode='random', seed=None):
22 """
23 Generate selectivity patterns for neurons with mixed selectivity support.
25 Parameters
26 ----------
27 n_neurons : int
28 Number of neurons.
29 n_features : int
30 Number of features.
31 mode : str, optional
32 Pattern generation mode: 'random', 'structured'. Default: 'random'.
33 selectivity_prob : float, optional
34 Probability of a neuron being selective to any feature. Default: 0.3.
35 multi_select_prob : float, optional
36 Probability of selective neuron having mixed selectivity. Default: 0.4.
37 weights_mode : str, optional
38 Weight generation mode: 'random', 'dominant', 'equal'. Default: 'random'.
39 seed : int, optional
40 Random seed for reproducibility.
42 Returns
43 -------
44 selectivity_matrix : ndarray
45 Matrix of shape (n_features, n_neurons) with selectivity weights.
46 Non-zero values indicate selectivity strength.
47 """
48 if seed is not None:
49 np.random.seed(seed)
51 selectivity_matrix = np.zeros((n_features, n_neurons))
53 for j in range(n_neurons):
54 # Decide if neuron is selective
55 if np.random.rand() > selectivity_prob:
56 continue
58 # Decide if neuron has mixed selectivity
59 if np.random.rand() < multi_select_prob:
60 # Mixed selectivity: 2-3 features
61 n_select = np.random.choice([2, 3], p=[0.7, 0.3])
62 else:
63 # Single selectivity
64 n_select = 1
66 # Choose features (ensure we don't try to select more than available)
67 n_select = min(n_select, n_features)
68 if n_select == 0:
69 continue
70 selected_features = np.random.choice(n_features, n_select, replace=False)
72 # Assign weights
73 if weights_mode == 'equal':
74 weights = np.ones(n_select) / n_select
75 elif weights_mode == 'dominant':
76 # One feature dominates
77 weights = np.random.dirichlet([5] + [1] * (n_select - 1))
78 else: # random
79 weights = np.random.dirichlet(np.ones(n_select))
81 # Set weights in matrix
82 selectivity_matrix[selected_features, j] = weights
84 return selectivity_matrix
87def generate_mixed_selective_signal(features, weights, duration, sampling_rate,
88 rate_0=0.1, rate_1=1.0, skip_prob=0.1,
89 ampl_range=(0.5, 2), decay_time=2, noise_std=0.1,
90 seed=None):
91 """
92 Generate neural signal selective to multiple features.
94 Parameters
95 ----------
96 features : list of arrays
97 List of feature time series.
98 weights : array-like
99 Weights for each feature contribution.
100 duration : float
101 Signal duration in seconds.
102 sampling_rate : float
103 Sampling rate in Hz.
104 Other parameters same as generate_pseudo_calcium_signal.
106 Returns
107 -------
108 signal : array
109 Generated calcium signal.
110 """
111 if seed is not None:
112 np.random.seed(seed)
114 length = int(duration * sampling_rate)
115 combined_activation = np.zeros(length)
117 # Combine feature activations
118 for feat, weight in zip(features, weights):
119 if weight == 0:
120 continue
122 # Check if already binary
123 unique_vals = np.unique(feat)
124 if len(unique_vals) == 2 and set(unique_vals).issubset({0, 1}):
125 # Already binary
126 binary_activation = feat.astype(float)
127 else:
128 # Use ROI-based discretization for continuous
129 binary_activation = discretize_via_roi(feat, seed=seed)
130 binary_activation = binary_activation.astype(float)
132 # Weight the activation
133 combined_activation += weight * binary_activation
134 if seed is not None:
135 seed += 1
137 # Threshold to get final binary activation
138 threshold = np.random.uniform(0.3, 0.7) # Flexible threshold
139 final_activation = (combined_activation >= threshold).astype(int)
141 # Add stochasticity
142 mod_activation = delete_one_islands(final_activation, skip_prob)
144 # Generate Poisson events
145 poisson_series = apply_poisson_to_binary_series(mod_activation,
146 rate_0 / sampling_rate,
147 rate_1 / sampling_rate)
149 # Generate calcium signal
150 calcium_signal = generate_pseudo_calcium_signal(duration=duration,
151 events=poisson_series,
152 sampling_rate=sampling_rate,
153 amplitude_range=ampl_range,
154 decay_time=decay_time,
155 noise_std=noise_std)
157 return calcium_signal
160def generate_synthetic_data_mixed_selectivity(features_dict, n_neurons, selectivity_matrix,
161 duration=600, seed=42, sampling_rate=20.0,
162 rate_0=0.1, rate_1=1.0, skip_prob=0.0,
163 ampl_range=(0.5, 2), decay_time=2, noise_std=0.1,
164 verbose=True):
165 """
166 Generate synthetic data with mixed selectivity support.
168 Parameters
169 ----------
170 features_dict : dict
171 Dictionary of feature_name: feature_array pairs.
172 n_neurons : int
173 Number of neurons to generate.
174 selectivity_matrix : ndarray
175 Matrix of shape (n_features, n_neurons) with selectivity weights.
176 Other parameters same as generate_synthetic_data.
178 Returns
179 -------
180 all_signals : ndarray
181 Neural signals of shape (n_neurons, n_timepoints).
182 ground_truth : ndarray
183 Ground truth selectivity matrix (same as input selectivity_matrix).
184 """
185 feature_names = list(features_dict.keys())
186 feature_arrays = [features_dict[name] for name in feature_names]
188 if verbose:
189 print('Generating mixed-selective neural signals...')
191 all_signals = []
193 for j in tqdm.tqdm(range(n_neurons)):
194 # Get selectivity pattern for this neuron
195 weights = selectivity_matrix[:, j]
196 selective_features = np.where(weights > 0)[0]
198 if len(selective_features) == 0:
199 # Non-selective neuron - just noise
200 signal = np.random.normal(0, noise_std, int(duration * sampling_rate))
201 else:
202 # Get features and weights
203 selected_feat_arrays = [feature_arrays[i] for i in selective_features]
204 selected_weights = weights[selective_features]
206 # Generate mixed selective signal
207 signal = generate_mixed_selective_signal(
208 selected_feat_arrays, selected_weights,
209 duration, sampling_rate,
210 rate_0, rate_1, skip_prob,
211 ampl_range, decay_time, noise_std,
212 seed=seed + j if seed is not None else None
213 )
215 all_signals.append(signal)
217 return np.vstack(all_signals), selectivity_matrix
220def generate_synthetic_exp_with_mixed_selectivity(n_discrete_feats=4, n_continuous_feats=4,
221 n_neurons=50, n_multifeatures=2,
222 create_discrete_pairs=True,
223 selectivity_prob=0.8, multi_select_prob=0.5,
224 weights_mode='random', duration=1200,
225 seed=42, fps=20, verbose=True,
226 name_convention='str',
227 rate_0=0.1, rate_1=1.0, skip_prob=0.1,
228 ampl_range=(0.5, 2), decay_time=2, noise_std=0.1):
229 """
230 Generate synthetic experiment with mixed selectivity and multifeatures.
232 Parameters
233 ----------
234 n_discrete_feats : int
235 Number of discrete features to generate.
236 n_continuous_feats : int
237 Number of continuous features to generate.
238 n_neurons : int
239 Number of neurons to generate.
240 n_multifeatures : int
241 Number of multifeature combinations to create.
242 create_discrete_pairs : bool
243 If True, create discretized versions of continuous features.
244 selectivity_prob : float
245 Probability of a neuron being selective.
246 multi_select_prob : float
247 Probability of mixed selectivity for selective neurons.
248 weights_mode : str
249 Weight generation mode: 'random', 'dominant', 'equal'.
250 duration : float
251 Experiment duration in seconds.
252 seed : int
253 Random seed.
254 fps : float
255 Sampling rate.
256 verbose : bool
257 Print progress messages.
258 name_convention : str, optional
259 Naming convention for multifeatures. Options:
260 - 'str' (default): Use string keys like 'xy', 'speed_direction'
261 - 'tuple': Use tuple keys like ('x', 'y'), ('speed', 'head_direction') [DEPRECATED]
262 rate_0 : float, optional
263 Baseline spike rate in Hz. Default: 0.1.
264 rate_1 : float, optional
265 Active spike rate in Hz. Default: 1.0.
266 skip_prob : float, optional
267 Probability of skipping spikes. Default: 0.1.
268 ampl_range : tuple, optional
269 Range of spike amplitudes. Default: (0.5, 2).
270 decay_time : float, optional
271 Calcium decay time constant in seconds. Default: 2.
272 noise_std : float, optional
273 Standard deviation of additive noise. Default: 0.1.
275 Returns
276 -------
277 exp : Experiment
278 Synthetic experiment with mixed selectivity.
279 selectivity_info : dict
280 Dictionary containing:
281 - 'matrix': selectivity matrix
282 - 'feature_names': ordered list of feature names
283 - 'multifeature_map': multifeature definitions
284 """
285 if seed is not None:
286 np.random.seed(seed)
288 length = int(duration * fps)
289 features_dict = {}
291 # Generate discrete features
292 if verbose:
293 print(f'Generating {n_discrete_feats} discrete features...')
294 for i in range(n_discrete_feats):
295 binary_series = generate_binary_time_series(length, avg_islands=10,
296 avg_duration=int(5 * fps))
297 features_dict[f'd_feat_{i}'] = binary_series
299 # Generate continuous features
300 if verbose:
301 print(f'Generating {n_continuous_feats} continuous features...')
302 for i in range(n_continuous_feats):
303 fbm_series = generate_fbm_time_series(length, hurst=0.3, seed=seed + i + 100)
304 features_dict[f'c_feat_{i}'] = fbm_series
306 # Create discretized pairs if requested
307 if create_discrete_pairs:
308 disc_series = discretize_via_roi(fbm_series, seed=seed + i + 200)
309 features_dict[f'd_feat_from_c{i}'] = disc_series
311 # Create multifeatures from existing continuous features
312 multifeatures_to_create = []
313 if n_multifeatures > 0 and n_continuous_feats >= 2:
314 if verbose:
315 print(f'Creating {n_multifeatures} multifeatures...')
317 # Get all continuous features
318 continuous_feats = [f for f in features_dict.keys() if 'c_feat' in f]
320 # Create multifeatures by pairing continuous features
321 multi_idx = 0
322 for i in range(0, min(n_multifeatures * 2, len(continuous_feats)), 2):
323 if multi_idx >= n_multifeatures:
324 break
325 if i + 1 < len(continuous_feats):
326 feat1 = continuous_feats[i]
327 feat2 = continuous_feats[i + 1]
329 if name_convention == 'str':
330 # String key for the multifeature
331 mf_name = f'multi{multi_idx}'
332 multifeatures_to_create.append((mf_name, (feat1, feat2)))
333 else: # 'tuple' convention (deprecated)
334 # Tuple key for the multifeature
335 # TODO: this need fixing
336 multifeatures_to_create.append(((feat1, feat2), (feat1, feat2)))
338 multi_idx += 1
340 # Generate selectivity patterns
341 all_feature_names = list(features_dict.keys())
342 n_total_features = len(all_feature_names)
344 if verbose:
345 print(f'Generating selectivity patterns for {n_neurons} neurons...')
346 selectivity_matrix = generate_multiselectivity_patterns(
347 n_neurons, n_total_features,
348 selectivity_prob=selectivity_prob,
349 multi_select_prob=multi_select_prob,
350 weights_mode=weights_mode,
351 seed=seed + 300
352 )
354 # Generate neural signals
355 calcium_signals, _ = generate_synthetic_data_mixed_selectivity(
356 features_dict, n_neurons, selectivity_matrix,
357 duration=duration, seed=seed + 400, sampling_rate=fps,
358 rate_0=rate_0, rate_1=rate_1, skip_prob=skip_prob,
359 ampl_range=ampl_range, decay_time=decay_time, noise_std=noise_std,
360 verbose=verbose
361 )
363 # Create TimeSeries objects
364 dynamic_features = {}
365 for feat_name, feat_data in features_dict.items():
366 # Determine if discrete
367 unique_vals = np.unique(feat_data)
368 is_discrete = len(unique_vals) <= 10 or (
369 len(unique_vals) == 2 and set(unique_vals).issubset({0, 1})
370 )
371 dynamic_features[feat_name] = TimeSeries(feat_data, discrete=is_discrete)
373 # Add multifeatures using aggregate_multiple_ts
374 for mf_key, mf_components in multifeatures_to_create:
375 # Get component TimeSeries
376 component_ts = []
377 for component_name in mf_components:
378 if component_name in dynamic_features and not dynamic_features[component_name].discrete:
379 component_ts.append(dynamic_features[component_name])
381 # Create MultiTimeSeries if all components are continuous
382 if len(component_ts) == len(mf_components):
383 dynamic_features[mf_key] = aggregate_multiple_ts(*component_ts)
385 # Create experiment
386 exp = Experiment('SyntheticMixedSelectivity',
387 calcium_signals,
388 None,
389 {},
390 {'fps': fps},
391 dynamic_features,
392 reconstruct_spikes=None)
394 # Prepare selectivity info
395 # Create multifeature map for return value
396 multifeature_map = {}
397 for i, (mf_key, mf_components) in enumerate(multifeatures_to_create):
398 if isinstance(mf_key, str):
399 # For string convention: components tuple -> multifeature name
400 multifeature_map[mf_components] = mf_key
401 else:
402 # For tuple convention: components tuple -> generated name
403 multifeature_map[mf_key] = f'multifeature_{i}'
405 selectivity_info = {
406 'matrix': selectivity_matrix,
407 'feature_names': all_feature_names,
408 'multifeature_map': multifeature_map
409 }
411 return exp, selectivity_info