Coverage for src/driada/experiment/synthetic/mixed_selectivity.py: 92.06%

126 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1""" 

2Mixed selectivity generation for synthetic neural data. 

3 

4This module contains functions for generating synthetic neural data with mixed 

5selectivity, where neurons can respond to multiple features simultaneously. 

6""" 

7 

8import numpy as np 

9import tqdm 

10from .core import generate_pseudo_calcium_signal 

11from .time_series import ( 

12 generate_binary_time_series, generate_fbm_time_series, 

13 discretize_via_roi, delete_one_islands, apply_poisson_to_binary_series 

14) 

15from ..exp_base import Experiment 

16from ...information.info_base import TimeSeries, aggregate_multiple_ts 

17 

18 

19def generate_multiselectivity_patterns(n_neurons, n_features, mode='random', 

20 selectivity_prob=0.3, multi_select_prob=0.4, 

21 weights_mode='random', seed=None): 

22 """ 

23 Generate selectivity patterns for neurons with mixed selectivity support. 

24  

25 Parameters 

26 ---------- 

27 n_neurons : int 

28 Number of neurons. 

29 n_features : int 

30 Number of features. 

31 mode : str, optional 

32 Pattern generation mode: 'random', 'structured'. Default: 'random'. 

33 selectivity_prob : float, optional 

34 Probability of a neuron being selective to any feature. Default: 0.3. 

35 multi_select_prob : float, optional 

36 Probability of selective neuron having mixed selectivity. Default: 0.4. 

37 weights_mode : str, optional 

38 Weight generation mode: 'random', 'dominant', 'equal'. Default: 'random'. 

39 seed : int, optional 

40 Random seed for reproducibility. 

41  

42 Returns 

43 ------- 

44 selectivity_matrix : ndarray 

45 Matrix of shape (n_features, n_neurons) with selectivity weights. 

46 Non-zero values indicate selectivity strength. 

47 """ 

48 if seed is not None: 

49 np.random.seed(seed) 

50 

51 selectivity_matrix = np.zeros((n_features, n_neurons)) 

52 

53 for j in range(n_neurons): 

54 # Decide if neuron is selective 

55 if np.random.rand() > selectivity_prob: 

56 continue 

57 

58 # Decide if neuron has mixed selectivity 

59 if np.random.rand() < multi_select_prob: 

60 # Mixed selectivity: 2-3 features 

61 n_select = np.random.choice([2, 3], p=[0.7, 0.3]) 

62 else: 

63 # Single selectivity 

64 n_select = 1 

65 

66 # Choose features (ensure we don't try to select more than available) 

67 n_select = min(n_select, n_features) 

68 if n_select == 0: 

69 continue 

70 selected_features = np.random.choice(n_features, n_select, replace=False) 

71 

72 # Assign weights 

73 if weights_mode == 'equal': 

74 weights = np.ones(n_select) / n_select 

75 elif weights_mode == 'dominant': 

76 # One feature dominates 

77 weights = np.random.dirichlet([5] + [1] * (n_select - 1)) 

78 else: # random 

79 weights = np.random.dirichlet(np.ones(n_select)) 

80 

81 # Set weights in matrix 

82 selectivity_matrix[selected_features, j] = weights 

83 

84 return selectivity_matrix 

85 

86 

87def generate_mixed_selective_signal(features, weights, duration, sampling_rate, 

88 rate_0=0.1, rate_1=1.0, skip_prob=0.1, 

89 ampl_range=(0.5, 2), decay_time=2, noise_std=0.1, 

90 seed=None): 

91 """ 

92 Generate neural signal selective to multiple features. 

93  

94 Parameters 

95 ---------- 

96 features : list of arrays 

97 List of feature time series. 

98 weights : array-like 

99 Weights for each feature contribution. 

100 duration : float 

101 Signal duration in seconds. 

102 sampling_rate : float 

103 Sampling rate in Hz. 

104 Other parameters same as generate_pseudo_calcium_signal. 

105  

106 Returns 

107 ------- 

108 signal : array 

109 Generated calcium signal. 

110 """ 

