Coverage for src/driada/intense/visual.py: 97.03%

236 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1import numpy as np 

2 

3import matplotlib.pyplot as plt 

4from ..utils.plot import create_default_figure, make_beautiful 

5from ..utils.data import rescale 

6from scipy.stats import rankdata, gaussian_kde, wasserstein_distance 

7import seaborn as sns 

8 

9 

10def plot_pc_activity(exp, cell_ind, ds=None, ax=None): 

11 """ 

12 Plot place cell activity overlaid on spatial trajectory. 

13  

14 Parameters 

15 ---------- 

16 exp : Experiment 

17 Experiment object with spatial data and neurons. 

18 cell_ind : int 

19 Index of the neuron to plot. 

20 ds : int, optional 

21 Downsampling factor. Default: 5. 

22 ax : matplotlib.axes.Axes, optional 

23 Axes to plot on. If None, creates new figure. 

24  

25 Returns 

26 ------- 

27 ax : matplotlib.axes.Axes 

28 Axes with the plot. 

29 """ 

30 pc_stats = exp.stats_table[('x', 'y')][cell_ind] 

31 pval = None if pc_stats['pval'] is None else np.round(pc_stats['pval'], 7) 

32 rel_mi_beh = None if pc_stats['rel_mi_beh'] is None else np.round(pc_stats['rel_mi_beh'], 4) 

33 

34 if ds is None: 

35 ds = 5 

36 

37 if ax is None: 

38 lenx = max(exp.x.data) - min(exp.x.data) 

39 leny = max(exp.y.data) - min(exp.y.data) 

40 xyratio = max(lenx / leny, leny / lenx) 

41 fig, ax = create_default_figure(6*xyratio, 6) 

42 

43 #neur = np.roll(rescale(rankdata(exp.neurons[ind].ca.data)), 0) 

44 neur = rescale(np.log(exp.neurons[cell_ind].ca.data+1e-10)) 

45 spinds = np.where(exp.neurons[cell_ind].sp.data != 0)[0] 

46 

47 #ax.plot(exp.x.data[::ds], exp.y.data[::ds], c = 'k', alpha=0.3) 

48 ax.scatter(exp.x.data[::ds], exp.y.data[::ds], c=neur[::ds], cmap = 'plasma', alpha=0.8) 

49 ax.scatter(exp.x.data[spinds], exp.y.data[spinds], c='k', alpha=1, marker='*', linewidth=2, s=100) 

50 ax.set_xlabel('x') 

51 ax.set_ylabel('y') 

52 ax.set_title(f'Cell {cell_ind}, Rel MI={rel_mi_beh}, pval={pval}') 

53 

54 return ax 

55 

56 

57def plot_neuron_feature_density(exp, data_type, cell_id, featname, ind1=0, ind2=100000, ds=1, shift=None, ax=None, compute_wsd=False): 

58 """ 

59 Plot density distribution of neural activity conditioned on feature values. 

60  

61 Parameters 

62 ---------- 

63 exp : Experiment 

64 Experiment object containing neurons and features. 

65 data_type : str 

66 Type of neural data: 'calcium' or 'spikes'. 

67 cell_id : int 

68 Index of the neuron. 

69 featname : str 

70 Name of the behavioral feature. 

71 ind1 : int, optional 

72 Start frame index. Default: 0. 

73 ind2 : int, optional 

74 End frame index. Default: 100000. 

75 ds : int, optional 

76 Downsampling factor. Default: 1. 

77 shift : int, optional 

78 Temporal shift (not implemented). Default: None. 

79 ax : matplotlib.axes.Axes, optional 

80 Axes to plot on. If None, creates new figure. 

81 compute_wsd : bool, optional 

82 Whether to compute Wasserstein distance for binary features. Default: False. 

83  

84 Returns 

85 ------- 

86 ax : matplotlib.axes.Axes 

87 Axes with the plot. 

88 """ 

89 ind2 = min(exp.n_frames, ind2) 

90 

91 if data_type == 'calcium': 

