Coverage for src/driada/intense/disentanglement.py: 70.27%

148 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1""" 

2Mixed selectivity disentanglement analysis for INTENSE. 

3 

4This module provides functions to analyze and disentangle mixed selectivity 

5in neural responses when neurons respond to multiple, potentially correlated 

6behavioral variables. 

7""" 

8 

9import numpy as np 

10from itertools import combinations 

11from ..information.info_base import get_mi, conditional_mi, MultiTimeSeries 

12 

13 

14# Default multifeature mapping for common behavioral variable combinations 

15# Maps component tuples to their semantic names 

16DEFAULT_MULTIFEATURE_MAP = { 

17 ('x', 'y'): 'place', # spatial location multifeature 

18} 

19 

20 

21def disentangle_pair(ts1, ts2, ts3, verbose=False, ds=1): 

22 """Disentangle mixed selectivity between two behavioral variables for a neuron. 

23  

24 Determines which of two correlated behavioral variables (ts2, ts3) provides 

25 the primary information about neural activity (ts1) using interaction information 

26 and conditional mutual information analysis. 

27  

28 Parameters 

29 ---------- 

30 ts1 : TimeSeries 

31 Neural activity time series (e.g., calcium signal or spike train). 

32 ts2 : TimeSeries 

33 First behavioral variable. 

34 ts3 : TimeSeries 

35 Second behavioral variable. 

36 verbose : bool, optional 

37 If True, print detailed analysis results. Default: False. 

38 ds : int, optional 

39 Downsampling factor. Default: 1. 

40  

41 Returns 

42 ------- 

43 float 

44 Disentanglement result: 

45 - 0: ts2 is the primary variable (ts3 is redundant) 

46 - 1: ts3 is the primary variable (ts2 is redundant) 

47 - 0.5: Both variables contribute - undistinguishable 

48  

49 Notes 

50 ----- 

51 The method uses interaction information to detect redundancy/synergy: 

52 - If II < 0 (redundancy), identifies the "weakest link" using criteria 

53 based on pairwise MI and conditional MI values 

54 - If II > 0 (synergy), uses different criteria for special cases 

55  

56 See README_INTENSE.md for theoretical background. 

57 """ 

58 # Compute pairwise mutual information 

59 mi12 = get_mi(ts1, ts2, ds=ds) # MI(neuron, behavior1) 

60 mi13 = get_mi(ts1, ts3, ds=ds) # MI(neuron, behavior2) 

61 mi23 = get_mi(ts2, ts3, ds=ds) # MI(behavior1, behavior2) 

62 

63 # Compute conditional mutual information 

64 cmi123 = conditional_mi(ts1, ts2, ts3, ds=ds) # MI(neuron, behavior1 | behavior2) 

65 cmi132 = conditional_mi(ts1, ts3, ts2, ds=ds) # MI(neuron, behavior2 | behavior1) 

66 

67 # Compute interaction information (average of two equivalent formulas) 

68 # Using Williams & Beer convention: II = I(X;Y|Z) - I(X;Y) 

69 I_av = np.mean([cmi123 - mi12, cmi132 - mi13]) 

70 

71 if verbose: 

72 print() 

73 print('MI(A,X):', mi12) 

74 print('MI(A,Y):', mi13) 

75 print('MI(X,Y):', mi23) 

76 

77 print() 

78 print('MI(A,X|Y):', cmi123) 

79 print('MI(A,Y|X):', cmi132) 

80 

81 print() 

82 print('MI(A,X|Y) / MI(A,X):', np.round(cmi123/mi12, 3) if mi12 > 0 else 'N/A') 

83 print('MI(A,Y|X) / MI(A,Y):', np.round(cmi132/mi13, 3) if mi13 > 0 else 'N/A') 

84 

85 print() 

86 print('I(A,X,Y) 1:', cmi123 - mi12) 

87 print('I(A,X,Y) 2:', cmi132 - mi13) 

88 print('I(A,X,Y) av:', I_av) 

89 

90 print() 

91 print(f'Analysis (X=behavior1, Y=behavior2):') 

92 print(f' Redundancy detected: {I_av < 0}') 

