Coverage for src/driada/utils/visual.py: 86.32%

285 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1""" 

2Visualization utilities for DRIADA 

3================================== 

4 

5This module provides reusable visualization functions for embedding comparisons, 

6trajectory plots, and component interpretation in dimensionality reduction analyses. 

7""" 

8 

9import numpy as np 

10import matplotlib.pyplot as plt 

11import matplotlib.gridspec as gridspec 

12import seaborn as sns 

13from typing import Dict, List, Tuple, Optional, Union, Any, Callable 

14from scipy.stats import gaussian_kde 

15import pandas as pd 

16 

17# Default DPI for all plots 

18DEFAULT_DPI = 150 

19 

20 

21def plot_embedding_comparison( 

22 embeddings: Dict[str, np.ndarray], 

23 features: Optional[Dict[str, np.ndarray]] = None, 

24 feature_names: Optional[Dict[str, str]] = None, 

25 methods: Optional[List[str]] = None, 

26 with_trajectory: bool = True, 

27 compute_metrics: bool = True, 

28 trajectory_kwargs: Optional[Dict] = None, 

29 figsize: Optional[Tuple[float, float]] = None, 

30 save_path: Optional[str] = None, 

31 dpi: int = DEFAULT_DPI 

32) -> plt.Figure: 

33 """ 

34 Create comprehensive embedding comparison figure with behavioral features and trajectories. 

35  

36 Parameters 

37 ---------- 

38 embeddings : dict 

39 Dictionary mapping method names to embedding arrays (n_samples, n_components) 

40 features : dict, optional 

41 Dictionary mapping feature names to feature arrays 

42 Default features used: 'angle' (circular position) and 'speed' 

43 feature_names : dict, optional 

44 Dictionary mapping feature keys to display names 

45 methods : list of str, optional 

46 List of methods to plot (if None, uses all keys in embeddings) 

47 with_trajectory : bool, default True 

48 Whether to include trajectory visualization as a third row 

49 compute_metrics : bool, default True 

50 Whether to compute and display metrics (density contours, percentiles) 

51 trajectory_kwargs : dict, optional 

52 Additional keyword arguments for trajectory plotting 

53 figsize : tuple, optional 

54 Figure size (width, height). If None, computed based on number of methods 

55 save_path : str, optional 

56 Path to save the figure 

57 dpi : int, default DEFAULT_DPI 

58 DPI resolution for saved figure 

59  

60 Returns 

61 ------- 

62 fig : matplotlib.figure.Figure 

63 The generated figure 

64 """ 

65 if methods is None: 

66 methods = list(embeddings.keys()) 

67 

68 n_methods = len(methods) 

69 n_rows = 3 if with_trajectory else 2 

70 

71 # Set figure size 

72 if figsize is None: 

73 figsize = (6 * n_methods, 5 * n_rows) 

74 

75 # Create figure and grid 

76 fig = plt.figure(figsize=figsize) 

77 gs = gridspec.GridSpec(n_rows, n_methods, hspace=0.3, wspace=0.3) 

78 

79 # Default feature names 

80 if feature_names is None: 

81 feature_names = { 

82 'angle': 'Head Direction', 

83 'speed': 'Speed' 

84 } 

85 

86 # Default trajectory kwargs 

87 if trajectory_kwargs is None: 

88 trajectory_kwargs = {} 

89 

90 default_traj_kwargs = { 

91 'linewidth': 0.8, 

92 'alpha': 0.3, 

93 'color': 'k', 

94 'arrow_spacing': 20, 

95 'arrow_scale': 0.3, 

96 'start_marker': 'o', 

97 'end_marker': 's', 

98 'marker_size': 100 

99 } 

100 default_traj_kwargs.update(trajectory_kwargs) 

101 

102 for i, method in enumerate(methods): 

103 if method not in embeddings: 

104 continue 

105 

106 embedding = embeddings[method] 

107 

108 # First row: colored by angle/position 

109 ax1 = fig.add_subplot(gs[0, i]) 

110 

111 if features is not None and 'angle' in features: 

112 angle = features['angle'] 

113 # Normalize angle to [0, 1] for color mapping 

114 angle_norm = (angle + np.pi) / (2 * np.pi) 

115 