92 sig = exp.neurons[cell_id].ca.scdata[ind1:ind2][::ds] 

93 if data_type == 'spikes': 

94 sig = exp.neurons[cell_id].sp.scdata[ind1:ind2][::ds] 

95 

96 feature = getattr(exp, featname) 

97 bdata = feature.scdata[ind1:ind2][::ds] 

98 rbdata = rescale(rankdata(bdata)) 

99 

100 if ax is None: 

101 fig, ax = plt.subplots(figsize=(6,6)) 

102 

103 if feature.is_binary: 

104 if data_type == 'calcium': 

105 vals0 = np.log10(sig[np.where((rbdata == min(rbdata)) & (sig > 0))]) 

106 vals1 = np.log10(sig[np.where((rbdata == max(rbdata)) & (sig > 0))]) 

107 

108 if compute_wsd and len(vals0) > 0 and len(vals1) > 0: 

109 wsd = wasserstein_distance(vals0, vals1) 

110 title_text = f'wsd={wsd:.3f}' 

111 else: 

112 title_text = '' 

113 

114 _ = sns.kdeplot(vals0, ax=ax, c='b', label=f'{featname}=0', linewidth=3, bw_adjust=0.5) 

115 _ = sns.kdeplot(vals1, ax=ax, c='r', label=f'{featname}=1', linewidth=3, bw_adjust=0.5) 

116 ax.legend(loc='upper right') 

117 ax.set_xlabel('log(dF/F)', fontsize=20) 

118 ax.set_ylabel('density', fontsize=20) 

119 if title_text: 

120 ax.set_title(title_text) 

121 

122 if data_type == 'spikes': 

123 raise NotImplementedError('Binary feature density plot for spike data not yet implemented') 

124 

125 else: 

126 x0, y0 = np.log10(sig + np.random.random(size=len(sig)) * 1e-8), np.log( 

127 bdata + np.random.random(size=len(bdata)) * 1e-8) 

128 

129 jdata = np.vstack([x0, y0]).T 

130 # jplot = sns.jointplot(jdata, x=jdata[:,0], y=jdata[:,1], kind='hist', bins=100) 

131 nbins = 100 

132 k = gaussian_kde(jdata.T) 

133 xi, yi = np.mgrid[x0.min():x0.max():nbins * 1j, y0.min():y0.max():nbins * 1j] 

134 zi = k(np.vstack([xi.flatten(), yi.flatten()])) 

135 

136 # plot a density 

137 ax.set_title('Density') 

138 ax.pcolormesh(xi, yi, zi.reshape(xi.shape), shading='auto', cmap='coolwarm') 

139 ax.set_xlabel('log(signals)', fontsize=20) 

140 ax.set_ylabel(f'log({featname})', fontsize=20) 

141 

142 return ax 

143 

144 

145def plot_neuron_feature_pair(exp, cell_id, featname, ind1=0, ind2=100000, ds=1, 

146 add_density_plot=True, ax=None, title=None): 

147 """ 

148 Plot neural activity time series alongside behavioral feature. 

149  

150 Parameters 

151 ---------- 

152 exp : Experiment 

153 Experiment object containing neurons and features. 

154 cell_id : int 

155 Index of the neuron. 

156 featname : str 

157 Name of the behavioral feature. 

158 ind1 : int, optional 

159 Start frame index. Default: 0. 

160 ind2 : int, optional 

161 End frame index. Default: 100000. 

162 ds : int, optional 

163 Downsampling factor. Default: 1. 

164 add_density_plot : bool, optional 

165 Whether to add density subplot. Default: True. 

166 ax : matplotlib.axes.Axes, optional 

167 Axes to plot on (ignored if add_density_plot=True). 

168 title : str, optional 

169 Custom title for the plot. 

170  

171 Returns 

172 ------- 

173 fig : matplotlib.figure.Figure 

174 Figure containing the plot(s). 

175 """ 

176 

177 ind2 = min(exp.n_frames, ind2) 

178 ca = exp.neurons[cell_id].ca.scdata[ind1:ind2][::ds] 