93 print(f' MI(A,X) < |II|: {mi12 < np.abs(I_av)}') 

94 print(f' MI(A,Y) < |II|: {mi13 < np.abs(I_av)}') 

95 

96 if I_av < 0: # Negative interaction information (redundancy) 

97 # Check if either variable is a "weak link" 

98 criterion1 = mi12 < np.abs(I_av) and not cmi132 < np.abs(I_av) 

99 criterion2 = mi13 < np.abs(I_av) and not cmi123 < np.abs(I_av) 

100 

101 if criterion1 and not criterion2: 

102 return 1 # ts2 is redundant, ts3 is primary 

103 elif criterion2 and not criterion1: 

104 return 0 # ts3 is redundant, ts2 is primary 

105 else: 

106 return 0.5 # Both contribute - undistinguishable 

107 

108 else: # Positive interaction information (synergy) 

109 # Special cases for synergistic relationships 

110 if mi13 == 0 and cmi123 > cmi132: 

111 return 0 # ts2 is primary 

112 

113 if mi12 == 0 and cmi132 > cmi123: 

114 return 1 # ts3 is primary 

115 

116 if mi13 > 0 and mi12/mi13 > 2.0 and cmi123 > cmi132: 

117 return 0 # ts2 is strongly dominant 

118 

119 if mi12 > 0 and mi13/mi12 > 2.0 and cmi132 > cmi123: 

120 return 1 # ts3 is strongly dominant 

121 

122 return 0.5 # Both contribute - undistinguishable 

123 

124 

125def disentangle_all_selectivities(exp, feat_names, ds=1, multifeature_map=None, 

126 feat_feat_significance=None, cell_bunch=None): 

127 """Analyze mixed selectivity across all significant neuron-feature pairs. 

128  

129 For each neuron that responds to multiple features, determines which 

130 features provide primary vs redundant information using disentanglement 

131 analysis. Only analyzes feature pairs that show significant correlation 

132 in the behavioral data. 

133  

134 Parameters 

135 ---------- 

136 exp : Experiment 

137 Experiment object containing neural and behavioral data. 

138 feat_names : list of str 

139 List of feature names to analyze. Should match features in experiment 

140 and any aggregated names from multifeature_map. 

141 ds : int, optional 

142 Downsampling factor. Default: 1. 

143 multifeature_map : dict, optional 

144 Mapping from multifeature tuples to aggregated names and their 

145 corresponding MultiTimeSeries. If None, uses DEFAULT_MULTIFEATURE_MAP. 

146 Example: { 

147 ('x', 'y'): 'place', 

148 ('speed', 'head_direction'): 'locomotion', 

149 ('lick', 'reward'): 'consummatory' 

150 } 

151 feat_feat_significance : ndarray, optional 

152 Binary significance matrix from compute_feat_feat_significance. 

153 If provided, only feature pairs marked as significant (value=1) 

154 will be analyzed for disentanglement. Non-significant pairs are 

155 assumed to represent true mixed selectivity. 

156 cell_bunch : list or None, optional 

157 List of cell IDs to analyze. If None, analyzes all cells. 

158 Default: None. 

159  

160 Returns 

161 ------- 

162 disent_matrix : ndarray 

163 Matrix where element [i,j] indicates how many times feature i 

164 was primary when paired with feature j across all neurons. 

165 count_matrix : ndarray 

166 Matrix where element [i,j] indicates how many neuron-feature 

167 pairs were tested for features i and j. 

168  

169 Notes 

170 ----- 

171 The analysis is performed only on neurons with significant selectivity 

172 to at least 2 features. If feat_feat_significance is provided, only 

173 behaviorally correlated feature pairs are analyzed for redundancy. 

174 Non-significant pairs indicate true mixed selectivity. 

175 """ 

176 # Use default multifeature mapping if none provided 

177 if multifeature_map is None: 

178 multifeature_map = DEFAULT_MULTIFEATURE_MAP.copy() 

179 

180 # Initialize result matrices 

181 n_features = len(feat_names) 

182 disent_matrix = np.zeros((n_features, n_features)) 

183 count_matrix = np.zeros((n_features, n_features)) 