116 scatter = ax1.scatter( 

117 embedding[:, 0], embedding[:, 1], 

118 c=angle_norm, cmap='hsv', alpha=0.7, s=2, 

119 vmin=0, vmax=1, edgecolors='none' 

120 ) 

121 

122 cbar = plt.colorbar(scatter, ax=ax1, label=feature_names.get('angle', 'Angle')) 

123 # Set colorbar ticks to show actual angles 

124 cbar.set_ticks([0, 0.25, 0.5, 0.75, 1]) 

125 cbar.set_ticklabels(['-π', '-π/2', '0', 'π/2', 'π']) 

126 

127 # Add density contours if requested 

128 if compute_metrics: 

129 try: 

130 kde = gaussian_kde(embedding[:, :2].T) 

131 x_min, x_max = ax1.get_xlim() 

132 y_min, y_max = ax1.get_ylim() 

133 X, Y = np.mgrid[x_min:x_max:50j, y_min:y_max:50j] 

134 positions_grid = np.vstack([X.ravel(), Y.ravel()]) 

135 Z = np.reshape(kde(positions_grid).T, X.shape) 

136 ax1.contour(X, Y, Z, colors='gray', alpha=0.3, linewidths=0.5) 

137 except: 

138 pass # Skip contours if KDE fails 

139 else: 

140 ax1.scatter(embedding[:, 0], embedding[:, 1], alpha=0.6, s=1) 

141 

142 ax1.set_xlabel('Component 0') 

143 ax1.set_ylabel('Component 1') 

144 ax1.set_title(f'{method.upper()} - {feature_names.get("angle", "Position")}') 

145 ax1.grid(True, alpha=0.3) 

146 

147 # Second row: colored by speed or second feature 

148 ax2 = fig.add_subplot(gs[1, i]) 

149 

150 if features is not None and 'speed' in features: 

151 speed = features['speed'] 

152 

153 scatter = ax2.scatter( 

154 embedding[:, 0], embedding[:, 1], 

155 c=speed, cmap='viridis', alpha=0.7, s=2, 

156 edgecolors='none' 

157 ) 

158 

159 cbar = plt.colorbar(scatter, ax=ax2, label=feature_names.get('speed', 'Speed')) 

160 

161 # Add percentile markers if requested 

162 if compute_metrics: 

163 speed_percentiles = np.percentile(speed, [25, 50, 75]) 

164 for p, val in zip([25, 50, 75], speed_percentiles): 

165 cbar.ax.axhline(y=val, color='red', alpha=0.3, linewidth=0.5) 

166 cbar.ax.text( 

167 1.05, val, f'{p}%', 

168 transform=cbar.ax.get_yaxis_transform(), 

169 fontsize=8, va='center' 

170 ) 

171 else: 

172 ax2.scatter(embedding[:, 0], embedding[:, 1], alpha=0.6, s=1) 

173 

174 ax2.set_xlabel('Component 0') 

175 ax2.set_ylabel('Component 1') 

176 ax2.set_title(f'{method.upper()} - {feature_names.get("speed", "Feature 2")}') 

177 ax2.grid(True, alpha=0.3) 

178 

179 # Third row: trajectory visualization 

180 if with_trajectory: 

181 ax3 = fig.add_subplot(gs[2, i]) 

182 

183 # Plot trajectory 

184 ax3.plot( 

185 embedding[:, 0], embedding[:, 1], 

186 color=default_traj_kwargs['color'], 

187 alpha=default_traj_kwargs['alpha'], 

188 linewidth=default_traj_kwargs['linewidth'] 

189 ) 

190 

191 # Add arrow markers to show direction 

192 trajectory_samples = len(embedding) 

