Coverage for src/driada/integration/selectivity_mapper.py: 69.92%

123 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1""" 

2SelectivityManifoldMapper: Bridge between INTENSE selectivity analysis and dimensionality reduction. 

3 

4This module provides tools for analyzing the relationship between single-neuron selectivity 

5profiles and population-level manifold structure. 

6""" 

7 

8import logging 

9from typing import Dict, List, Optional, Tuple, Union, Any 

10import numpy as np 

11from scipy import stats 

12from sklearn.decomposition import PCA 

13 

14from ..experiment import Experiment 

15from ..dim_reduction.data import MVData 

16from ..intense.pipelines import compute_embedding_selectivity 

17 

18 

19class SelectivityManifoldMapper: 

20 """ 

21 Analyzes relationships between neuronal selectivity and manifold structure. 

22  

23 This class provides methods to: 

24 1. Create population embeddings and store them in the Experiment 

25 2. Compute selectivity of neurons to embedding components 

26 3. Analyze functional organization in the manifold 

27  

28 Parameters 

29 ---------- 

30 experiment : Experiment 

31 An Experiment object with computed selectivity results 

32 device : Optional[torch.device], default=None 

33 Device for computation (for future GPU support) 

34 logger : Optional[logging.Logger], default=None 

35 Logger for debugging and info messages 

36 config : Optional[Dict], default=None 

37 Configuration dictionary for custom parameters 

38  

39 Examples 

40 -------- 

41 >>> # Create mapper and generate embeddings 

42 >>> mapper = SelectivityManifoldMapper(exp) 

43 >>>  

44 >>> # Create and store PCA embedding 

45 >>> mapper.create_embedding('pca', n_components=10, neuron_selection='significant') 

46 >>>  

47 >>> # Analyze neuron selectivity to PCA components 

48 >>> results = mapper.analyze_embedding_selectivity('pca') 

49 >>>  

50 >>> # Get functional organization summary 

51 >>> summary = mapper.get_functional_organization('pca') 

52 """ 

53 

54 def __init__( 

55 self, 

56 experiment: Experiment, 

57 device: Optional[Any] = None, 

58 logger: Optional[logging.Logger] = None, 

59 config: Optional[Dict] = None 

60 ): 

61 self.experiment = experiment 

62 self.device = device 

63 self.logger = logger or logging.getLogger(self.__class__.__name__) 

64 self.config = config or {} 

65 

66 # Validate experiment has required data 

67 if not hasattr(experiment, 'calcium') or experiment.calcium is None: 

68 raise ValueError("Experiment must have calcium data") 

69 

70 # Check if selectivity analysis has been performed 

71 self.has_selectivity = hasattr(experiment, 'stats_tables') and experiment.stats_tables 

72 

73 if self.logger: 

74 self.logger.info( 

75 f"Initialized {self.__class__.__name__} with {experiment.n_cells} neurons" 

76 ) 

77 

78 def create_embedding( 

79 self, 

80 method: str, 

81 n_components: int = 2, 

82 data_type: str = 'calcium', 

83 neuron_selection: Optional[Union[str, List[int]]] = None, 

84 **dr_kwargs 

85 ) -> np.ndarray: 

86 """ 

87 Create dimensionality reduction embedding and store it in the experiment. 

88  

89 Parameters 

90 ---------- 

91 method : str 

92 DR method name ('pca', 'umap', 'isomap', etc.) 

93 n_components : int 

94 Number of embedding dimensions 

95 data_type : str 

96 Type of data to use ('calcium' or 'spikes') 

97 neuron_selection : str, list or None 

98 How to select neurons: 

99 - None or 'all': Use all neurons 

100 - 'significant': Use only significantly selective neurons 

101 - List of integers: Use specific neuron indices 

102 **dr_kwargs 

103 Additional arguments for the DR method 

104  

105 Returns 

106 ------- 

107 embedding : np.ndarray 

108 The embedding array, shape (n_timepoints, n_components) 

109 """ 

110 # Select neurons 

111 if neuron_selection is None or neuron_selection == 'all': 

112 neuron_indices = np.arange(self.experiment.n_cells) 

113 elif neuron_selection == 'significant': 

114 if not self.has_selectivity: 

115 raise ValueError("Cannot select significant neurons without selectivity analysis") 

116 sig_neurons = self.experiment.get_significant_neurons() 

117 neuron_indices = np.array(list(sig_neurons.keys())) 

118 if len(neuron_indices) == 0: 

119 self.logger.warning("No significant neurons found, using all neurons") 

120 neuron_indices = np.arange(self.experiment.n_cells) 

121 else: 

122 neuron_indices = np.array(neuron_selection) 

123 

124 # Get neural data 

125 if data_type == 'calcium': 

126 neural_data = self.experiment.calcium.data[neuron_indices, :] 

127 elif data_type == 'spikes': 

128 neural_data = self.experiment.spikes.data[neuron_indices, :] 

129 else: 

130 raise ValueError("data_type must be 'calcium' or 'spikes'") 