184 

185 # Create MultiTimeSeries for each multifeature 

186 multifeature_ts = {} 

187 for mf_tuple, agg_name in multifeature_map.items(): 

188 if agg_name in feat_names: 

189 # Get individual TimeSeries for each component 

190 component_ts = [] 

191 for component in mf_tuple: 

192 if hasattr(exp, component): 

193 component_ts.append(getattr(exp, component)) 

194 else: 

195 raise ValueError(f"Component '{component}' not found in experiment") 

196 

197 # Create MultiTimeSeries 

198 multifeature_ts[agg_name] = MultiTimeSeries(component_ts) 

199 

200 # Get neurons with significant selectivity to multiple features 

201 sneur = exp.get_significant_neurons(min_nspec=2, cbunch=cell_bunch) 

202 

203 for neuron, sels in sneur.items(): 

204 neur_ts = exp.neurons[neuron].ca 

205 

206 # Test all pairs of features this neuron responds to 

207 for sel_comb in combinations(sels, 2): 

208 try: 

209 sel_comb = list(sel_comb) 

210 feat_ts = [] 

211 finds = [] 

212 

213 # Get time series for each feature 

214 for fname in sel_comb: 

215 # Check if this is a multifeature tuple 

216 if isinstance(fname, tuple) and fname in multifeature_map: 

217 agg_name = multifeature_map[fname] 

218 if agg_name in feat_names: 

219 feat_ts.append(multifeature_ts[agg_name]) 

220 finds.append(feat_names.index(agg_name)) 

221 else: 

222 raise ValueError(f"Aggregated name '{agg_name}' not in feat_names") 

223 else: 

224 # Regular single feature 

225 if hasattr(exp, fname): 

226 feat_ts.append(getattr(exp, fname)) 

227 finds.append(feat_names.index(fname)) 

228 else: 

229 raise ValueError(f"Feature '{fname}' not found in experiment") 

230 

231 # Get feature indices 

232 ind1 = finds[0] 

233 ind2 = finds[1] 

234 

235 # Check if this feature pair has significant behavioral correlation 

236 if feat_feat_significance is not None: 

237 if feat_feat_significance[ind1, ind2] == 0: 

238 # Features are not significantly correlated 

239 # Skip disentanglement - this is true mixed selectivity 

240 count_matrix[ind1, ind2] += 1 

241 count_matrix[ind2, ind1] += 1 

242 # Add 0.5 to each to indicate undistinguishable contributions 

243 disent_matrix[ind1, ind2] += 0.5 

244 disent_matrix[ind2, ind1] += 0.5 

245 continue 

246 

247 # Perform disentanglement analysis only for significant pairs 

248 disres = disentangle_pair(neur_ts, feat_ts[0], feat_ts[1], 

249 ds=ds, verbose=False) 

250 

251 # Update matrices 

252 count_matrix[ind1, ind2] += 1 

253 count_matrix[ind2, ind1] += 1 

254 

255 if disres == 0: 

256 disent_matrix[ind1, ind2] += 1 # Feature 1 is primary 

257 elif disres == 1: 

258 disent_matrix[ind2, ind1] += 1 # Feature 2 is primary 

259 elif disres == 0.5: 

260 disent_matrix[ind1, ind2] += 0.5 # Both contribute 

261 disent_matrix[ind2, ind1] += 0.5 

262 

263 except Exception as e: 

264 print(f'ERROR processing neuron {neuron}, features {sel_comb}: {str(e)}') 

265 continue 

266 

267 return disent_matrix, count_matrix 

268 

269 

270def create_multifeature_map(exp, mapping_dict): 

271 """Create a multifeature mapping with validation. 

272  

273 Parameters 

274 ---------- 

275 exp : Experiment 

276 Experiment object to validate feature existence. 

277 mapping_dict : dict 

278 Dictionary mapping tuples of features to aggregated names. 

279 Example: {('x', 'y'): 'place', ('speed', 'head_direction'): 'locomotion'} 

280  

281 Returns 

282 ------- 

283 dict 

284 Validated multifeature mapping. 

285  

286 Raises 

287 ------ 

288 ValueError 

289 If any component features don't exist in the experiment. 

290 """ 