111 if seed is not None: 

112 np.random.seed(seed) 

113 

114 length = int(duration * sampling_rate) 

115 combined_activation = np.zeros(length) 

116 

117 # Combine feature activations 

118 for feat, weight in zip(features, weights): 

119 if weight == 0: 

120 continue 

121 

122 # Check if already binary 

123 unique_vals = np.unique(feat) 

124 if len(unique_vals) == 2 and set(unique_vals).issubset({0, 1}): 

125 # Already binary 

126 binary_activation = feat.astype(float) 

127 else: 

128 # Use ROI-based discretization for continuous 

129 binary_activation = discretize_via_roi(feat, seed=seed) 

130 binary_activation = binary_activation.astype(float) 

131 

132 # Weight the activation 

133 combined_activation += weight * binary_activation 

134 if seed is not None: 

135 seed += 1 

136 

137 # Threshold to get final binary activation 

138 threshold = np.random.uniform(0.3, 0.7) # Flexible threshold 

139 final_activation = (combined_activation >= threshold).astype(int) 

140 

141 # Add stochasticity 

142 mod_activation = delete_one_islands(final_activation, skip_prob) 

143 

144 # Generate Poisson events 

145 poisson_series = apply_poisson_to_binary_series(mod_activation, 

146 rate_0 / sampling_rate, 

147 rate_1 / sampling_rate) 

148 

149 # Generate calcium signal 

150 calcium_signal = generate_pseudo_calcium_signal(duration=duration, 

151 events=poisson_series, 

152 sampling_rate=sampling_rate, 

153 amplitude_range=ampl_range, 

154 decay_time=decay_time, 

155 noise_std=noise_std) 

156 

157 return calcium_signal 

158 

159 

160def generate_synthetic_data_mixed_selectivity(features_dict, n_neurons, selectivity_matrix, 

161 duration=600, seed=42, sampling_rate=20.0, 

162 rate_0=0.1, rate_1=1.0, skip_prob=0.0, 

163 ampl_range=(0.5, 2), decay_time=2, noise_std=0.1, 

164 verbose=True): 

165 """ 

166 Generate synthetic data with mixed selectivity support. 

167  

168 Parameters 

169 ---------- 

170 features_dict : dict 

171 Dictionary of feature_name: feature_array pairs. 

172 n_neurons : int 

173 Number of neurons to generate. 

174 selectivity_matrix : ndarray 

175 Matrix of shape (n_features, n_neurons) with selectivity weights. 

176 Other parameters same as generate_synthetic_data. 

177  

178 Returns 

179 ------- 

180 all_signals : ndarray 

181 Neural signals of shape (n_neurons, n_timepoints). 

182 ground_truth : ndarray 

183 Ground truth selectivity matrix (same as input selectivity_matrix). 

184 """ 

185 feature_names = list(features_dict.keys()) 

186 feature_arrays = [features_dict[name] for name in feature_names] 

187 

188 if verbose: 

189 print('Generating mixed-selective neural signals...') 

190 

191 all_signals = [] 

192 

193 for j in tqdm.tqdm(range(n_neurons)): 

194 # Get selectivity pattern for this neuron 

195 weights = selectivity_matrix[:, j] 

196 selective_features = np.where(weights > 0)[0] 

197 

198 if len(selective_features) == 0: 

199 # Non-selective neuron - just noise 

200 signal = np.random.normal(0, noise_std, int(duration * sampling_rate)) 

201 else: 

202 # Get features and weights 

203 selected_feat_arrays = [feature_arrays[i] for i in selective_features] 

204 selected_weights = weights[selective_features] 

205 

206 # Generate mixed selective signal 

207 signal = generate_mixed_selective_signal( 

208 selected_feat_arrays, selected_weights, 

209 duration, sampling_rate, 

210 rate_0, rate_1, skip_prob, 

211 ampl_range, decay_time, noise_std, 

212 seed=seed + j if seed is not None else None 

213 ) 

214 

215 all_signals.append(signal) 

216 

217 return np.vstack(all_signals), selectivity_matrix 

218 

219 