179 #rca = rescale(rankdata(ca)) 

180 feature = getattr(exp, featname) 

181 bdata = feature.scdata[ind1:ind2][::ds] 

182 rbdata = rescale(rankdata(bdata)) 

183 

184 if ax is None: 

185 if add_density_plot: 

186 fig, axs = plt.subplots(1, 2, figsize=(12, 6), width_ratios=[0.6, 0.4]) 

187 ax0, ax1 = axs 

188 ax1 = make_beautiful(ax1) 

189 else: 

190 fig, ax0 = plt.subplots(figsize=(10, 6)) 

191 ax1 = None 

192 else: 

193 # When ax is provided externally, use it as ax0 

194 ax0 = ax 

195 ax1 = None 

196 add_density_plot = False # Cannot add density plot when single axis provided 

197 fig = ax0.figure # Get the figure from the provided axis 

198 

199 ax0 = make_beautiful(ax0) 

200 

201 ax0.plot(np.arange(ind1, ind2)[::ds], ca, c='b', linewidth=2, alpha=0.5, label=f'neuron {cell_id}') 

202 if feature.discrete: 

203 # For discrete features, use the original data to find where feature is active (1) 

204 active_indices = np.where(feature.data[ind1:ind2][::ds] == 1)[0] 

205 if len(active_indices) > 0: 

206 ax0.scatter(np.arange(ind1, ind2)[::ds][active_indices], ca[active_indices], 

207 c='r', s=50, alpha=0.7, zorder=10, label=f'{featname}=1') 

208 else: 

209 ax0.plot(np.arange(ind1, ind2)[::ds], rbdata, c='r', linewidth=2, alpha=0.5) 

210 

211 if add_density_plot: 

212 plot_neuron_feature_density(exp, 'calcium', cell_id, featname, ind1=ind1, ind2=ind2, ds=ds, ax=ax1, compute_wsd=False) 

213 

214 ax0.set_xlabel('timeframes', fontsize=20) 

215 ax0.set_ylabel('Signal/behavior', fontsize=20) 

216 

217 # Add legend if we have labels 

218 if feature.discrete: 

219 ax0.legend(loc='upper right') 

220 

221 if title is None: 

222 title = f'{exp.signature} Neuron {cell_id}, feature {featname}' 

223 

224 fig.suptitle(title, fontsize=20) 

225 plt.tight_layout() 

226 

227 return fig 

228 

229 

230def plot_disentanglement_heatmap(disent_matrix, count_matrix, feat_names, 

231 title=None, figsize=(12, 10), dpi=100, 

232 cmap=None, vmin=0, vmax=100, 

233 cbar_label='Disentanglement score (%)', 

234 fontsize=14, title_fontsize=18, 

235 show_grid=True, grid_alpha=0.3): 

236 """Plot disentanglement analysis results as a heatmap. 

237  

238 Creates a heatmap showing the relative disentanglement scores between 

239 feature pairs. Each cell (i,j) shows the percentage of neurons where 

240 feature i was primary when paired with feature j. 

241  

242 Parameters 

243 ---------- 

244 disent_matrix : ndarray 

245 Disentanglement matrix from disentangle_all_selectivities. 

246 count_matrix : ndarray 

247 Count matrix from disentangle_all_selectivities. 

248 feat_names : list of str 

249 Feature names corresponding to matrix indices. 

250 title : str, optional 

251 Plot title. Default: 'Disentanglement Analysis'. 

252 figsize : tuple, optional 

253 Figure size (width, height). Default: (12, 10). 

254 dpi : int, optional 

255 Figure DPI. Default: 100. 

256 cmap : str or Colormap, optional 

257 Colormap to use. Default: custom red-white-green gradient. 

258 vmin : float, optional 

259 Minimum value for colormap. Default: 0. 

260 vmax : float, optional 

261 Maximum value for colormap. Default: 100. 

262 cbar_label : str, optional 

263 Colorbar label. Default: 'Disentanglement score (%)'. 

264 fontsize : int, optional 

265 Font size for tick labels. Default: 14. 

266 title_fontsize : int, optional 

267 Font size for title. Default: 18. 

268 show_grid : bool, optional 

269 Whether to show grid lines. Default: True. 

270 grid_alpha : float, optional 

271 Grid transparency. Default: 0.3. 

272  

273 Returns 

274 ------- 

275 fig : matplotlib.figure.Figure 

276 Figure containing the heatmap. 

277 ax : matplotlib.axes.Axes 

278 Axes containing the heatmap. 

279  

280 Notes 

281 ----- 

282 The heatmap uses a diverging colormap where: 

283 - Red indicates low disentanglement (feature is redundant) 

284 - White indicates balanced contribution (~50%) 

285 - Green indicates high disentanglement (feature is primary) 

286  

287 Cells are masked (shown in white) where no data is available. 

288 """ 

