Coverage for src/driada/intense/disentanglement.py: 70.27%
148 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1"""
2Mixed selectivity disentanglement analysis for INTENSE.
4This module provides functions to analyze and disentangle mixed selectivity
5in neural responses when neurons respond to multiple, potentially correlated
6behavioral variables.
7"""
9import numpy as np
10from itertools import combinations
11from ..information.info_base import get_mi, conditional_mi, MultiTimeSeries
14# Default multifeature mapping for common behavioral variable combinations
15# Maps component tuples to their semantic names
16DEFAULT_MULTIFEATURE_MAP = {
17 ('x', 'y'): 'place', # spatial location multifeature
18}
21def disentangle_pair(ts1, ts2, ts3, verbose=False, ds=1):
22 """Disentangle mixed selectivity between two behavioral variables for a neuron.
24 Determines which of two correlated behavioral variables (ts2, ts3) provides
25 the primary information about neural activity (ts1) using interaction information
26 and conditional mutual information analysis.
28 Parameters
29 ----------
30 ts1 : TimeSeries
31 Neural activity time series (e.g., calcium signal or spike train).
32 ts2 : TimeSeries
33 First behavioral variable.
34 ts3 : TimeSeries
35 Second behavioral variable.
36 verbose : bool, optional
37 If True, print detailed analysis results. Default: False.
38 ds : int, optional
39 Downsampling factor. Default: 1.
41 Returns
42 -------
43 float
44 Disentanglement result:
45 - 0: ts2 is the primary variable (ts3 is redundant)
46 - 1: ts3 is the primary variable (ts2 is redundant)
47 - 0.5: Both variables contribute - undistinguishable
49 Notes
50 -----
51 The method uses interaction information to detect redundancy/synergy:
52 - If II < 0 (redundancy), identifies the "weakest link" using criteria
53 based on pairwise MI and conditional MI values
54 - If II > 0 (synergy), uses different criteria for special cases
56 See README_INTENSE.md for theoretical background.
57 """
58 # Compute pairwise mutual information
59 mi12 = get_mi(ts1, ts2, ds=ds) # MI(neuron, behavior1)
60 mi13 = get_mi(ts1, ts3, ds=ds) # MI(neuron, behavior2)
61 mi23 = get_mi(ts2, ts3, ds=ds) # MI(behavior1, behavior2)
63 # Compute conditional mutual information
64 cmi123 = conditional_mi(ts1, ts2, ts3, ds=ds) # MI(neuron, behavior1 | behavior2)
65 cmi132 = conditional_mi(ts1, ts3, ts2, ds=ds) # MI(neuron, behavior2 | behavior1)
67 # Compute interaction information (average of two equivalent formulas)
68 # Using Williams & Beer convention: II = I(X;Y|Z) - I(X;Y)
69 I_av = np.mean([cmi123 - mi12, cmi132 - mi13])
71 if verbose:
72 print()
73 print('MI(A,X):', mi12)
74 print('MI(A,Y):', mi13)
75 print('MI(X,Y):', mi23)
77 print()
78 print('MI(A,X|Y):', cmi123)
79 print('MI(A,Y|X):', cmi132)
81 print()
82 print('MI(A,X|Y) / MI(A,X):', np.round(cmi123/mi12, 3) if mi12 > 0 else 'N/A')
83 print('MI(A,Y|X) / MI(A,Y):', np.round(cmi132/mi13, 3) if mi13 > 0 else 'N/A')
85 print()
86 print('I(A,X,Y) 1:', cmi123 - mi12)
87 print('I(A,X,Y) 2:', cmi132 - mi13)
88 print('I(A,X,Y) av:', I_av)
90 print()
91 print(f'Analysis (X=behavior1, Y=behavior2):')
92 print(f' Redundancy detected: {I_av < 0}')
93 print(f' MI(A,X) < |II|: {mi12 < np.abs(I_av)}')
94 print(f' MI(A,Y) < |II|: {mi13 < np.abs(I_av)}')
96 if I_av < 0: # Negative interaction information (redundancy)
97 # Check if either variable is a "weak link"
98 criterion1 = mi12 < np.abs(I_av) and not cmi132 < np.abs(I_av)
99 criterion2 = mi13 < np.abs(I_av) and not cmi123 < np.abs(I_av)
101 if criterion1 and not criterion2:
102 return 1 # ts2 is redundant, ts3 is primary
103 elif criterion2 and not criterion1:
104 return 0 # ts3 is redundant, ts2 is primary
105 else:
106 return 0.5 # Both contribute - undistinguishable
108 else: # Positive interaction information (synergy)
109 # Special cases for synergistic relationships
110 if mi13 == 0 and cmi123 > cmi132:
111 return 0 # ts2 is primary
113 if mi12 == 0 and cmi132 > cmi123:
114 return 1 # ts3 is primary
116 if mi13 > 0 and mi12/mi13 > 2.0 and cmi123 > cmi132:
117 return 0 # ts2 is strongly dominant
119 if mi12 > 0 and mi13/mi12 > 2.0 and cmi132 > cmi123:
120 return 1 # ts3 is strongly dominant
122 return 0.5 # Both contribute - undistinguishable
125def disentangle_all_selectivities(exp, feat_names, ds=1, multifeature_map=None,
126 feat_feat_significance=None, cell_bunch=None):
127 """Analyze mixed selectivity across all significant neuron-feature pairs.
129 For each neuron that responds to multiple features, determines which
130 features provide primary vs redundant information using disentanglement
131 analysis. Only analyzes feature pairs that show significant correlation
132 in the behavioral data.
134 Parameters
135 ----------
136 exp : Experiment
137 Experiment object containing neural and behavioral data.
138 feat_names : list of str
139 List of feature names to analyze. Should match features in experiment
140 and any aggregated names from multifeature_map.
141 ds : int, optional
142 Downsampling factor. Default: 1.
143 multifeature_map : dict, optional
144 Mapping from multifeature tuples to aggregated names and their
145 corresponding MultiTimeSeries. If None, uses DEFAULT_MULTIFEATURE_MAP.
146 Example: {
147 ('x', 'y'): 'place',
148 ('speed', 'head_direction'): 'locomotion',
149 ('lick', 'reward'): 'consummatory'
150 }
151 feat_feat_significance : ndarray, optional
152 Binary significance matrix from compute_feat_feat_significance.
153 If provided, only feature pairs marked as significant (value=1)
154 will be analyzed for disentanglement. Non-significant pairs are
155 assumed to represent true mixed selectivity.
156 cell_bunch : list or None, optional
157 List of cell IDs to analyze. If None, analyzes all cells.
158 Default: None.
160 Returns
161 -------
162 disent_matrix : ndarray
163 Matrix where element [i,j] indicates how many times feature i
164 was primary when paired with feature j across all neurons.
165 count_matrix : ndarray
166 Matrix where element [i,j] indicates how many neuron-feature
167 pairs were tested for features i and j.
169 Notes
170 -----
171 The analysis is performed only on neurons with significant selectivity
172 to at least 2 features. If feat_feat_significance is provided, only
173 behaviorally correlated feature pairs are analyzed for redundancy.
174 Non-significant pairs indicate true mixed selectivity.
175 """
176 # Use default multifeature mapping if none provided
177 if multifeature_map is None:
178 multifeature_map = DEFAULT_MULTIFEATURE_MAP.copy()
180 # Initialize result matrices
181 n_features = len(feat_names)
182 disent_matrix = np.zeros((n_features, n_features))
183 count_matrix = np.zeros((n_features, n_features))
185 # Create MultiTimeSeries for each multifeature
186 multifeature_ts = {}
187 for mf_tuple, agg_name in multifeature_map.items():
188 if agg_name in feat_names:
189 # Get individual TimeSeries for each component
190 component_ts = []
191 for component in mf_tuple:
192 if hasattr(exp, component):
193 component_ts.append(getattr(exp, component))
194 else:
195 raise ValueError(f"Component '{component}' not found in experiment")
197 # Create MultiTimeSeries
198 multifeature_ts[agg_name] = MultiTimeSeries(component_ts)
200 # Get neurons with significant selectivity to multiple features
201 sneur = exp.get_significant_neurons(min_nspec=2, cbunch=cell_bunch)
203 for neuron, sels in sneur.items():
204 neur_ts = exp.neurons[neuron].ca
206 # Test all pairs of features this neuron responds to
207 for sel_comb in combinations(sels, 2):
208 try:
209 sel_comb = list(sel_comb)
210 feat_ts = []
211 finds = []
213 # Get time series for each feature
214 for fname in sel_comb:
215 # Check if this is a multifeature tuple
216 if isinstance(fname, tuple) and fname in multifeature_map:
217 agg_name = multifeature_map[fname]
218 if agg_name in feat_names:
219 feat_ts.append(multifeature_ts[agg_name])
220 finds.append(feat_names.index(agg_name))
221 else:
222 raise ValueError(f"Aggregated name '{agg_name}' not in feat_names")
223 else:
224 # Regular single feature
225 if hasattr(exp, fname):
226 feat_ts.append(getattr(exp, fname))
227 finds.append(feat_names.index(fname))
228 else:
229 raise ValueError(f"Feature '{fname}' not found in experiment")
231 # Get feature indices
232 ind1 = finds[0]
233 ind2 = finds[1]
235 # Check if this feature pair has significant behavioral correlation
236 if feat_feat_significance is not None:
237 if feat_feat_significance[ind1, ind2] == 0:
238 # Features are not significantly correlated
239 # Skip disentanglement - this is true mixed selectivity
240 count_matrix[ind1, ind2] += 1
241 count_matrix[ind2, ind1] += 1
242 # Add 0.5 to each to indicate undistinguishable contributions
243 disent_matrix[ind1, ind2] += 0.5
244 disent_matrix[ind2, ind1] += 0.5
245 continue
247 # Perform disentanglement analysis only for significant pairs
248 disres = disentangle_pair(neur_ts, feat_ts[0], feat_ts[1],
249 ds=ds, verbose=False)
251 # Update matrices
252 count_matrix[ind1, ind2] += 1
253 count_matrix[ind2, ind1] += 1
255 if disres == 0:
256 disent_matrix[ind1, ind2] += 1 # Feature 1 is primary
257 elif disres == 1:
258 disent_matrix[ind2, ind1] += 1 # Feature 2 is primary
259 elif disres == 0.5:
260 disent_matrix[ind1, ind2] += 0.5 # Both contribute
261 disent_matrix[ind2, ind1] += 0.5
263 except Exception as e:
264 print(f'ERROR processing neuron {neuron}, features {sel_comb}: {str(e)}')
265 continue
267 return disent_matrix, count_matrix
270def create_multifeature_map(exp, mapping_dict):
271 """Create a multifeature mapping with validation.
273 Parameters
274 ----------
275 exp : Experiment
276 Experiment object to validate feature existence.
277 mapping_dict : dict
278 Dictionary mapping tuples of features to aggregated names.
279 Example: {('x', 'y'): 'place', ('speed', 'head_direction'): 'locomotion'}
281 Returns
282 -------
283 dict
284 Validated multifeature mapping.
286 Raises
287 ------
288 ValueError
289 If any component features don't exist in the experiment.
290 """
291 validated_map = {}
293 for mf_tuple, agg_name in mapping_dict.items():
294 # Validate that all components exist
295 for component in mf_tuple:
296 if not hasattr(exp, component):
297 raise ValueError(f"Component '{component}' in multifeature {mf_tuple} "
298 f"not found in experiment")
300 # Ensure tuple is sorted for consistency
301 sorted_tuple = tuple(sorted(mf_tuple))
302 validated_map[sorted_tuple] = agg_name
304 return validated_map
307def get_disentanglement_summary(disent_matrix, count_matrix, feat_names,
308 feat_feat_significance=None):
309 """Generate a summary of disentanglement results.
311 Parameters
312 ----------
313 disent_matrix : ndarray
314 Disentanglement result matrix from disentangle_all_selectivities.
315 count_matrix : ndarray
316 Count matrix from disentangle_all_selectivities.
317 feat_names : list of str
318 Feature names corresponding to matrix indices.
319 feat_feat_significance : ndarray, optional
320 Binary significance matrix indicating which feature pairs
321 were analyzed for disentanglement.
323 Returns
324 -------
325 dict
326 Summary statistics including:
327 - Primary feature percentages for each pair
328 - Total counts for each pair
329 - Overall redundancy vs independence rates
330 - Breakdown by significant vs non-significant feature pairs
331 """
332 summary = {
333 'feature_pairs': {},
334 'overall_stats': {}
335 }
337 n_features = len(feat_names)
338 total_redundant = 0
339 total_undistinguishable = 0
340 total_pairs = 0
342 for i in range(n_features):
343 for j in range(i + 1, n_features):
344 if count_matrix[i, j] > 0:
345 n_total = count_matrix[i, j]
346 n_i_primary = disent_matrix[i, j]
347 n_j_primary = disent_matrix[j, i]
349 # Account for 0.5 contributions (undistinguishable)
350 n_undistinguishable = (n_i_primary + n_j_primary - n_total) * 2
351 n_redundant = n_total - n_undistinguishable
353 pair_key = f"{feat_names[i]}_vs_{feat_names[j]}"
354 summary['feature_pairs'][pair_key] = {
355 'total_neurons': int(n_total),
356 f'{feat_names[i]}_primary': n_i_primary / n_total * 100,
357 f'{feat_names[j]}_primary': n_j_primary / n_total * 100,
358 'undistinguishable_pct': n_undistinguishable / n_total * 100,
359 'redundant_pct': n_redundant / n_total * 100
360 }
362 total_redundant += n_redundant
363 total_undistinguishable += n_undistinguishable
364 total_pairs += n_total
366 if total_pairs > 0:
367 summary['overall_stats'] = {
368 'total_neuron_pairs': int(total_pairs),
369 'redundancy_rate': total_redundant / total_pairs * 100,
370 'undistinguishable_rate': total_undistinguishable / total_pairs * 100
371 }
373 # Add breakdown by behavioral significance if provided
374 if feat_feat_significance is not None:
375 sig_pairs = 0
376 nonsig_pairs = 0
377 for i in range(n_features):
378 for j in range(i + 1, n_features):
379 if count_matrix[i, j] > 0:
380 if feat_feat_significance[i, j] == 1:
381 sig_pairs += count_matrix[i, j]
382 else:
383 nonsig_pairs += count_matrix[i, j]
385 summary['overall_stats']['significant_behavior_pairs'] = int(sig_pairs)
386 summary['overall_stats']['nonsignificant_behavior_pairs'] = int(nonsig_pairs)
387 summary['overall_stats']['true_mixed_selectivity_rate'] = (
388 nonsig_pairs / total_pairs * 100 if total_pairs > 0 else 0
389 )
391 return summary