220def generate_synthetic_exp_with_mixed_selectivity(n_discrete_feats=4, n_continuous_feats=4, 

221 n_neurons=50, n_multifeatures=2, 

222 create_discrete_pairs=True, 

223 selectivity_prob=0.8, multi_select_prob=0.5, 

224 weights_mode='random', duration=1200, 

225 seed=42, fps=20, verbose=True, 

226 name_convention='str', 

227 rate_0=0.1, rate_1=1.0, skip_prob=0.1, 

228 ampl_range=(0.5, 2), decay_time=2, noise_std=0.1): 

229 """ 

230 Generate synthetic experiment with mixed selectivity and multifeatures. 

231  

232 Parameters 

233 ---------- 

234 n_discrete_feats : int 

235 Number of discrete features to generate. 

236 n_continuous_feats : int 

237 Number of continuous features to generate. 

238 n_neurons : int 

239 Number of neurons to generate. 

240 n_multifeatures : int 

241 Number of multifeature combinations to create. 

242 create_discrete_pairs : bool 

243 If True, create discretized versions of continuous features. 

244 selectivity_prob : float 

245 Probability of a neuron being selective. 

246 multi_select_prob : float 

247 Probability of mixed selectivity for selective neurons. 

248 weights_mode : str 

249 Weight generation mode: 'random', 'dominant', 'equal'. 

250 duration : float 

251 Experiment duration in seconds. 

252 seed : int 

253 Random seed. 

254 fps : float 

255 Sampling rate. 

256 verbose : bool 

257 Print progress messages. 

258 name_convention : str, optional 

259 Naming convention for multifeatures. Options: 

260 - 'str' (default): Use string keys like 'xy', 'speed_direction' 

261 - 'tuple': Use tuple keys like ('x', 'y'), ('speed', 'head_direction') [DEPRECATED] 

262 rate_0 : float, optional 

263 Baseline spike rate in Hz. Default: 0.1. 

264 rate_1 : float, optional 

265 Active spike rate in Hz. Default: 1.0. 

266 skip_prob : float, optional 

267 Probability of skipping spikes. Default: 0.1. 

268 ampl_range : tuple, optional 

269 Range of spike amplitudes. Default: (0.5, 2). 

270 decay_time : float, optional 

271 Calcium decay time constant in seconds. Default: 2. 

272 noise_std : float, optional 

273 Standard deviation of additive noise. Default: 0.1. 

274  

275 Returns 

276 ------- 

277 exp : Experiment 

278 Synthetic experiment with mixed selectivity. 

279 selectivity_info : dict 

280 Dictionary containing: 

281 - 'matrix': selectivity matrix 

282 - 'feature_names': ordered list of feature names 

283 - 'multifeature_map': multifeature definitions 

284 """ 

285 if seed is not None: 

286 np.random.seed(seed) 

287 

288 length = int(duration * fps) 

289 features_dict = {} 

290 

291 # Generate discrete features 

292 if verbose: 

293 print(f'Generating {n_discrete_feats} discrete features...') 

294 for i in range(n_discrete_feats): 

295 binary_series = generate_binary_time_series(length, avg_islands=10, 

296 avg_duration=int(5 * fps)) 

297 features_dict[f'd_feat_{i}'] = binary_series 

298 

299 # Generate continuous features 

300 if verbose: 

301 print(f'Generating {n_continuous_feats} continuous features...') 

302 for i in range(n_continuous_feats): 

303 fbm_series = generate_fbm_time_series(length, hurst=0.3, seed=seed + i + 100) 

304 features_dict[f'c_feat_{i}'] = fbm_series 

305 

306 # Create discretized pairs if requested 

307 if create_discrete_pairs: 

308 disc_series = discretize_via_roi(fbm_series, seed=seed + i + 200) 

309 features_dict[f'd_feat_from_c{i}'] = disc_series 

310 

311 # Create multifeatures from existing continuous features 

312 multifeatures_to_create = [] 

313 if n_multifeatures > 0 and n_continuous_feats >= 2: 

314 if verbose: 