289 import seaborn as sns 

290 from matplotlib.colors import LinearSegmentedColormap 

291 import pandas as pd 

292 

293 # Calculate relative disentanglement matrix (as percentage) 

294 with np.errstate(divide='ignore', invalid='ignore'): 

295 rel_disent_matrix = np.divide(disent_matrix, count_matrix) * 100 

296 rel_disent_matrix[count_matrix == 0] = np.nan 

297 

298 # Create default colormap if not provided 

299 if cmap is None: 

300 # Red -> Gray -> Green gradient  

301 # Gray at 50% represents equal selectivity (no disentanglement) 

302 colors = [(1, 0, 0), (0.7, 0.7, 0.7), (0, 1, 0)] 

303 n_bins = 100 

304 cmap = LinearSegmentedColormap.from_list("disentanglement_cmap", colors, N=n_bins) 

305 

306 # Create DataFrame for seaborn 

307 df_heatmap = pd.DataFrame(rel_disent_matrix, columns=feat_names, index=feat_names) 

308 

309 # Create figure 

310 fig, ax = plt.subplots(figsize=figsize, dpi=dpi) 

311 

312 # Create heatmap 

313 sns.heatmap(df_heatmap, 

314 ax=ax, 

315 cmap=cmap, 

316 vmin=vmin, 

317 vmax=vmax, 

318 cbar_kws={'label': cbar_label}, 

319 mask=np.isnan(rel_disent_matrix), 

320 square=True, 

321 linewidths=0.5, 

322 linecolor='gray') 

323 

324 # Add grid if requested 

325 if show_grid: 

326 ax.grid(True, linestyle='-', alpha=grid_alpha, color='black') 

327 

328 # Set title 

329 if title is None: 

330 title = 'Disentanglement Analysis' 

331 ax.set_title(title, fontsize=title_fontsize, pad=20) 

332 

333 # Configure tick labels 

334 ax.set_xticks(np.arange(len(feat_names)) + 0.5) 

335 ax.set_xticklabels(feat_names, fontsize=fontsize, rotation=45, ha='right') 

336 ax.set_yticks(np.arange(len(feat_names)) + 0.5) 

337 ax.set_yticklabels(feat_names, fontsize=fontsize, rotation=0) 

338 

339 # Set axis labels 

340 ax.set_xlabel('Feature (as secondary)', fontsize=fontsize + 2) 

341 ax.set_ylabel('Feature (as primary)', fontsize=fontsize + 2) 

342 

343 plt.tight_layout() 

344 

345 return fig, ax 

346 

347 

348def plot_disentanglement_summary(disent_matrix, count_matrix, feat_names, 

349 experiments=None, title_prefix='', 

350 figsize=(14, 10), dpi=100): 