291 validated_map = {} 

292 

293 for mf_tuple, agg_name in mapping_dict.items(): 

294 # Validate that all components exist 

295 for component in mf_tuple: 

296 if not hasattr(exp, component): 

297 raise ValueError(f"Component '{component}' in multifeature {mf_tuple} " 

298 f"not found in experiment") 

299 

300 # Ensure tuple is sorted for consistency 

301 sorted_tuple = tuple(sorted(mf_tuple)) 

302 validated_map[sorted_tuple] = agg_name 

303 

304 return validated_map 

305 

306 

307def get_disentanglement_summary(disent_matrix, count_matrix, feat_names, 

308 feat_feat_significance=None): 

309 """Generate a summary of disentanglement results. 

310  

311 Parameters 

312 ---------- 

313 disent_matrix : ndarray 

314 Disentanglement result matrix from disentangle_all_selectivities. 

315 count_matrix : ndarray 

316 Count matrix from disentangle_all_selectivities. 

317 feat_names : list of str 

318 Feature names corresponding to matrix indices. 

319 feat_feat_significance : ndarray, optional 

320 Binary significance matrix indicating which feature pairs 

321 were analyzed for disentanglement. 

322  

323 Returns 

324 ------- 

325 dict 

326 Summary statistics including: 

327 - Primary feature percentages for each pair 

328 - Total counts for each pair 

329 - Overall redundancy vs independence rates 

330 - Breakdown by significant vs non-significant feature pairs 

331 """ 

332 summary = { 

333 'feature_pairs': {}, 

334 'overall_stats': {} 

335 } 

336 

337 n_features = len(feat_names) 

338 total_redundant = 0 

339 total_undistinguishable = 0 

340 total_pairs = 0 

341 

342 for i in range(n_features): 

343 for j in range(i + 1, n_features): 

344 if count_matrix[i, j] > 0: 

345 n_total = count_matrix[i, j] 

346 n_i_primary = disent_matrix[i, j] 

347 n_j_primary = disent_matrix[j, i] 

348 

349 # Account for 0.5 contributions (undistinguishable) 

350 n_undistinguishable = (n_i_primary + n_j_primary - n_total) * 2 

351 n_redundant = n_total - n_undistinguishable 

352 

353 pair_key = f"{feat_names[i]}_vs_{feat_names[j]}" 

354 summary['feature_pairs'][pair_key] = { 

355 'total_neurons': int(n_total), 

356 f'{feat_names[i]}_primary': n_i_primary / n_total * 100, 

357 f'{feat_names[j]}_primary': n_j_primary / n_total * 100, 

358 'undistinguishable_pct': n_undistinguishable / n_total * 100, 

359 'redundant_pct': n_redundant / n_total * 100 

360 } 

361 

362 total_redundant += n_redundant 

363 total_undistinguishable += n_undistinguishable 

364 total_pairs += n_total 

365 

366 if total_pairs > 0: 

367 summary['overall_stats'] = { 

368 'total_neuron_pairs': int(total_pairs), 

369 'redundancy_rate': total_redundant / total_pairs * 100, 

370 'undistinguishable_rate': total_undistinguishable / total_pairs * 100 

371 } 

372 

373 # Add breakdown by behavioral significance if provided 

374 if feat_feat_significance is not None: 

375 sig_pairs = 0 

376 nonsig_pairs = 0 

377 for i in range(n_features): 

378 for j in range(i + 1, n_features): 

379 if count_matrix[i, j] > 0: 

380 if feat_feat_significance[i, j] == 1: 

381 sig_pairs += count_matrix[i, j] 

382 else: 

383 nonsig_pairs += count_matrix[i, j] 

384 

385 summary['overall_stats']['significant_behavior_pairs'] = int(sig_pairs) 

386 summary['overall_stats']['nonsignificant_behavior_pairs'] = int(nonsig_pairs) 

387 summary['overall_stats']['true_mixed_selectivity_rate'] = ( 

388 nonsig_pairs / total_pairs * 100 if total_pairs > 0 else 0 

389 ) 

390 

391 return summary