Coverage for src/driada/integration/selectivity_mapper.py: 69.92%
123 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1"""
2SelectivityManifoldMapper: Bridge between INTENSE selectivity analysis and dimensionality reduction.
4This module provides tools for analyzing the relationship between single-neuron selectivity
5profiles and population-level manifold structure.
6"""
8import logging
9from typing import Dict, List, Optional, Tuple, Union, Any
10import numpy as np
11from scipy import stats
12from sklearn.decomposition import PCA
14from ..experiment import Experiment
15from ..dim_reduction.data import MVData
16from ..intense.pipelines import compute_embedding_selectivity
19class SelectivityManifoldMapper:
20 """
21 Analyzes relationships between neuronal selectivity and manifold structure.
23 This class provides methods to:
24 1. Create population embeddings and store them in the Experiment
25 2. Compute selectivity of neurons to embedding components
26 3. Analyze functional organization in the manifold
28 Parameters
29 ----------
30 experiment : Experiment
31 An Experiment object with computed selectivity results
32 device : Optional[torch.device], default=None
33 Device for computation (for future GPU support)
34 logger : Optional[logging.Logger], default=None
35 Logger for debugging and info messages
36 config : Optional[Dict], default=None
37 Configuration dictionary for custom parameters
39 Examples
40 --------
41 >>> # Create mapper and generate embeddings
42 >>> mapper = SelectivityManifoldMapper(exp)
43 >>>
44 >>> # Create and store PCA embedding
45 >>> mapper.create_embedding('pca', n_components=10, neuron_selection='significant')
46 >>>
47 >>> # Analyze neuron selectivity to PCA components
48 >>> results = mapper.analyze_embedding_selectivity('pca')
49 >>>
50 >>> # Get functional organization summary
51 >>> summary = mapper.get_functional_organization('pca')
52 """
54 def __init__(
55 self,
56 experiment: Experiment,
57 device: Optional[Any] = None,
58 logger: Optional[logging.Logger] = None,
59 config: Optional[Dict] = None
60 ):
61 self.experiment = experiment
62 self.device = device
63 self.logger = logger or logging.getLogger(self.__class__.__name__)
64 self.config = config or {}
66 # Validate experiment has required data
67 if not hasattr(experiment, 'calcium') or experiment.calcium is None:
68 raise ValueError("Experiment must have calcium data")
70 # Check if selectivity analysis has been performed
71 self.has_selectivity = hasattr(experiment, 'stats_tables') and experiment.stats_tables
73 if self.logger:
74 self.logger.info(
75 f"Initialized {self.__class__.__name__} with {experiment.n_cells} neurons"
76 )
78 def create_embedding(
79 self,
80 method: str,
81 n_components: int = 2,
82 data_type: str = 'calcium',
83 neuron_selection: Optional[Union[str, List[int]]] = None,
84 **dr_kwargs
85 ) -> np.ndarray:
86 """
87 Create dimensionality reduction embedding and store it in the experiment.
89 Parameters
90 ----------
91 method : str
92 DR method name ('pca', 'umap', 'isomap', etc.)
93 n_components : int
94 Number of embedding dimensions
95 data_type : str
96 Type of data to use ('calcium' or 'spikes')
97 neuron_selection : str, list or None
98 How to select neurons:
99 - None or 'all': Use all neurons
100 - 'significant': Use only significantly selective neurons
101 - List of integers: Use specific neuron indices
102 **dr_kwargs
103 Additional arguments for the DR method
105 Returns
106 -------
107 embedding : np.ndarray
108 The embedding array, shape (n_timepoints, n_components)
109 """
110 # Select neurons
111 if neuron_selection is None or neuron_selection == 'all':
112 neuron_indices = np.arange(self.experiment.n_cells)
113 elif neuron_selection == 'significant':
114 if not self.has_selectivity:
115 raise ValueError("Cannot select significant neurons without selectivity analysis")
116 sig_neurons = self.experiment.get_significant_neurons()
117 neuron_indices = np.array(list(sig_neurons.keys()))
118 if len(neuron_indices) == 0:
119 self.logger.warning("No significant neurons found, using all neurons")
120 neuron_indices = np.arange(self.experiment.n_cells)
121 else:
122 neuron_indices = np.array(neuron_selection)
124 # Get neural data
125 if data_type == 'calcium':
126 neural_data = self.experiment.calcium.data[neuron_indices, :]
127 elif data_type == 'spikes':
128 neural_data = self.experiment.spikes.data[neuron_indices, :]
129 else:
130 raise ValueError("data_type must be 'calcium' or 'spikes'")
132 # Apply downsampling if requested
133 ds = dr_kwargs.pop('ds', 1) # Remove 'ds' from dr_kwargs and default to 1
134 if ds > 1:
135 neural_data = neural_data[:, ::ds]
136 if self.logger:
137 self.logger.info(f"Downsampling data by factor {ds}: {neural_data.shape[1]} timepoints")
139 # Create MVData and compute embedding
140 mvdata = MVData(data=neural_data) # MVData expects (n_features, n_samples)
142 # Prepare parameters for the new simplified API
143 params = {'dim': n_components}
145 # Handle method-specific parameters from dr_kwargs
146 if 'n_neighbors' in dr_kwargs:
147 params['n_neighbors'] = dr_kwargs['n_neighbors']
148 if 'min_dist' in dr_kwargs:
149 params['min_dist'] = dr_kwargs['min_dist']
150 if 'perplexity' in dr_kwargs:
151 params['perplexity'] = dr_kwargs['perplexity']
152 if 'dm_alpha' in dr_kwargs:
153 params['dm_alpha'] = dr_kwargs['dm_alpha']
155 # Add any other parameters
156 params.update(dr_kwargs)
158 # Get embedding using simplified API
159 embedding_obj = mvdata.get_embedding(method=method, **params)
160 embedding = embedding_obj.coords.T # Transpose to (n_timepoints, n_components)
162 # Check if embedding has all timepoints (accounting for downsampling)
163 expected_frames = self.experiment.n_frames // ds
164 if embedding.shape[0] < expected_frames:
165 n_missing = expected_frames - embedding.shape[0]
166 raise ValueError(
167 f"{method} embedding dropped {n_missing} timepoints due to graph disconnection. "
168 f"This is not supported for INTENSE analysis. Try increasing n_neighbors or using a different method."
169 )
171 # Store metadata
172 metadata = {
173 'method': method,
174 'n_components': n_components,
175 'neuron_selection': neuron_selection,
176 'neuron_indices': neuron_indices.tolist(),
177 'n_neurons': len(neuron_indices),
178 'dr_params': dr_kwargs,
179 'data_type': data_type,
180 'ds': ds # Store downsampling factor
181 }
183 # Store in experiment
184 self.experiment.store_embedding(embedding, method, data_type, metadata)
186 if self.logger:
187 self.logger.info(
188 f"Created {method} embedding with {n_components} components "
189 f"using {len(neuron_indices)} neurons"
190 )
192 return embedding
194 def analyze_embedding_selectivity(
195 self,
196 embedding_methods: Optional[Union[str, List[str]]] = None,
197 data_type: str = 'calcium',
198 **intense_kwargs
199 ) -> Dict:
200 """
201 Analyze how neurons are selective to embedding components.
203 Parameters
204 ----------
205 embedding_methods : str, list or None
206 Embedding methods to analyze. If None, analyzes all stored embeddings
207 data_type : str
208 Data type ('calcium' or 'spikes')
209 **intense_kwargs
210 Additional arguments for compute_embedding_selectivity
212 Returns
213 -------
214 results : dict
215 Results from compute_embedding_selectivity
216 """
217 results = compute_embedding_selectivity(
218 self.experiment,
219 embedding_methods=embedding_methods,
220 data_type=data_type,
221 **intense_kwargs
222 )
224 return results
226 def get_functional_organization(
227 self,
228 method_name: str,
229 data_type: str = 'calcium'
230 ) -> Dict:
231 """
232 Analyze functional organization in the manifold.
234 Parameters
235 ----------
236 method_name : str
237 Name of the embedding method
238 data_type : str
239 Data type used for embedding
241 Returns
242 -------
243 organization : dict
244 Dictionary containing:
245 - 'component_importance': Variance explained by each component
246 - 'neuron_participation': How many components each neuron contributes to
247 - 'component_specialization': How selective each component is
248 - 'functional_clusters': Groups of neurons with similar embedding selectivity
249 """
250 # Get embedding and metadata
251 embedding_dict = self.experiment.get_embedding(method_name, data_type)
252 embedding = embedding_dict['data']
253 metadata = embedding_dict.get('metadata', {})
254 neuron_indices = metadata.get('neuron_indices', list(range(self.experiment.n_cells)))
256 # Compute component importance (variance explained)
257 component_var = np.var(embedding, axis=0)
258 component_importance = component_var / np.sum(component_var)
260 # Get selectivity results if available
261 stats_key = f"{method_name}_comp0"
262 has_embedding_selectivity = (
263 hasattr(self.experiment, 'stats_tables') and
264 data_type in self.experiment.stats_tables and
265 stats_key in self.experiment.stats_tables[data_type]
266 )
268 organization = {
269 'component_importance': component_importance,
270 'n_components': embedding.shape[1],
271 'n_neurons_used': len(neuron_indices),
272 'neuron_indices': neuron_indices
273 }
275 if has_embedding_selectivity:
276 # Analyze neuron participation across components
277 neuron_participation = {}
278 component_specialization = {}
280 for comp_idx in range(embedding.shape[1]):
281 feat_name = f"{method_name}_comp{comp_idx}"
282 selective_neurons = []
284 # Check which neurons are selective to this component
285 for neuron_idx in range(self.experiment.n_cells):
286 if (feat_name in self.experiment.significance_tables[data_type] and
287 neuron_idx in self.experiment.significance_tables[data_type][feat_name] and
288 self.experiment.significance_tables[data_type][feat_name][neuron_idx].get('stage2', False)):
289 selective_neurons.append(neuron_idx)
291 # Track neuron participation
292 if neuron_idx not in neuron_participation:
293 neuron_participation[neuron_idx] = []
294 neuron_participation[neuron_idx].append(comp_idx)
296 component_specialization[comp_idx] = {
297 'n_selective_neurons': len(selective_neurons),
298 'selective_neurons': selective_neurons,
299 'selectivity_rate': len(selective_neurons) / self.experiment.n_cells
300 }
302 # Identify functional clusters (neurons selective to same components)
303 from collections import defaultdict
304 cluster_map = defaultdict(list)
306 for neuron_idx, components in neuron_participation.items():
307 cluster_key = tuple(sorted(components))
308 cluster_map[cluster_key].append(neuron_idx)
310 functional_clusters = []
311 for components, neurons in cluster_map.items():
312 if len(neurons) > 1: # Only keep clusters with multiple neurons
313 functional_clusters.append({
314 'components': list(components),
315 'neurons': neurons,
316 'size': len(neurons)
317 })
319 # Sort clusters by size
320 functional_clusters.sort(key=lambda x: x['size'], reverse=True)
322 organization.update({
323 'neuron_participation': neuron_participation,
324 'component_specialization': component_specialization,
325 'functional_clusters': functional_clusters,
326 'n_participating_neurons': len(neuron_participation),
327 'mean_components_per_neuron': np.mean([len(comps) for comps in neuron_participation.values()]) if neuron_participation else 0
328 })
330 return organization
332 def compare_embeddings(
333 self,
334 method_names: List[str],
335 data_type: str = 'calcium'
336 ) -> Dict:
337 """
338 Compare functional organization across different embedding methods.
340 Parameters
341 ----------
342 method_names : list
343 List of embedding method names to compare
344 data_type : str
345 Data type used for embeddings
347 Returns
348 -------
349 comparison : dict
350 Comparison metrics between embeddings
351 """
352 organizations = {}
353 for method in method_names:
354 try:
355 organizations[method] = self.get_functional_organization(method, data_type)
356 except KeyError:
357 self.logger.warning(f"No embedding found for method '{method}'")
359 if len(organizations) < 2:
360 raise ValueError("Need at least 2 embeddings to compare")
362 comparison = {
363 'methods': list(organizations.keys()),
364 'n_components': {m: org['n_components'] for m, org in organizations.items()},
365 'n_participating_neurons': {m: org.get('n_participating_neurons', 0) for m, org in organizations.items()},
366 'mean_components_per_neuron': {m: org.get('mean_components_per_neuron', 0) for m, org in organizations.items()},
367 'n_functional_clusters': {m: len(org.get('functional_clusters', [])) for m, org in organizations.items()}
368 }
370 # Compare neuron participation overlap
371 if all('neuron_participation' in org for org in organizations.values()):
372 method_pairs = [(m1, m2) for i, m1 in enumerate(method_names) for m2 in method_names[i+1:]]
373 participation_overlap = {}
375 for m1, m2 in method_pairs:
376 neurons1 = set(organizations[m1]['neuron_participation'].keys())
377 neurons2 = set(organizations[m2]['neuron_participation'].keys())
379 if neurons1 or neurons2:
380 overlap = len(neurons1 & neurons2) / len(neurons1 | neurons2)
381 else:
382 overlap = 0
384 participation_overlap[f"{m1}_vs_{m2}"] = overlap
386 comparison['participation_overlap'] = participation_overlap
388 return comparison