193 arrow_spacing = max(1, trajectory_samples // default_traj_kwargs['arrow_spacing']) 

194 

195 for j in range(0, trajectory_samples - arrow_spacing, arrow_spacing): 

196 if j + 1 < trajectory_samples: 

197 dx = embedding[j+1, 0] - embedding[j, 0] 

198 dy = embedding[j+1, 1] - embedding[j, 1] 

199 

200 # Only plot arrow if movement is significant 

201 if np.sqrt(dx**2 + dy**2) > 0.001: 

202 ax3.arrow( 

203 embedding[j, 0], embedding[j, 1], 

204 dx * default_traj_kwargs['arrow_scale'], 

205 dy * default_traj_kwargs['arrow_scale'], 

206 head_width=0.02, head_length=0.02, 

207 fc='red', ec='red', alpha=0.6 

208 ) 

209 

210 # Mark start and end points 

211 ax3.scatter( 

212 embedding[0, 0], embedding[0, 1], 

213 c='green', s=default_traj_kwargs['marker_size'], 

214 marker=default_traj_kwargs['start_marker'], 

215 edgecolors='black', linewidth=2, label='Start', zorder=5 

216 ) 

217 ax3.scatter( 

218 embedding[-1, 0], embedding[-1, 1], 

219 c='red', s=default_traj_kwargs['marker_size'], 

220 marker=default_traj_kwargs['end_marker'], 

221 edgecolors='black', linewidth=2, label='End', zorder=5 

222 ) 

223 

224 ax3.set_xlabel('Component 0') 

225 ax3.set_ylabel('Component 1') 

226 ax3.set_title(f'{method.upper()} - Trajectory') 

227 ax3.grid(True, alpha=0.3) 

228 ax3.legend(loc='best', fontsize=8) 

229 ax3.set_aspect('equal', adjustable='datalim') 

230 

231 # Set main title 

232 title = 'Population Embeddings: Behavioral Features' 

233 if with_trajectory: 

234 title += ' and Trajectories' 

235 plt.suptitle(title, fontsize=16) 

236 

237 # Save if requested 

238 if save_path: 

239 plt.savefig(save_path, dpi=dpi, bbox_inches='tight') 

240 

241 return fig 

242 

243 

244def plot_trajectories( 

245 embeddings: Dict[str, np.ndarray], 

246 methods: Optional[List[str]] = None, 

247 trajectory_kwargs: Optional[Dict] = None, 

248 figsize: Optional[Tuple[float, float]] = None, 

249 save_path: Optional[str] = None, 

250 dpi: int = DEFAULT_DPI 

251) -> plt.Figure: 

252 """ 

253 Create figure showing trajectories in embedding space for multiple methods. 

254  

255 Parameters 

256 ---------- 

257 embeddings : dict 

258 Dictionary mapping method names to embedding arrays 

259 methods : list of str, optional 

260 List of methods to plot 

261 trajectory_kwargs : dict, optional 

262 Keyword arguments for trajectory plotting 

263 figsize : tuple, optional 

264 Figure size 

265 save_path : str, optional 

266 Path to save the figure 

267 dpi : int, default DEFAULT_DPI 

268 DPI resolution for saved figure 

269  

270 Returns 

271 ------- 

272 fig : matplotlib.figure.Figure 

273 The generated figure 

274 """ 

275 if methods is None: 

276 methods = list(embeddings.keys()) 

277 

278 n_methods = len(methods) 

279 

280 if figsize is None: 

281 figsize = (6 * n_methods, 5) 

282 

283 fig = plt.figure(figsize=figsize) 

284 

285 # Default trajectory kwargs 

286 if trajectory_kwargs is None: 

287 trajectory_kwargs = {} 

288 

289 default_kwargs = { 

290 'linewidth': 0.8, 

291 'alpha': 0.3, 

292 'color': 'k', 

293 'arrow_spacing': 20, 

294 'arrow_scale': 0.3, 

295 'start_marker': 'o', 

296 'end_marker': 's', 

297 'marker_size': 100 

298 } 

299 default_kwargs.update(trajectory_kwargs) 

300 

301 for i, method in enumerate(methods): 

302 if method not in embeddings: 

303 continue 

304 

305 embedding = embeddings[method] 

306 ax = fig.add_subplot(1, n_methods, i+1) 

307 

308 # Plot trajectory 

309 ax.plot( 

310 embedding[:, 0], embedding[:, 1], 

311 color=default_kwargs['color'], 

312 alpha=default_kwargs['alpha'], 

313 linewidth=default_kwargs['linewidth'] 

314 ) 

315 

316 # Add direction arrows 

317 trajectory_samples = len(embedding) 

318 arrow_spacing = max(1, trajectory_samples // default_kwargs['arrow_spacing']) 

319 

320 for j in range(0, trajectory_samples - arrow_spacing, arrow_spacing): 

321 if j + 1 < trajectory_samples: 

322 dx = embedding[j+1, 0] - embedding[j, 0] 

323 dy = embedding[j+1, 1] - embedding[j, 1] 

324 

325 if np.sqrt(dx**2 + dy**2) > 0.001: 

326 ax.arrow( 

327 embedding[j, 0], embedding[j, 1], 

328 dx * default_kwargs['arrow_scale'], 

329 dy * default_kwargs['arrow_scale'], 

330 head_width=0.02, head_length=0.02, 

331 fc='red', ec='red', alpha=0.6 

332 ) 

333 

334 # Mark start and end 

335 ax.scatter( 

336 embedding[0, 0], embedding[0, 1], 

337 c='green', s=default_kwargs['marker_size'], 

338 marker=default_kwargs['start_marker'], 

339 edgecolors='black', linewidth=2, label='Start', zorder=5 

340 ) 

341 ax.scatter( 

342 embedding[-1, 0], embedding[-1, 1], 

343 c='red', s=default_kwargs['marker_size'], 

344 marker=default_kwargs['end_marker'], 

345 edgecolors='black', linewidth=2, label='End', zorder=5 

346 ) 

347 

348 ax.set_xlabel('Component 0') 

349 ax.set_ylabel('Component 1') 

350 ax.set_title(f'{method.upper()} - Trajectory') 

351 ax.grid(True, alpha=0.3) 

352 ax.legend(loc='best', fontsize=8) 

353 ax.set_aspect('equal', adjustable='datalim') 

354 

355 plt.suptitle('Temporal Trajectories in Embedding Space', fontsize=16) 

356 plt.tight_layout() 

357 

358 if save_path: 

359 plt.savefig(save_path, dpi=dpi, bbox_inches='tight') 

360 

361 return fig 

362 

363 

364def plot_component_interpretation( 

365 mi_matrices: Dict[str, np.ndarray], 

366 feature_names: List[str], 

367 methods: Optional[List[str]] = None, 

368 n_components: Optional[int] = None, 

369 metadata: Optional[Dict[str, Dict]] = None, 

370 compute_metrics: bool = True, 

371 figsize: Optional[Tuple[float, float]] = None, 

372 save_path: Optional[str] = None, 

373 dpi: int = DEFAULT_DPI 

374) -> plt.Figure: 

375 """ 

376 Create figure showing mutual information between embedding components and features. 

377  

378 Parameters 

379 ---------- 

380 mi_matrices : dict 

381 Dictionary mapping method names to MI matrices (n_features, n_components) 

382 feature_names : list of str 

383 Names of features for y-axis labels 

384 methods : list of str, optional 

385 List of methods to plot 

386 n_components : int, optional 

387 Number of components to show (default: min 5 or available) 

388 metadata : dict, optional 

389 Dictionary of metadata for each method (e.g., explained variance for PCA) 

390 compute_metrics : bool, default True 

391 Whether to show additional metrics (e.g., explained variance) 

392 figsize : tuple, optional 

393 Figure size 

394 save_path : str, optional 

395 Path to save the figure 

396 dpi : int, default DEFAULT_DPI 

397 DPI resolution for saved figure 

398  

399 Returns 

400 ------- 

401 fig : matplotlib.figure.Figure 

402 The generated figure 

403 """ 

404 if methods is None: 

405 methods = list(mi_matrices.keys()) 

406 

407 n_methods = len(methods) 

408 

409 if figsize is None: 

410 figsize = (8 * n_methods, 6) 

411 

412 fig = plt.figure(figsize=figsize) 

413 

414 for idx, method in enumerate(methods): 

415 if method not in mi_matrices: 

416 continue 

417 

418 mi_matrix = mi_matrices[method] 

419 

420 # Determine number of components to show 

421 n_comp_available = mi_matrix.shape[1] 

422 if n_components is None: 

423 n_comp_show = min(5, n_comp_available) 

424 else: 

425 n_comp_show = min(n_components, n_comp_available) 

426 

427 # Create subplot 

428 ax = plt.subplot(1, n_methods, idx + 1) 

429 

430 # Plot MI heatmap 

431 mi_subset = mi_matrix[:, :n_comp_show] 

432 max_mi = np.max(mi_subset) if np.max(mi_subset) > 0 else 1 

433 

434 im = ax.imshow( 

435 mi_subset, aspect='auto', cmap='YlOrRd', 

436 vmin=0, vmax=max_mi 

437 ) 

438 

439 # Set labels 

440 ax.set_xticks(range(n_comp_show)) 

441 

442 # Create component labels based on method 

443 if method.lower() == 'pca': 

444 comp_labels = [f'PC{i}' for i in range(n_comp_show)] 

445 elif method.lower() == 'umap': 

446 comp_labels = [f'UMAP{i}' for i in range(n_comp_show)] 

447 elif method.lower() == 'le': 

448 comp_labels = [f'LE{i}' for i in range(n_comp_show)] 

449 else: 

450 comp_labels = [f'{method.upper()}{i}' for i in range(n_comp_show)] 

451 

452 ax.set_xticklabels(comp_labels) 

453 ax.set_yticks(range(len(feature_names))) 

454 ax.set_yticklabels(feature_names) 

455 ax.set_xlabel(f'{method.upper()} Components') 

456 ax.set_title(f'{method.upper()} Component-Feature MI') 

457 

458 # Add MI values on cells 

459 for i in range(len(feature_names)): 

460 for j in range(n_comp_show): 

461 text_color = "black" if mi_subset[i, j] < max_mi * 0.5 else "white" 

462 ax.text( 

463 j, i, f'{mi_subset[i, j]:.3f}', 

464 ha="center", va="center", 

465 color=text_color, fontsize=9 

466 ) 

467 

468 # Add colorbar 

469 cbar = plt.colorbar(im, ax=ax, label='Mean MI (bits)') 

470 

471 # Add method-specific metrics if available 

472 if compute_metrics and metadata is not None and method in metadata: 

473 method_meta = metadata[method] 

474 

475 # For PCA, show explained variance 

476 if method.lower() == 'pca' and 'explained_variance_ratio' in method_meta: 

477 var_exp = method_meta['explained_variance_ratio'][:n_comp_show] 

478 var_text = 'Var explained: ' + ', '.join([f'{v*100:.1f}%' for v in var_exp]) 

479 ax.text( 

480 0.5, -0.15, var_text, 

481 transform=ax.transAxes, 

482 ha='center', va='top', 

483 fontsize=8, style='italic' 

484 ) 

485 

486 plt.suptitle('Component Interpretation: Mutual Information between Components and Features', fontsize=16) 

487 plt.tight_layout() 

488 

489 if save_path: 

490 plt.savefig(save_path, dpi=dpi, bbox_inches='tight') 

491 

492 return fig 

493 

494 

495def plot_embeddings_grid( 

496 embeddings: Dict[str, Dict[str, np.ndarray]], 

497 labels: Optional[Union[np.ndarray, Dict[str, np.ndarray]]] = None, 

498 methods: Optional[List[str]] = None, 

499 scenarios: Optional[List[str]] = None, 

500 metrics: Optional[Dict[str, Dict[str, Dict[str, float]]]] = None, 

501 colormap: str = 'viridis', 

502 figsize: Optional[Tuple[float, float]] = None, 

503 n_cols: int = 4, 

504 save_path: Optional[str] = None, 

505 dpi: int = DEFAULT_DPI 

506) -> plt.Figure: 

507 """ 

508 Create grid of embeddings for multiple methods and scenarios. 

509  

510 Parameters 

511 ---------- 

512 embeddings : dict of dict 

513 Nested dictionary: {method: {scenario: embedding_array}} 

514 labels : array or dict, optional 

515 Color labels for points. Can be array (same for all) or dict matching structure 

516 methods : list, optional 

517 Methods to plot (default: all in embeddings) 

518 scenarios : list, optional 

519 Scenarios to plot (default: all available) 

520 metrics : dict, optional 

521 Nested dict of metrics: {method: {scenario: {metric_name: value}}} 

522 colormap : str 

523 Colormap for scatter plots 

524 figsize : tuple, optional 

525 Figure size 

526 n_cols : int 

527 Number of columns in grid 

528 save_path : str, optional 

529 Path to save figure 

530 dpi : int, default DEFAULT_DPI 

531 DPI resolution for saved figure 

532  

533 Returns 

534 ------- 

535 fig : matplotlib.figure.Figure 

536 """ 

537 if methods is None: 

538 methods = list(embeddings.keys()) 

539 

540 # Collect all scenario-method pairs 

541 all_plots = [] 

542 for method in methods: 

543 if method not in embeddings: 

544 continue 

545 if scenarios is None: 

546 method_scenarios = list(embeddings[method].keys()) 

547 else: 

548 method_scenarios = [s for s in scenarios if s in embeddings[method]] 

549 

550 for scenario in method_scenarios: 

551 if embeddings[method][scenario] is not None: 

552 all_plots.append((method, scenario)) 

553 

554 if not all_plots: 

555 print("No valid embeddings to plot") 

556 return None 

557 

558 # Calculate grid dimensions 

559 n_plots = len(all_plots) 

560 n_rows = (n_plots + n_cols - 1) // n_cols 

561 

562 if figsize is None: 

563 figsize = (4 * n_cols, 4 * n_rows) 

564 

565 fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) 

566 if n_rows == 1: 

567 axes = axes.reshape(1, -1) 

568 elif n_cols == 1: 

569 axes = axes.reshape(-1, 1) 

570 

571 # Plot each embedding 

572 for idx, (method, scenario) in enumerate(all_plots): 

573 row = idx // n_cols 

574 col = idx % n_cols 

575 ax = axes[row, col] 

576 

577 embedding = embeddings[method][scenario] 

578 

579 # Get labels for coloring 

580 if labels is None: 

581 color_labels = np.arange(len(embedding)) 

582 elif isinstance(labels, dict): 

583 if method in labels and scenario in labels[method]: 

584 color_labels = labels[method][scenario] 

585 else: 

586 color_labels = np.arange(len(embedding)) 

587 else: 

588 color_labels = labels 

589 

590 # Create scatter plot 

591 scatter = ax.scatter( 

592 embedding[:, 0], embedding[:, 1], 

593 c=color_labels, cmap=colormap, 

594 s=10, alpha=0.7, edgecolors='none' 

595 ) 

596 

597 # Add title with metrics if available 

598 title = f'{method} - {scenario}' 

599 if metrics and method in metrics and scenario in metrics[method]: 

600 metric_strs = [] 

601 for metric_name, value in metrics[method][scenario].items(): 

602 if isinstance(value, float): 

603 metric_strs.append(f'{metric_name}: {value:.3f}') 

604 if metric_strs: 

605 title += '\n' + ', '.join(metric_strs[:2]) # Show max 2 metrics 

606 

607 ax.set_title(title, fontsize=10) 

608 ax.set_xlabel('Component 0') 

609 ax.set_ylabel('Component 1') 

610 ax.grid(True, alpha=0.3) 

611 

612 # Hide unused subplots 

613 for idx in range(n_plots, n_rows * n_cols): 

614 row = idx // n_cols 

615 col = idx % n_cols 

616 axes[row, col].set_visible(False) 

617 

618 plt.tight_layout() 

619 

620 if save_path: 

621 plt.savefig(save_path, dpi=dpi, bbox_inches='tight') 

622 

623 return fig 

624 

625 

626# Note: plot_quality_metrics_comparison, plot_quality_vs_speed_tradeoff 

627# are removed as they are too specific to certain examples and not reused elsewhere 

628 

629 

630def plot_neuron_selectivity_summary( 

631 selectivity_counts: Dict[str, int], 

632 total_neurons: int, 

633 colors: Optional[Dict[str, str]] = None, 

634 figsize: Tuple[float, float] = (8, 6), 

635 save_path: Optional[str] = None, 

636 dpi: int = DEFAULT_DPI 

637) -> plt.Figure: 

638 """ 

639 Create bar plot summarizing neuron selectivity categories. 

640  

641 Parameters 

642 ---------- 

643 selectivity_counts : dict 

644 Dictionary mapping category names to counts 

645 total_neurons : int 

646 Total number of neurons 

647 colors : dict, optional 

648 Dictionary mapping category names to colors 

649 figsize : tuple 

650 Figure size 

651 save_path : str, optional 

652 Path to save figure 

653 dpi : int, default DEFAULT_DPI 

654 DPI resolution for saved figure 

655  

656 Returns 

657 ------- 

658 fig : matplotlib.figure.Figure 

659 """ 

660 if colors is None: 

661 # Default colors for common categories 

662 colors = { 

663 'Spatial': 'darkgreen', 

664 'spatial': 'darkgreen', 

665 'position_2d': 'darkgreen', 

666 'x_position': 'green', 

667 'y_position': 'lightgreen', 

668 'head_direction': 'blue', 

669 'speed': 'orange', 

670 'task_type': 'red', 

671 'reward': 'purple', 

672 'Non-spatial': 'gray', 

673 'non_spatial': 'gray', 

674 'Non-selective': 'lightgray' 

675 } 

676 

677 fig, ax = plt.subplots(figsize=figsize) 

678 

679 categories = list(selectivity_counts.keys()) 

680 counts = list(selectivity_counts.values()) 

681 

682 # Get colors for each category 

683 bar_colors = [colors.get(cat, 'steelblue') for cat in categories] 

684 

685 # Create bars 

686 bars = ax.bar(categories, counts, color=bar_colors, alpha=0.7) 

687 

688 # Add percentage labels 

689 for bar, count in zip(bars, counts): 

690 percentage = count / total_neurons * 100 

691 ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 1, 

692 f'{percentage:.1f}%', ha='center', va='bottom') 