131 

132 # Apply downsampling if requested 

133 ds = dr_kwargs.pop('ds', 1) # Remove 'ds' from dr_kwargs and default to 1 

134 if ds > 1: 

135 neural_data = neural_data[:, ::ds] 

136 if self.logger: 

137 self.logger.info(f"Downsampling data by factor {ds}: {neural_data.shape[1]} timepoints") 

138 

139 # Create MVData and compute embedding 

140 mvdata = MVData(data=neural_data) # MVData expects (n_features, n_samples) 

141 

142 # Prepare parameters for the new simplified API 

143 params = {'dim': n_components} 

144 

145 # Handle method-specific parameters from dr_kwargs 

146 if 'n_neighbors' in dr_kwargs: 

147 params['n_neighbors'] = dr_kwargs['n_neighbors'] 

148 if 'min_dist' in dr_kwargs: 

149 params['min_dist'] = dr_kwargs['min_dist'] 

150 if 'perplexity' in dr_kwargs: 

151 params['perplexity'] = dr_kwargs['perplexity'] 

152 if 'dm_alpha' in dr_kwargs: 

153 params['dm_alpha'] = dr_kwargs['dm_alpha'] 

154 

155 # Add any other parameters 

156 params.update(dr_kwargs) 

157 

158 # Get embedding using simplified API 

159 embedding_obj = mvdata.get_embedding(method=method, **params) 

160 embedding = embedding_obj.coords.T # Transpose to (n_timepoints, n_components) 

161 

162 # Check if embedding has all timepoints (accounting for downsampling) 

163 expected_frames = self.experiment.n_frames // ds 

164 if embedding.shape[0] < expected_frames: 

165 n_missing = expected_frames - embedding.shape[0] 

166 raise ValueError( 

167 f"{method} embedding dropped {n_missing} timepoints due to graph disconnection. " 

168 f"This is not supported for INTENSE analysis. Try increasing n_neighbors or using a different method." 

169 ) 

170 

171 # Store metadata 

172 metadata = { 

173 'method': method, 

174 'n_components': n_components, 

175 'neuron_selection': neuron_selection, 

176 'neuron_indices': neuron_indices.tolist(), 

177 'n_neurons': len(neuron_indices), 

178 'dr_params': dr_kwargs, 

179 'data_type': data_type, 

180 'ds': ds # Store downsampling factor 

181 } 

182 

183 # Store in experiment 

184 self.experiment.store_embedding(embedding, method, data_type, metadata) 

185 

186 if self.logger: 

187 self.logger.info( 

188 f"Created {method} embedding with {n_components} components " 

189 f"using {len(neuron_indices)} neurons" 

190 ) 

191 

192 return embedding 

193 

194 def analyze_embedding_selectivity( 

195 self, 

196 embedding_methods: Optional[Union[str, List[str]]] = None, 

197 data_type: str = 'calcium', 

198 **intense_kwargs 

199 ) -> Dict: 

200 """ 

201 Analyze how neurons are selective to embedding components. 

202  

203 Parameters 

204 ---------- 

205 embedding_methods : str, list or None 

206 Embedding methods to analyze. If None, analyzes all stored embeddings 

207 data_type : str 

208 Data type ('calcium' or 'spikes') 

209 **intense_kwargs 

210 Additional arguments for compute_embedding_selectivity 

211  

212 Returns 

213 ------- 

214 results : dict 

215 Results from compute_embedding_selectivity 

216 """ 

217 results = compute_embedding_selectivity( 

218 self.experiment, 

219 embedding_methods=embedding_methods, 

220 data_type=data_type, 

221 **intense_kwargs 

222 ) 

223 

224 return results 

225 

226 def get_functional_organization( 

227 self, 

228 method_name: str, 

229 data_type: str = 'calcium' 

230 ) -> Dict: 

231 """ 

232 Analyze functional organization in the manifold. 

233  

234 Parameters 

235 ---------- 

236 method_name : str 

237 Name of the embedding method 

238 data_type : str 

239 Data type used for embedding 

240  

241 Returns 

242 ------- 

243 organization : dict 

244 Dictionary containing: 

245 - 'component_importance': Variance explained by each component 

246 - 'neuron_participation': How many components each neuron contributes to 

247 - 'component_specialization': How selective each component is 

248 - 'functional_clusters': Groups of neurons with similar embedding selectivity 

249 """ 

250 # Get embedding and metadata 

251 embedding_dict = self.experiment.get_embedding(method_name, data_type) 

252 embedding = embedding_dict['data'] 

253 metadata = embedding_dict.get('metadata', {}) 

254 neuron_indices = metadata.get('neuron_indices', list(range(self.experiment.n_cells))) 

255 

256 # Compute component importance (variance explained) 

257 component_var = np.var(embedding, axis=0) 

258 component_importance = component_var / np.sum(component_var) 

259 

260 # Get selectivity results if available 

261 stats_key = f"{method_name}_comp0" 