315 print(f'Creating {n_multifeatures} multifeatures...') 

316 

317 # Get all continuous features 

318 continuous_feats = [f for f in features_dict.keys() if 'c_feat' in f] 

319 

320 # Create multifeatures by pairing continuous features 

321 multi_idx = 0 

322 for i in range(0, min(n_multifeatures * 2, len(continuous_feats)), 2): 

323 if multi_idx >= n_multifeatures: 

324 break 

325 if i + 1 < len(continuous_feats): 

326 feat1 = continuous_feats[i] 

327 feat2 = continuous_feats[i + 1] 

328 

329 if name_convention == 'str': 

330 # String key for the multifeature 

331 mf_name = f'multi{multi_idx}' 

332 multifeatures_to_create.append((mf_name, (feat1, feat2))) 

333 else: # 'tuple' convention (deprecated) 

334 # Tuple key for the multifeature 

335 # TODO: this need fixing 

336 multifeatures_to_create.append(((feat1, feat2), (feat1, feat2))) 

337 

338 multi_idx += 1 

339 

340 # Generate selectivity patterns 

341 all_feature_names = list(features_dict.keys()) 

342 n_total_features = len(all_feature_names) 

343 

344 if verbose: 

345 print(f'Generating selectivity patterns for {n_neurons} neurons...') 

346 selectivity_matrix = generate_multiselectivity_patterns( 

347 n_neurons, n_total_features, 

348 selectivity_prob=selectivity_prob, 

349 multi_select_prob=multi_select_prob, 

350 weights_mode=weights_mode, 

351 seed=seed + 300 

352 ) 

353 

354 # Generate neural signals 

355 calcium_signals, _ = generate_synthetic_data_mixed_selectivity( 

356 features_dict, n_neurons, selectivity_matrix, 

357 duration=duration, seed=seed + 400, sampling_rate=fps, 

358 rate_0=rate_0, rate_1=rate_1, skip_prob=skip_prob, 

359 ampl_range=ampl_range, decay_time=decay_time, noise_std=noise_std, 

360 verbose=verbose 

361 ) 

362 

363 # Create TimeSeries objects 

364 dynamic_features = {} 

365 for feat_name, feat_data in features_dict.items(): 

366 # Determine if discrete 

367 unique_vals = np.unique(feat_data) 

368 is_discrete = len(unique_vals) <= 10 or ( 

369 len(unique_vals) == 2 and set(unique_vals).issubset({0, 1}) 

370 ) 

371 dynamic_features[feat_name] = TimeSeries(feat_data, discrete=is_discrete) 

372 

373 # Add multifeatures using aggregate_multiple_ts 

374 for mf_key, mf_components in multifeatures_to_create: 

375 # Get component TimeSeries 

376 component_ts = [] 

377 for component_name in mf_components: 

378 if component_name in dynamic_features and not dynamic_features[component_name].discrete: 

379 component_ts.append(dynamic_features[component_name]) 

380 

381 # Create MultiTimeSeries if all components are continuous 

382 if len(component_ts) == len(mf_components): 

383 dynamic_features[mf_key] = aggregate_multiple_ts(*component_ts) 

384 

385 # Create experiment 

386 exp = Experiment('SyntheticMixedSelectivity', 

387 calcium_signals, 

388 None, 

389 {}, 

390 {'fps': fps}, 

391 dynamic_features, 

392 reconstruct_spikes=None) 

393 

394 # Prepare selectivity info 

395 # Create multifeature map for return value 

396 multifeature_map = {} 

397 for i, (mf_key, mf_components) in enumerate(multifeatures_to_create): 

398 if isinstance(mf_key, str): 

399 # For string convention: components tuple -> multifeature name 

400 multifeature_map[mf_components] = mf_key 

401 else: 

402 # For tuple convention: components tuple -> generated name 

403 multifeature_map[mf_key] = f'multifeature_{i}' 

404 

405 selectivity_info = { 

406 'matrix': selectivity_matrix, 

407 'feature_names': all_feature_names, 

408 'multifeature_map': multifeature_map 

409 } 

410 

411 return exp, selectivity_info