351 """Plot comprehensive disentanglement analysis with multiple views. 

352  

353 Creates a figure with multiple subplots showing: 

354 1. Disentanglement heatmap 

355 2. Feature dominance scores 

356 3. Pairwise interaction counts 

357  

358 Parameters 

359 ---------- 

360 disent_matrix : ndarray or list of ndarray 

361 Disentanglement matrix(es). If list, matrices are summed. 

362 count_matrix : ndarray or list of ndarray 

363 Count matrix(es). If list, matrices are summed. 

364 feat_names : list of str 

365 Feature names corresponding to matrix indices. 

366 experiments : list of str, optional 

367 Experiment names if multiple matrices provided. 

368 title_prefix : str, optional 

369 Prefix for the main title. 

370 figsize : tuple, optional 

371 Figure size. Default: (14, 10). 

372 dpi : int, optional 

373 Figure DPI. Default: 100. 

374  

375 Returns 

376 ------- 

377 fig : matplotlib.figure.Figure 

378 Figure containing all subplots. 

379 """ 

380 # Handle multiple experiments 

381 if isinstance(disent_matrix, list): 

382 total_disent = np.sum(disent_matrix, axis=0) 

383 total_count = np.sum(count_matrix, axis=0) 

384 n_exp = len(disent_matrix) 

385 else: 

386 total_disent = disent_matrix 

387 total_count = count_matrix 

388 n_exp = 1 

389 

390 # Create figure with subplots 

391 fig = plt.figure(figsize=figsize, dpi=dpi) 

392 gs = fig.add_gridspec(2, 2, height_ratios=[3, 1], width_ratios=[3, 1]) 

393 

394 # Main heatmap 

395 ax_main = fig.add_subplot(gs[0, 0]) 

396 

397 # Calculate relative disentanglement matrix 

398 with np.errstate(divide='ignore', invalid='ignore'): 

399 rel_disent_matrix = np.divide(total_disent, total_count) * 100 

400 rel_disent_matrix[total_count == 0] = np.nan 

401 

402 # Create colormap 

403 from matplotlib.colors import LinearSegmentedColormap 

404 import seaborn as sns 

405 import pandas as pd 

406 colors = [(1, 0, 0), (1, 1, 1), (0, 1, 0)] 

407 cmap = LinearSegmentedColormap.from_list("disentanglement_cmap", colors, N=100) 

408 

409 # Create DataFrame and plot 

410 df_heatmap = pd.DataFrame(rel_disent_matrix, columns=feat_names, index=feat_names) 

411 sns.heatmap(df_heatmap, ax=ax_main, cmap=cmap, vmin=0, vmax=100, 

412 cbar_kws={'label': 'Disentanglement score (%)'}, 

413 mask=np.isnan(rel_disent_matrix), square=True, 

414 linewidths=0.5, linecolor='gray') 

415 ax_main.set_title('Disentanglement Heatmap') 

416 

417 # Feature dominance scores (how often each feature is primary) 

418 ax_dom = fig.add_subplot(gs[0, 1]) 

419 with np.errstate(divide='ignore', invalid='ignore'): 

420 dominance_scores = np.nansum(total_disent / total_count, axis=1) 

421 y_pos = np.arange(len(feat_names)) 

422 ax_dom.barh(y_pos, dominance_scores, color='green', alpha=0.7) 

423 ax_dom.set_yticks(y_pos) 

424 ax_dom.set_yticklabels(feat_names) 

425 ax_dom.set_xlabel('Dominance Score') 

426 ax_dom.set_title('Feature Dominance') 

427 ax_dom.grid(True, alpha=0.3) 

428 

429 # Interaction counts 

430 ax_counts = fig.add_subplot(gs[1, :]) 

431 pair_counts = [] 

432 pair_labels = [] 

433 for i in range(len(feat_names)): 

434 for j in range(i + 1, len(feat_names)): 

435 if total_count[i, j] > 0: 

436 pair_counts.append(total_count[i, j]) 

437 pair_labels.append(f'{feat_names[i]}-{feat_names[j]}') 

438 

439 x_pos = np.arange(len(pair_counts)) 

440 ax_counts.bar(x_pos, pair_counts, color='blue', alpha=0.7) 

441 ax_counts.set_xticks(x_pos) 

442 ax_counts.set_xticklabels(pair_labels, rotation=45, ha='right') 