262 has_embedding_selectivity = ( 

263 hasattr(self.experiment, 'stats_tables') and 

264 data_type in self.experiment.stats_tables and 

265 stats_key in self.experiment.stats_tables[data_type] 

266 ) 

267 

268 organization = { 

269 'component_importance': component_importance, 

270 'n_components': embedding.shape[1], 

271 'n_neurons_used': len(neuron_indices), 

272 'neuron_indices': neuron_indices 

273 } 

274 

275 if has_embedding_selectivity: 

276 # Analyze neuron participation across components 

277 neuron_participation = {} 

278 component_specialization = {} 

279 

280 for comp_idx in range(embedding.shape[1]): 

281 feat_name = f"{method_name}_comp{comp_idx}" 

282 selective_neurons = [] 

283 

284 # Check which neurons are selective to this component 

285 for neuron_idx in range(self.experiment.n_cells): 

286 if (feat_name in self.experiment.significance_tables[data_type] and 

287 neuron_idx in self.experiment.significance_tables[data_type][feat_name] and 

288 self.experiment.significance_tables[data_type][feat_name][neuron_idx].get('stage2', False)): 

289 selective_neurons.append(neuron_idx) 

290 

291 # Track neuron participation 

292 if neuron_idx not in neuron_participation: 

293 neuron_participation[neuron_idx] = [] 

294 neuron_participation[neuron_idx].append(comp_idx) 

295 

296 component_specialization[comp_idx] = { 

297 'n_selective_neurons': len(selective_neurons), 

298 'selective_neurons': selective_neurons, 

299 'selectivity_rate': len(selective_neurons) / self.experiment.n_cells 

300 } 

301 

302 # Identify functional clusters (neurons selective to same components) 

303 from collections import defaultdict 

304 cluster_map = defaultdict(list) 

305 

306 for neuron_idx, components in neuron_participation.items(): 

307 cluster_key = tuple(sorted(components)) 

308 cluster_map[cluster_key].append(neuron_idx) 

309 

310 functional_clusters = [] 

311 for components, neurons in cluster_map.items(): 

312 if len(neurons) > 1: # Only keep clusters with multiple neurons 

313 functional_clusters.append({ 

314 'components': list(components), 

315 'neurons': neurons, 

316 'size': len(neurons) 

317 }) 

318 

319 # Sort clusters by size 

320 functional_clusters.sort(key=lambda x: x['size'], reverse=True) 

321 

322 organization.update({ 

323 'neuron_participation': neuron_participation, 

324 'component_specialization': component_specialization, 

325 'functional_clusters': functional_clusters, 

326 'n_participating_neurons': len(neuron_participation), 

327 'mean_components_per_neuron': np.mean([len(comps) for comps in neuron_participation.values()]) if neuron_participation else 0 

328 }) 

329 

330 return organization 

331 

332 def compare_embeddings( 

333 self, 

334 method_names: List[str], 

335 data_type: str = 'calcium' 

336 ) -> Dict: 

337 """ 

338 Compare functional organization across different embedding methods. 

339  

340 Parameters 

341 ---------- 

342 method_names : list 

343 List of embedding method names to compare 

344 data_type : str 

345 Data type used for embeddings 

346  

347 Returns 

348 ------- 

349 comparison : dict 

350 Comparison metrics between embeddings 

351 """ 

352 organizations = {} 

353 for method in method_names: 

354 try: 

355 organizations[method] = self.get_functional_organization(method, data_type) 

356 except KeyError: 

357 self.logger.warning(f"No embedding found for method '{method}'") 

358 

359 if len(organizations) < 2: 

360 raise ValueError("Need at least 2 embeddings to compare") 

361 

362 comparison = { 

363 'methods': list(organizations.keys()), 

364 'n_components': {m: org['n_components'] for m, org in organizations.items()}, 

365 'n_participating_neurons': {m: org.get('n_participating_neurons', 0) for m, org in organizations.items()}, 

366 'mean_components_per_neuron': {m: org.get('mean_components_per_neuron', 0) for m, org in organizations.items()}, 

367 'n_functional_clusters': {m: len(org.get('functional_clusters', [])) for m, org in organizations.items()} 

368 } 

369 

370 # Compare neuron participation overlap 

371 if all('neuron_participation' in org for org in organizations.values()): 

372 method_pairs = [(m1, m2) for i, m1 in enumerate(method_names) for m2 in method_names[i+1:]] 

373 participation_overlap = {} 

374 

375 for m1, m2 in method_pairs: 

376 neurons1 = set(organizations[m1]['neuron_participation'].keys()) 

377 neurons2 = set(organizations[m2]['neuron_participation'].keys()) 

378 

379 if neurons1 or neurons2: 

380 overlap = len(neurons1 & neurons2) / len(neurons1 | neurons2) 

381 else: 

382 overlap = 0 

383 

384 participation_overlap[f"{m1}_vs_{m2}"] = overlap 

385 

386 comparison['participation_overlap'] = participation_overlap 

387 

388 return comparison