693 

694 ax.set_ylabel('Number of neurons') 

695 ax.set_title('Neuron Selectivity Categories') 

696 ax.set_ylim(0, max(counts) * 1.15) 

697 

698 # Add total count as text 

699 ax.text(0.02, 0.98, f'Total neurons: {total_neurons}', 

700 transform=ax.transAxes, ha='left', va='top', 

701 bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) 

702 

703 plt.xticks(rotation=45, ha='right') 

704 plt.tight_layout() 

705 

706 if save_path: 

707 plt.savefig(save_path, dpi=dpi, bbox_inches='tight') 

708 

709 return fig 

710 

711 

712def plot_component_selectivity_heatmap( 

713 selectivity_matrix: np.ndarray, 

714 methods: List[str], 

715 n_components_per_method: Optional[Dict[str, int]] = None, 

716 figsize: Optional[Tuple[float, float]] = None, 

717 save_path: Optional[str] = None, 

718 dpi: int = DEFAULT_DPI 

719) -> plt.Figure: 

720 """ 

721 Create heatmap showing neuron selectivity to embedding components. 

722  

723 Parameters 

724 ---------- 

725 selectivity_matrix : ndarray 

726 Matrix of shape (n_neurons, total_components) with MI values 

727 methods : list of str 

728 List of DR method names 

729 n_components_per_method : dict, optional 

730 Number of components for each method. If None, assumes equal 

731 figsize : tuple, optional 

732 Figure size 

733 save_path : str, optional 

734 Path to save figure 

735 dpi : int, default DEFAULT_DPI 

736 DPI resolution for saved figure 

737  

738 Returns 

739 ------- 

740 fig : matplotlib.figure.Figure 

741 """ 