443 ax_counts.set_ylabel('Number of neurons') 

444 ax_counts.set_title('Pairwise interaction counts') 

445 ax_counts.grid(True, axis='y', alpha=0.3) 

446 

447 # Main title 

448 if n_exp > 1: 

449 title = f'{title_prefix}Disentanglement Analysis ({n_exp} experiments)' 

450 else: 

451 title = f'{title_prefix}Disentanglement Analysis' 

452 fig.suptitle(title, fontsize=16, y=0.98) 

453 

454 plt.tight_layout() 

455 return fig 

456 

457 

458def plot_selectivity_heatmap(exp, significant_neurons, 

459 metric='mi', cmap='viridis', use_log_scale=False, 

460 vmin=None, vmax=None, figsize=(10, 8), 

461 significance_threshold=None, ax=None): 

462 """Create a heatmap showing metric values for selective neuron-feature pairs. 

463  

464 Parameters 

465 ---------- 

466 exp : Experiment 

467 The experiment object containing all data and results 

468 significant_neurons : dict 

469 Dictionary mapping neuron IDs to lists of significant features 

470 metric : str, optional 

471 Which metric to display ('mi' for mutual information, 'corr' for correlation) 

472 Default: 'mi' 

473 cmap : str, optional 

474 Colormap to use. Default: 'viridis' 

475 use_log_scale : bool, optional 

476 Whether to use log scale for metric values. Default: False 

477 vmin : float, optional 

478 Minimum value for colormap. If None, auto-determined from data 

479 vmax : float, optional 

480 Maximum value for colormap. If None, auto-determined from data 

481 figsize : tuple, optional 

482 Figure size (ignored if ax provided). Default: (10, 8) 

483 significance_threshold : float, optional 

484 If provided, only show pairs with p-value below this threshold 

485 ax : matplotlib.axes.Axes, optional 

486 Axes to plot on. If None, creates new figure. 

487  

488 Returns 

489 ------- 

490 fig : matplotlib.figure.Figure 

491 Figure containing the heatmap 

492 ax : matplotlib.axes.Axes 

493 Axes containing the heatmap 

494 stats : dict 

495 Dictionary containing statistics about the data: 

496 - n_selective: number of selective neurons 

497 - n_pairs: total number of significant pairs 

498 - selectivity_rate: percentage of selective neurons 

499 - metric_values: list of all non-zero metric values 

500 - sparsity: percentage of zero entries in the matrix 

501 """ 

502 # Get all features and create ordered lists 

503 all_features = sorted([f for f in exp.dynamic_features.keys() if isinstance(f, str)]) 

504 all_neurons = list(range(exp.n_cells)) 

505 

506 # Create matrix with metric values (0 for non-selective pairs) 

507 selectivity_matrix = np.zeros((len(all_neurons), len(all_features))) 

508 

509 # Collect all metric values for statistics 

510 all_metric_values = [] 

511 

512 for neuron_idx, cell_id in enumerate(all_neurons): 

513 for feat_idx, feat_name in enumerate(all_features): 

514 # Check if this neuron-feature pair is significant 

515 if cell_id in significant_neurons and feat_name in significant_neurons[cell_id]: 

516 # Get the statistics for this pair 

517 pair_stats = exp.get_neuron_feature_pair_stats(cell_id, feat_name, mode='calcium') 

518 

519 # Check significance threshold if provided 

520 if significance_threshold is not None: 

521 pval = pair_stats.get('pval', None) 

522 # Skip if pval is None (failed stage 1) or above threshold 

523 if pval is None or pval > significance_threshold: 

524 continue 

525 

526 # Get the metric value - 'me' contains the metric value for whichever metric was used 

527 value = pair_stats.get('me', 0) 

528 

529 selectivity_matrix[neuron_idx, feat_idx] = value 

530 all_metric_values.append(value) 

531 

532 # Apply log scale if requested 

533 if use_log_scale and len(all_metric_values) > 0: 

534 # Add small epsilon to avoid log(0) 

535 epsilon = 1e-10 

