Coverage for src/driada/utils/visual.py: 86.32%
285 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1"""
2Visualization utilities for DRIADA
3==================================
5This module provides reusable visualization functions for embedding comparisons,
6trajectory plots, and component interpretation in dimensionality reduction analyses.
7"""
9import numpy as np
10import matplotlib.pyplot as plt
11import matplotlib.gridspec as gridspec
12import seaborn as sns
13from typing import Dict, List, Tuple, Optional, Union, Any, Callable
14from scipy.stats import gaussian_kde
15import pandas as pd
17# Default DPI for all plots
18DEFAULT_DPI = 150
21def plot_embedding_comparison(
22 embeddings: Dict[str, np.ndarray],
23 features: Optional[Dict[str, np.ndarray]] = None,
24 feature_names: Optional[Dict[str, str]] = None,
25 methods: Optional[List[str]] = None,
26 with_trajectory: bool = True,
27 compute_metrics: bool = True,
28 trajectory_kwargs: Optional[Dict] = None,
29 figsize: Optional[Tuple[float, float]] = None,
30 save_path: Optional[str] = None,
31 dpi: int = DEFAULT_DPI
32) -> plt.Figure:
33 """
34 Create comprehensive embedding comparison figure with behavioral features and trajectories.
36 Parameters
37 ----------
38 embeddings : dict
39 Dictionary mapping method names to embedding arrays (n_samples, n_components)
40 features : dict, optional
41 Dictionary mapping feature names to feature arrays
42 Default features used: 'angle' (circular position) and 'speed'
43 feature_names : dict, optional
44 Dictionary mapping feature keys to display names
45 methods : list of str, optional
46 List of methods to plot (if None, uses all keys in embeddings)
47 with_trajectory : bool, default True
48 Whether to include trajectory visualization as a third row
49 compute_metrics : bool, default True
50 Whether to compute and display metrics (density contours, percentiles)
51 trajectory_kwargs : dict, optional
52 Additional keyword arguments for trajectory plotting
53 figsize : tuple, optional
54 Figure size (width, height). If None, computed based on number of methods
55 save_path : str, optional
56 Path to save the figure
57 dpi : int, default DEFAULT_DPI
58 DPI resolution for saved figure
60 Returns
61 -------
62 fig : matplotlib.figure.Figure
63 The generated figure
64 """
65 if methods is None:
66 methods = list(embeddings.keys())
68 n_methods = len(methods)
69 n_rows = 3 if with_trajectory else 2
71 # Set figure size
72 if figsize is None:
73 figsize = (6 * n_methods, 5 * n_rows)
75 # Create figure and grid
76 fig = plt.figure(figsize=figsize)
77 gs = gridspec.GridSpec(n_rows, n_methods, hspace=0.3, wspace=0.3)
79 # Default feature names
80 if feature_names is None:
81 feature_names = {
82 'angle': 'Head Direction',
83 'speed': 'Speed'
84 }
86 # Default trajectory kwargs
87 if trajectory_kwargs is None:
88 trajectory_kwargs = {}
90 default_traj_kwargs = {
91 'linewidth': 0.8,
92 'alpha': 0.3,
93 'color': 'k',
94 'arrow_spacing': 20,
95 'arrow_scale': 0.3,
96 'start_marker': 'o',
97 'end_marker': 's',
98 'marker_size': 100
99 }
100 default_traj_kwargs.update(trajectory_kwargs)
102 for i, method in enumerate(methods):
103 if method not in embeddings:
104 continue
106 embedding = embeddings[method]
108 # First row: colored by angle/position
109 ax1 = fig.add_subplot(gs[0, i])
111 if features is not None and 'angle' in features:
112 angle = features['angle']
113 # Normalize angle to [0, 1] for color mapping
114 angle_norm = (angle + np.pi) / (2 * np.pi)
116 scatter = ax1.scatter(
117 embedding[:, 0], embedding[:, 1],
118 c=angle_norm, cmap='hsv', alpha=0.7, s=2,
119 vmin=0, vmax=1, edgecolors='none'
120 )
122 cbar = plt.colorbar(scatter, ax=ax1, label=feature_names.get('angle', 'Angle'))
123 # Set colorbar ticks to show actual angles
124 cbar.set_ticks([0, 0.25, 0.5, 0.75, 1])
125 cbar.set_ticklabels(['-π', '-π/2', '0', 'π/2', 'π'])
127 # Add density contours if requested
128 if compute_metrics:
129 try:
130 kde = gaussian_kde(embedding[:, :2].T)
131 x_min, x_max = ax1.get_xlim()
132 y_min, y_max = ax1.get_ylim()
133 X, Y = np.mgrid[x_min:x_max:50j, y_min:y_max:50j]
134 positions_grid = np.vstack([X.ravel(), Y.ravel()])
135 Z = np.reshape(kde(positions_grid).T, X.shape)
136 ax1.contour(X, Y, Z, colors='gray', alpha=0.3, linewidths=0.5)
137 except:
138 pass # Skip contours if KDE fails
139 else:
140 ax1.scatter(embedding[:, 0], embedding[:, 1], alpha=0.6, s=1)
142 ax1.set_xlabel('Component 0')
143 ax1.set_ylabel('Component 1')
144 ax1.set_title(f'{method.upper()} - {feature_names.get("angle", "Position")}')
145 ax1.grid(True, alpha=0.3)
147 # Second row: colored by speed or second feature
148 ax2 = fig.add_subplot(gs[1, i])
150 if features is not None and 'speed' in features:
151 speed = features['speed']
153 scatter = ax2.scatter(
154 embedding[:, 0], embedding[:, 1],
155 c=speed, cmap='viridis', alpha=0.7, s=2,
156 edgecolors='none'
157 )
159 cbar = plt.colorbar(scatter, ax=ax2, label=feature_names.get('speed', 'Speed'))
161 # Add percentile markers if requested
162 if compute_metrics:
163 speed_percentiles = np.percentile(speed, [25, 50, 75])
164 for p, val in zip([25, 50, 75], speed_percentiles):
165 cbar.ax.axhline(y=val, color='red', alpha=0.3, linewidth=0.5)
166 cbar.ax.text(
167 1.05, val, f'{p}%',
168 transform=cbar.ax.get_yaxis_transform(),
169 fontsize=8, va='center'
170 )
171 else:
172 ax2.scatter(embedding[:, 0], embedding[:, 1], alpha=0.6, s=1)
174 ax2.set_xlabel('Component 0')
175 ax2.set_ylabel('Component 1')
176 ax2.set_title(f'{method.upper()} - {feature_names.get("speed", "Feature 2")}')
177 ax2.grid(True, alpha=0.3)
179 # Third row: trajectory visualization
180 if with_trajectory:
181 ax3 = fig.add_subplot(gs[2, i])
183 # Plot trajectory
184 ax3.plot(
185 embedding[:, 0], embedding[:, 1],
186 color=default_traj_kwargs['color'],
187 alpha=default_traj_kwargs['alpha'],
188 linewidth=default_traj_kwargs['linewidth']
189 )
191 # Add arrow markers to show direction
192 trajectory_samples = len(embedding)
193 arrow_spacing = max(1, trajectory_samples // default_traj_kwargs['arrow_spacing'])
195 for j in range(0, trajectory_samples - arrow_spacing, arrow_spacing):
196 if j + 1 < trajectory_samples:
197 dx = embedding[j+1, 0] - embedding[j, 0]
198 dy = embedding[j+1, 1] - embedding[j, 1]
200 # Only plot arrow if movement is significant
201 if np.sqrt(dx**2 + dy**2) > 0.001:
202 ax3.arrow(
203 embedding[j, 0], embedding[j, 1],
204 dx * default_traj_kwargs['arrow_scale'],
205 dy * default_traj_kwargs['arrow_scale'],
206 head_width=0.02, head_length=0.02,
207 fc='red', ec='red', alpha=0.6
208 )
210 # Mark start and end points
211 ax3.scatter(
212 embedding[0, 0], embedding[0, 1],
213 c='green', s=default_traj_kwargs['marker_size'],
214 marker=default_traj_kwargs['start_marker'],
215 edgecolors='black', linewidth=2, label='Start', zorder=5
216 )
217 ax3.scatter(
218 embedding[-1, 0], embedding[-1, 1],
219 c='red', s=default_traj_kwargs['marker_size'],
220 marker=default_traj_kwargs['end_marker'],
221 edgecolors='black', linewidth=2, label='End', zorder=5
222 )
224 ax3.set_xlabel('Component 0')
225 ax3.set_ylabel('Component 1')
226 ax3.set_title(f'{method.upper()} - Trajectory')
227 ax3.grid(True, alpha=0.3)
228 ax3.legend(loc='best', fontsize=8)
229 ax3.set_aspect('equal', adjustable='datalim')
231 # Set main title
232 title = 'Population Embeddings: Behavioral Features'
233 if with_trajectory:
234 title += ' and Trajectories'
235 plt.suptitle(title, fontsize=16)
237 # Save if requested
238 if save_path:
239 plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
241 return fig
244def plot_trajectories(
245 embeddings: Dict[str, np.ndarray],
246 methods: Optional[List[str]] = None,
247 trajectory_kwargs: Optional[Dict] = None,
248 figsize: Optional[Tuple[float, float]] = None,
249 save_path: Optional[str] = None,
250 dpi: int = DEFAULT_DPI
251) -> plt.Figure:
252 """
253 Create figure showing trajectories in embedding space for multiple methods.
255 Parameters
256 ----------
257 embeddings : dict
258 Dictionary mapping method names to embedding arrays
259 methods : list of str, optional
260 List of methods to plot
261 trajectory_kwargs : dict, optional
262 Keyword arguments for trajectory plotting
263 figsize : tuple, optional
264 Figure size
265 save_path : str, optional
266 Path to save the figure
267 dpi : int, default DEFAULT_DPI
268 DPI resolution for saved figure
270 Returns
271 -------
272 fig : matplotlib.figure.Figure
273 The generated figure
274 """
275 if methods is None:
276 methods = list(embeddings.keys())
278 n_methods = len(methods)
280 if figsize is None:
281 figsize = (6 * n_methods, 5)
283 fig = plt.figure(figsize=figsize)
285 # Default trajectory kwargs
286 if trajectory_kwargs is None:
287 trajectory_kwargs = {}
289 default_kwargs = {
290 'linewidth': 0.8,
291 'alpha': 0.3,
292 'color': 'k',
293 'arrow_spacing': 20,
294 'arrow_scale': 0.3,
295 'start_marker': 'o',
296 'end_marker': 's',
297 'marker_size': 100
298 }
299 default_kwargs.update(trajectory_kwargs)
301 for i, method in enumerate(methods):
302 if method not in embeddings:
303 continue
305 embedding = embeddings[method]
306 ax = fig.add_subplot(1, n_methods, i+1)
308 # Plot trajectory
309 ax.plot(
310 embedding[:, 0], embedding[:, 1],
311 color=default_kwargs['color'],
312 alpha=default_kwargs['alpha'],
313 linewidth=default_kwargs['linewidth']
314 )
316 # Add direction arrows
317 trajectory_samples = len(embedding)
318 arrow_spacing = max(1, trajectory_samples // default_kwargs['arrow_spacing'])
320 for j in range(0, trajectory_samples - arrow_spacing, arrow_spacing):
321 if j + 1 < trajectory_samples:
322 dx = embedding[j+1, 0] - embedding[j, 0]
323 dy = embedding[j+1, 1] - embedding[j, 1]
325 if np.sqrt(dx**2 + dy**2) > 0.001:
326 ax.arrow(
327 embedding[j, 0], embedding[j, 1],
328 dx * default_kwargs['arrow_scale'],
329 dy * default_kwargs['arrow_scale'],
330 head_width=0.02, head_length=0.02,
331 fc='red', ec='red', alpha=0.6
332 )
334 # Mark start and end
335 ax.scatter(
336 embedding[0, 0], embedding[0, 1],
337 c='green', s=default_kwargs['marker_size'],
338 marker=default_kwargs['start_marker'],
339 edgecolors='black', linewidth=2, label='Start', zorder=5
340 )
341 ax.scatter(
342 embedding[-1, 0], embedding[-1, 1],
343 c='red', s=default_kwargs['marker_size'],
344 marker=default_kwargs['end_marker'],
345 edgecolors='black', linewidth=2, label='End', zorder=5
346 )
348 ax.set_xlabel('Component 0')
349 ax.set_ylabel('Component 1')
350 ax.set_title(f'{method.upper()} - Trajectory')
351 ax.grid(True, alpha=0.3)
352 ax.legend(loc='best', fontsize=8)
353 ax.set_aspect('equal', adjustable='datalim')
355 plt.suptitle('Temporal Trajectories in Embedding Space', fontsize=16)
356 plt.tight_layout()
358 if save_path:
359 plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
361 return fig
364def plot_component_interpretation(
365 mi_matrices: Dict[str, np.ndarray],
366 feature_names: List[str],
367 methods: Optional[List[str]] = None,
368 n_components: Optional[int] = None,
369 metadata: Optional[Dict[str, Dict]] = None,
370 compute_metrics: bool = True,
371 figsize: Optional[Tuple[float, float]] = None,
372 save_path: Optional[str] = None,
373 dpi: int = DEFAULT_DPI
374) -> plt.Figure:
375 """
376 Create figure showing mutual information between embedding components and features.
378 Parameters
379 ----------
380 mi_matrices : dict
381 Dictionary mapping method names to MI matrices (n_features, n_components)
382 feature_names : list of str
383 Names of features for y-axis labels
384 methods : list of str, optional
385 List of methods to plot
386 n_components : int, optional
387 Number of components to show (default: min 5 or available)
388 metadata : dict, optional
389 Dictionary of metadata for each method (e.g., explained variance for PCA)
390 compute_metrics : bool, default True
391 Whether to show additional metrics (e.g., explained variance)
392 figsize : tuple, optional
393 Figure size
394 save_path : str, optional
395 Path to save the figure
396 dpi : int, default DEFAULT_DPI
397 DPI resolution for saved figure
399 Returns
400 -------
401 fig : matplotlib.figure.Figure
402 The generated figure
403 """
404 if methods is None:
405 methods = list(mi_matrices.keys())
407 n_methods = len(methods)
409 if figsize is None:
410 figsize = (8 * n_methods, 6)
412 fig = plt.figure(figsize=figsize)
414 for idx, method in enumerate(methods):
415 if method not in mi_matrices:
416 continue
418 mi_matrix = mi_matrices[method]
420 # Determine number of components to show
421 n_comp_available = mi_matrix.shape[1]
422 if n_components is None:
423 n_comp_show = min(5, n_comp_available)
424 else:
425 n_comp_show = min(n_components, n_comp_available)
427 # Create subplot
428 ax = plt.subplot(1, n_methods, idx + 1)
430 # Plot MI heatmap
431 mi_subset = mi_matrix[:, :n_comp_show]
432 max_mi = np.max(mi_subset) if np.max(mi_subset) > 0 else 1
434 im = ax.imshow(
435 mi_subset, aspect='auto', cmap='YlOrRd',
436 vmin=0, vmax=max_mi
437 )
439 # Set labels
440 ax.set_xticks(range(n_comp_show))
442 # Create component labels based on method
443 if method.lower() == 'pca':
444 comp_labels = [f'PC{i}' for i in range(n_comp_show)]
445 elif method.lower() == 'umap':
446 comp_labels = [f'UMAP{i}' for i in range(n_comp_show)]
447 elif method.lower() == 'le':
448 comp_labels = [f'LE{i}' for i in range(n_comp_show)]
449 else:
450 comp_labels = [f'{method.upper()}{i}' for i in range(n_comp_show)]
452 ax.set_xticklabels(comp_labels)
453 ax.set_yticks(range(len(feature_names)))
454 ax.set_yticklabels(feature_names)
455 ax.set_xlabel(f'{method.upper()} Components')
456 ax.set_title(f'{method.upper()} Component-Feature MI')
458 # Add MI values on cells
459 for i in range(len(feature_names)):
460 for j in range(n_comp_show):
461 text_color = "black" if mi_subset[i, j] < max_mi * 0.5 else "white"
462 ax.text(
463 j, i, f'{mi_subset[i, j]:.3f}',
464 ha="center", va="center",
465 color=text_color, fontsize=9
466 )
468 # Add colorbar
469 cbar = plt.colorbar(im, ax=ax, label='Mean MI (bits)')
471 # Add method-specific metrics if available
472 if compute_metrics and metadata is not None and method in metadata:
473 method_meta = metadata[method]
475 # For PCA, show explained variance
476 if method.lower() == 'pca' and 'explained_variance_ratio' in method_meta:
477 var_exp = method_meta['explained_variance_ratio'][:n_comp_show]
478 var_text = 'Var explained: ' + ', '.join([f'{v*100:.1f}%' for v in var_exp])
479 ax.text(
480 0.5, -0.15, var_text,
481 transform=ax.transAxes,
482 ha='center', va='top',
483 fontsize=8, style='italic'
484 )
486 plt.suptitle('Component Interpretation: Mutual Information between Components and Features', fontsize=16)
487 plt.tight_layout()
489 if save_path:
490 plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
492 return fig
495def plot_embeddings_grid(
496 embeddings: Dict[str, Dict[str, np.ndarray]],
497 labels: Optional[Union[np.ndarray, Dict[str, np.ndarray]]] = None,
498 methods: Optional[List[str]] = None,
499 scenarios: Optional[List[str]] = None,
500 metrics: Optional[Dict[str, Dict[str, Dict[str, float]]]] = None,
501 colormap: str = 'viridis',
502 figsize: Optional[Tuple[float, float]] = None,
503 n_cols: int = 4,
504 save_path: Optional[str] = None,
505 dpi: int = DEFAULT_DPI
506) -> plt.Figure:
507 """
508 Create grid of embeddings for multiple methods and scenarios.
510 Parameters
511 ----------
512 embeddings : dict of dict
513 Nested dictionary: {method: {scenario: embedding_array}}
514 labels : array or dict, optional
515 Color labels for points. Can be array (same for all) or dict matching structure
516 methods : list, optional
517 Methods to plot (default: all in embeddings)
518 scenarios : list, optional
519 Scenarios to plot (default: all available)
520 metrics : dict, optional
521 Nested dict of metrics: {method: {scenario: {metric_name: value}}}
522 colormap : str
523 Colormap for scatter plots
524 figsize : tuple, optional
525 Figure size
526 n_cols : int
527 Number of columns in grid
528 save_path : str, optional
529 Path to save figure
530 dpi : int, default DEFAULT_DPI
531 DPI resolution for saved figure
533 Returns
534 -------
535 fig : matplotlib.figure.Figure
536 """
537 if methods is None:
538 methods = list(embeddings.keys())
540 # Collect all scenario-method pairs
541 all_plots = []
542 for method in methods:
543 if method not in embeddings:
544 continue
545 if scenarios is None:
546 method_scenarios = list(embeddings[method].keys())
547 else:
548 method_scenarios = [s for s in scenarios if s in embeddings[method]]
550 for scenario in method_scenarios:
551 if embeddings[method][scenario] is not None:
552 all_plots.append((method, scenario))
554 if not all_plots:
555 print("No valid embeddings to plot")
556 return None
558 # Calculate grid dimensions
559 n_plots = len(all_plots)
560 n_rows = (n_plots + n_cols - 1) // n_cols
562 if figsize is None:
563 figsize = (4 * n_cols, 4 * n_rows)
565 fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
566 if n_rows == 1:
567 axes = axes.reshape(1, -1)
568 elif n_cols == 1:
569 axes = axes.reshape(-1, 1)
571 # Plot each embedding
572 for idx, (method, scenario) in enumerate(all_plots):
573 row = idx // n_cols
574 col = idx % n_cols
575 ax = axes[row, col]
577 embedding = embeddings[method][scenario]
579 # Get labels for coloring
580 if labels is None:
581 color_labels = np.arange(len(embedding))
582 elif isinstance(labels, dict):
583 if method in labels and scenario in labels[method]:
584 color_labels = labels[method][scenario]
585 else:
586 color_labels = np.arange(len(embedding))
587 else:
588 color_labels = labels
590 # Create scatter plot
591 scatter = ax.scatter(
592 embedding[:, 0], embedding[:, 1],
593 c=color_labels, cmap=colormap,
594 s=10, alpha=0.7, edgecolors='none'
595 )
597 # Add title with metrics if available
598 title = f'{method} - {scenario}'
599 if metrics and method in metrics and scenario in metrics[method]:
600 metric_strs = []
601 for metric_name, value in metrics[method][scenario].items():
602 if isinstance(value, float):
603 metric_strs.append(f'{metric_name}: {value:.3f}')
604 if metric_strs:
605 title += '\n' + ', '.join(metric_strs[:2]) # Show max 2 metrics
607 ax.set_title(title, fontsize=10)
608 ax.set_xlabel('Component 0')
609 ax.set_ylabel('Component 1')
610 ax.grid(True, alpha=0.3)
612 # Hide unused subplots
613 for idx in range(n_plots, n_rows * n_cols):
614 row = idx // n_cols
615 col = idx % n_cols
616 axes[row, col].set_visible(False)
618 plt.tight_layout()
620 if save_path:
621 plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
623 return fig
626# Note: plot_quality_metrics_comparison, plot_quality_vs_speed_tradeoff
627# are removed as they are too specific to certain examples and not reused elsewhere
630def plot_neuron_selectivity_summary(
631 selectivity_counts: Dict[str, int],
632 total_neurons: int,
633 colors: Optional[Dict[str, str]] = None,
634 figsize: Tuple[float, float] = (8, 6),
635 save_path: Optional[str] = None,
636 dpi: int = DEFAULT_DPI
637) -> plt.Figure:
638 """
639 Create bar plot summarizing neuron selectivity categories.
641 Parameters
642 ----------
643 selectivity_counts : dict
644 Dictionary mapping category names to counts
645 total_neurons : int
646 Total number of neurons
647 colors : dict, optional
648 Dictionary mapping category names to colors
649 figsize : tuple
650 Figure size
651 save_path : str, optional
652 Path to save figure
653 dpi : int, default DEFAULT_DPI
654 DPI resolution for saved figure
656 Returns
657 -------
658 fig : matplotlib.figure.Figure
659 """
660 if colors is None:
661 # Default colors for common categories
662 colors = {
663 'Spatial': 'darkgreen',
664 'spatial': 'darkgreen',
665 'position_2d': 'darkgreen',
666 'x_position': 'green',
667 'y_position': 'lightgreen',
668 'head_direction': 'blue',
669 'speed': 'orange',
670 'task_type': 'red',
671 'reward': 'purple',
672 'Non-spatial': 'gray',
673 'non_spatial': 'gray',
674 'Non-selective': 'lightgray'
675 }
677 fig, ax = plt.subplots(figsize=figsize)
679 categories = list(selectivity_counts.keys())
680 counts = list(selectivity_counts.values())
682 # Get colors for each category
683 bar_colors = [colors.get(cat, 'steelblue') for cat in categories]
685 # Create bars
686 bars = ax.bar(categories, counts, color=bar_colors, alpha=0.7)
688 # Add percentage labels
689 for bar, count in zip(bars, counts):
690 percentage = count / total_neurons * 100
691 ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 1,
692 f'{percentage:.1f}%', ha='center', va='bottom')
694 ax.set_ylabel('Number of neurons')
695 ax.set_title('Neuron Selectivity Categories')
696 ax.set_ylim(0, max(counts) * 1.15)
698 # Add total count as text
699 ax.text(0.02, 0.98, f'Total neurons: {total_neurons}',
700 transform=ax.transAxes, ha='left', va='top',
701 bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
703 plt.xticks(rotation=45, ha='right')
704 plt.tight_layout()
706 if save_path:
707 plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
709 return fig
712def plot_component_selectivity_heatmap(
713 selectivity_matrix: np.ndarray,
714 methods: List[str],
715 n_components_per_method: Optional[Dict[str, int]] = None,
716 figsize: Optional[Tuple[float, float]] = None,
717 save_path: Optional[str] = None,
718 dpi: int = DEFAULT_DPI
719) -> plt.Figure:
720 """
721 Create heatmap showing neuron selectivity to embedding components.
723 Parameters
724 ----------
725 selectivity_matrix : ndarray
726 Matrix of shape (n_neurons, total_components) with MI values
727 methods : list of str
728 List of DR method names
729 n_components_per_method : dict, optional
730 Number of components for each method. If None, assumes equal
731 figsize : tuple, optional
732 Figure size
733 save_path : str, optional
734 Path to save figure
735 dpi : int, default DEFAULT_DPI
736 DPI resolution for saved figure
738 Returns
739 -------
740 fig : matplotlib.figure.Figure
741 """
742 n_neurons, total_components = selectivity_matrix.shape
744 if n_components_per_method is None:
745 # Assume equal components per method
746 n_methods = len(methods)
747 n_comp_each = total_components // n_methods
748 n_components_per_method = {m: n_comp_each for m in methods}
750 if figsize is None:
751 figsize = (5 * len(methods), 8)
753 fig, axes = plt.subplots(1, len(methods), figsize=figsize)
754 if len(methods) == 1:
755 axes = [axes]
757 comp_start = 0
758 for ax, method in zip(axes, methods):
759 n_comp = n_components_per_method[method]
761 # Extract subset for this method
762 method_matrix = selectivity_matrix[:, comp_start:comp_start + n_comp]
764 # Plot heatmap
765 im = ax.imshow(method_matrix.T, aspect='auto', cmap='hot',
766 interpolation='nearest')
768 ax.set_xlabel('Neuron ID')
769 ax.set_ylabel('Component')
770 ax.set_title(f'{method.upper()} Component Selectivity')
772 # Add colorbar
773 cbar = plt.colorbar(im, ax=ax)
774 cbar.set_label('Mutual Information (bits)')
776 # Set component labels
777 ax.set_yticks(range(n_comp))
778 ax.set_yticklabels([f'C{i}' for i in range(n_comp)])
780 comp_start += n_comp
782 plt.suptitle('Neuron Selectivity to Embedding Components', fontsize=14)
783 plt.tight_layout()
785 if save_path:
786 plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
788 return fig