742 n_neurons, total_components = selectivity_matrix.shape 

743 

744 if n_components_per_method is None: 

745 # Assume equal components per method 

746 n_methods = len(methods) 

747 n_comp_each = total_components // n_methods 

748 n_components_per_method = {m: n_comp_each for m in methods} 

749 

750 if figsize is None: 

751 figsize = (5 * len(methods), 8) 

752 

753 fig, axes = plt.subplots(1, len(methods), figsize=figsize) 

754 if len(methods) == 1: 

755 axes = [axes] 

756 

757 comp_start = 0 

758 for ax, method in zip(axes, methods): 

759 n_comp = n_components_per_method[method] 

760 

761 # Extract subset for this method 

762 method_matrix = selectivity_matrix[:, comp_start:comp_start + n_comp] 

763 

764 # Plot heatmap 

765 im = ax.imshow(method_matrix.T, aspect='auto', cmap='hot', 

766 interpolation='nearest') 

767 

768 ax.set_xlabel('Neuron ID') 

769 ax.set_ylabel('Component') 

770 ax.set_title(f'{method.upper()} Component Selectivity') 

771 

772 # Add colorbar 

773 cbar = plt.colorbar(im, ax=ax) 

774 cbar.set_label('Mutual Information (bits)') 

775 

776 # Set component labels 

777 ax.set_yticks(range(n_comp)) 

778 ax.set_yticklabels([f'C{i}' for i in range(n_comp)]) 

779 

780 comp_start += n_comp 

781 

782 plt.suptitle('Neuron Selectivity to Embedding Components', fontsize=14) 

783 plt.tight_layout() 

784 

785 if save_path: 

786 plt.savefig(save_path, dpi=dpi, bbox_inches='tight') 

787 

788 return fig