536 selectivity_matrix = np.log10(selectivity_matrix + epsilon) 

537 # Set zeros back to a special value for visualization 

538 selectivity_matrix[selectivity_matrix < np.log10(epsilon * 2)] = np.nan 

539 

540 # Create figure if needed 

541 if ax is None: 

542 fig, ax = plt.subplots(figsize=figsize) 

543 else: 

544 fig = ax.figure 

545 

546 # Determine color limits 

547 if len(all_metric_values) > 0: 

548 if vmin is None: 

549 vmin = 0 if not use_log_scale else np.log10(min(all_metric_values)) 

550 if vmax is None: 

551 vmax = max(all_metric_values) if not use_log_scale else np.log10(max(all_metric_values)) 

552 else: 

553 vmin = 0 

554 vmax = 1 

555 

556 # Create masked array to handle NaN values properly 

557 masked_matrix = np.ma.masked_invalid(selectivity_matrix) 

558 

559 # Plot heatmap 

560 im = ax.imshow(masked_matrix, cmap=cmap, aspect='auto', interpolation='nearest', 

561 vmin=vmin, vmax=vmax) 

562 

563 # Set ticks and labels 

564 ax.set_xticks(range(len(all_features))) 

565 ax.set_xticklabels(all_features, rotation=45, ha='right') 

566 ax.set_yticks(range(0, len(all_neurons), max(1, len(all_neurons)//20))) # Show ~20 neuron labels 

567 ax.set_yticklabels(range(0, len(all_neurons), max(1, len(all_neurons)//20))) 

568 

569 # Labels and title 

570 ax.set_xlabel('Features', fontsize=12) 

571 ax.set_ylabel('Neurons', fontsize=12) 

572 metric_name = 'Mutual Information' if metric == 'mi' else 'Correlation' 

573 scale_text = ' (log₁₀)' if use_log_scale else '' 

574 ax.set_title(f'Neuronal Selectivity: {metric_name}{scale_text}', fontsize=14, fontweight='bold') 

575 

576 # Add colorbar with appropriate label 

577 cbar = plt.colorbar(im, ax=ax) 

578 cbar.set_label(f'{metric_name}{scale_text}', rotation=270, labelpad=20) 

579 

580 # Calculate statistics 

581 n_selective = len(significant_neurons) 

582 n_pairs = sum(len(features) for features in significant_neurons.values()) 

583 selectivity_rate = (n_selective / exp.n_cells) * 100 

584 sparsity = (1 - n_pairs/(len(all_neurons)*len(all_features)))*100 

585 

586 # Add summary text 

587 summary_lines = [ 

588 f'Selective neurons: {n_selective}/{exp.n_cells} ({selectivity_rate:.1f}%)', 

589 f'Total selective pairs: {n_pairs}' 

590 ] 

591 

592 if len(all_metric_values) > 0: 

593 summary_lines.extend([ 

594 f'{metric.upper()} range: [{min(all_metric_values):.3f}, {max(all_metric_values):.3f}]', 

595 f'Mean {metric.upper()}: {np.mean(all_metric_values):.3f}' 

596 ]) 

597 

598 summary_text = '\n'.join(summary_lines) 

599 # Position text in the lower right corner to avoid colorbar overlap 

600 fig.text(0.98, 0.02, summary_text, transform=fig.transFigure, 

601 fontsize=10, verticalalignment='bottom', horizontalalignment='right', 

602 bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) 

603 

604 # Add grid for better readability 

605 ax.set_xticks(np.arange(len(all_features)) - 0.5, minor=True) 

606 ax.set_yticks(np.arange(len(all_neurons)) - 0.5, minor=True) 

607 ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3) 

608 

609 plt.tight_layout() 

610 

611 # Return statistics 

612 stats = { 

613 'n_selective': n_selective, 

614 'n_pairs': n_pairs, 

615 'selectivity_rate': selectivity_rate, 

616 'metric_values': all_metric_values, 

617 'sparsity': sparsity 

618 } 

619 

620 return fig, ax, stats