Coverage for src/driada/intense/visual.py: 97.03%
236 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1import numpy as np
3import matplotlib.pyplot as plt
4from ..utils.plot import create_default_figure, make_beautiful
5from ..utils.data import rescale
6from scipy.stats import rankdata, gaussian_kde, wasserstein_distance
7import seaborn as sns
10def plot_pc_activity(exp, cell_ind, ds=None, ax=None):
11 """
12 Plot place cell activity overlaid on spatial trajectory.
14 Parameters
15 ----------
16 exp : Experiment
17 Experiment object with spatial data and neurons.
18 cell_ind : int
19 Index of the neuron to plot.
20 ds : int, optional
21 Downsampling factor. Default: 5.
22 ax : matplotlib.axes.Axes, optional
23 Axes to plot on. If None, creates new figure.
25 Returns
26 -------
27 ax : matplotlib.axes.Axes
28 Axes with the plot.
29 """
30 pc_stats = exp.stats_table[('x', 'y')][cell_ind]
31 pval = None if pc_stats['pval'] is None else np.round(pc_stats['pval'], 7)
32 rel_mi_beh = None if pc_stats['rel_mi_beh'] is None else np.round(pc_stats['rel_mi_beh'], 4)
34 if ds is None:
35 ds = 5
37 if ax is None:
38 lenx = max(exp.x.data) - min(exp.x.data)
39 leny = max(exp.y.data) - min(exp.y.data)
40 xyratio = max(lenx / leny, leny / lenx)
41 fig, ax = create_default_figure(6*xyratio, 6)
43 #neur = np.roll(rescale(rankdata(exp.neurons[ind].ca.data)), 0)
44 neur = rescale(np.log(exp.neurons[cell_ind].ca.data+1e-10))
45 spinds = np.where(exp.neurons[cell_ind].sp.data != 0)[0]
47 #ax.plot(exp.x.data[::ds], exp.y.data[::ds], c = 'k', alpha=0.3)
48 ax.scatter(exp.x.data[::ds], exp.y.data[::ds], c=neur[::ds], cmap = 'plasma', alpha=0.8)
49 ax.scatter(exp.x.data[spinds], exp.y.data[spinds], c='k', alpha=1, marker='*', linewidth=2, s=100)
50 ax.set_xlabel('x')
51 ax.set_ylabel('y')
52 ax.set_title(f'Cell {cell_ind}, Rel MI={rel_mi_beh}, pval={pval}')
54 return ax
57def plot_neuron_feature_density(exp, data_type, cell_id, featname, ind1=0, ind2=100000, ds=1, shift=None, ax=None, compute_wsd=False):
58 """
59 Plot density distribution of neural activity conditioned on feature values.
61 Parameters
62 ----------
63 exp : Experiment
64 Experiment object containing neurons and features.
65 data_type : str
66 Type of neural data: 'calcium' or 'spikes'.
67 cell_id : int
68 Index of the neuron.
69 featname : str
70 Name of the behavioral feature.
71 ind1 : int, optional
72 Start frame index. Default: 0.
73 ind2 : int, optional
74 End frame index. Default: 100000.
75 ds : int, optional
76 Downsampling factor. Default: 1.
77 shift : int, optional
78 Temporal shift (not implemented). Default: None.
79 ax : matplotlib.axes.Axes, optional
80 Axes to plot on. If None, creates new figure.
81 compute_wsd : bool, optional
82 Whether to compute Wasserstein distance for binary features. Default: False.
84 Returns
85 -------
86 ax : matplotlib.axes.Axes
87 Axes with the plot.
88 """
89 ind2 = min(exp.n_frames, ind2)
91 if data_type == 'calcium':
92 sig = exp.neurons[cell_id].ca.scdata[ind1:ind2][::ds]
93 if data_type == 'spikes':
94 sig = exp.neurons[cell_id].sp.scdata[ind1:ind2][::ds]
96 feature = getattr(exp, featname)
97 bdata = feature.scdata[ind1:ind2][::ds]
98 rbdata = rescale(rankdata(bdata))
100 if ax is None:
101 fig, ax = plt.subplots(figsize=(6,6))
103 if feature.is_binary:
104 if data_type == 'calcium':
105 vals0 = np.log10(sig[np.where((rbdata == min(rbdata)) & (sig > 0))])
106 vals1 = np.log10(sig[np.where((rbdata == max(rbdata)) & (sig > 0))])
108 if compute_wsd and len(vals0) > 0 and len(vals1) > 0:
109 wsd = wasserstein_distance(vals0, vals1)
110 title_text = f'wsd={wsd:.3f}'
111 else:
112 title_text = ''
114 _ = sns.kdeplot(vals0, ax=ax, c='b', label=f'{featname}=0', linewidth=3, bw_adjust=0.5)
115 _ = sns.kdeplot(vals1, ax=ax, c='r', label=f'{featname}=1', linewidth=3, bw_adjust=0.5)
116 ax.legend(loc='upper right')
117 ax.set_xlabel('log(dF/F)', fontsize=20)
118 ax.set_ylabel('density', fontsize=20)
119 if title_text:
120 ax.set_title(title_text)
122 if data_type == 'spikes':
123 raise NotImplementedError('Binary feature density plot for spike data not yet implemented')
125 else:
126 x0, y0 = np.log10(sig + np.random.random(size=len(sig)) * 1e-8), np.log(
127 bdata + np.random.random(size=len(bdata)) * 1e-8)
129 jdata = np.vstack([x0, y0]).T
130 # jplot = sns.jointplot(jdata, x=jdata[:,0], y=jdata[:,1], kind='hist', bins=100)
131 nbins = 100
132 k = gaussian_kde(jdata.T)
133 xi, yi = np.mgrid[x0.min():x0.max():nbins * 1j, y0.min():y0.max():nbins * 1j]
134 zi = k(np.vstack([xi.flatten(), yi.flatten()]))
136 # plot a density
137 ax.set_title('Density')
138 ax.pcolormesh(xi, yi, zi.reshape(xi.shape), shading='auto', cmap='coolwarm')
139 ax.set_xlabel('log(signals)', fontsize=20)
140 ax.set_ylabel(f'log({featname})', fontsize=20)
142 return ax
145def plot_neuron_feature_pair(exp, cell_id, featname, ind1=0, ind2=100000, ds=1,
146 add_density_plot=True, ax=None, title=None):
147 """
148 Plot neural activity time series alongside behavioral feature.
150 Parameters
151 ----------
152 exp : Experiment
153 Experiment object containing neurons and features.
154 cell_id : int
155 Index of the neuron.
156 featname : str
157 Name of the behavioral feature.
158 ind1 : int, optional
159 Start frame index. Default: 0.
160 ind2 : int, optional
161 End frame index. Default: 100000.
162 ds : int, optional
163 Downsampling factor. Default: 1.
164 add_density_plot : bool, optional
165 Whether to add density subplot. Default: True.
166 ax : matplotlib.axes.Axes, optional
167 Axes to plot on (ignored if add_density_plot=True).
168 title : str, optional
169 Custom title for the plot.
171 Returns
172 -------
173 fig : matplotlib.figure.Figure
174 Figure containing the plot(s).
175 """
177 ind2 = min(exp.n_frames, ind2)
178 ca = exp.neurons[cell_id].ca.scdata[ind1:ind2][::ds]
179 #rca = rescale(rankdata(ca))
180 feature = getattr(exp, featname)
181 bdata = feature.scdata[ind1:ind2][::ds]
182 rbdata = rescale(rankdata(bdata))
184 if ax is None:
185 if add_density_plot:
186 fig, axs = plt.subplots(1, 2, figsize=(12, 6), width_ratios=[0.6, 0.4])
187 ax0, ax1 = axs
188 ax1 = make_beautiful(ax1)
189 else:
190 fig, ax0 = plt.subplots(figsize=(10, 6))
191 ax1 = None
192 else:
193 # When ax is provided externally, use it as ax0
194 ax0 = ax
195 ax1 = None
196 add_density_plot = False # Cannot add density plot when single axis provided
197 fig = ax0.figure # Get the figure from the provided axis
199 ax0 = make_beautiful(ax0)
201 ax0.plot(np.arange(ind1, ind2)[::ds], ca, c='b', linewidth=2, alpha=0.5, label=f'neuron {cell_id}')
202 if feature.discrete:
203 # For discrete features, use the original data to find where feature is active (1)
204 active_indices = np.where(feature.data[ind1:ind2][::ds] == 1)[0]
205 if len(active_indices) > 0:
206 ax0.scatter(np.arange(ind1, ind2)[::ds][active_indices], ca[active_indices],
207 c='r', s=50, alpha=0.7, zorder=10, label=f'{featname}=1')
208 else:
209 ax0.plot(np.arange(ind1, ind2)[::ds], rbdata, c='r', linewidth=2, alpha=0.5)
211 if add_density_plot:
212 plot_neuron_feature_density(exp, 'calcium', cell_id, featname, ind1=ind1, ind2=ind2, ds=ds, ax=ax1, compute_wsd=False)
214 ax0.set_xlabel('timeframes', fontsize=20)
215 ax0.set_ylabel('Signal/behavior', fontsize=20)
217 # Add legend if we have labels
218 if feature.discrete:
219 ax0.legend(loc='upper right')
221 if title is None:
222 title = f'{exp.signature} Neuron {cell_id}, feature {featname}'
224 fig.suptitle(title, fontsize=20)
225 plt.tight_layout()
227 return fig
230def plot_disentanglement_heatmap(disent_matrix, count_matrix, feat_names,
231 title=None, figsize=(12, 10), dpi=100,
232 cmap=None, vmin=0, vmax=100,
233 cbar_label='Disentanglement score (%)',
234 fontsize=14, title_fontsize=18,
235 show_grid=True, grid_alpha=0.3):
236 """Plot disentanglement analysis results as a heatmap.
238 Creates a heatmap showing the relative disentanglement scores between
239 feature pairs. Each cell (i,j) shows the percentage of neurons where
240 feature i was primary when paired with feature j.
242 Parameters
243 ----------
244 disent_matrix : ndarray
245 Disentanglement matrix from disentangle_all_selectivities.
246 count_matrix : ndarray
247 Count matrix from disentangle_all_selectivities.
248 feat_names : list of str
249 Feature names corresponding to matrix indices.
250 title : str, optional
251 Plot title. Default: 'Disentanglement Analysis'.
252 figsize : tuple, optional
253 Figure size (width, height). Default: (12, 10).
254 dpi : int, optional
255 Figure DPI. Default: 100.
256 cmap : str or Colormap, optional
257 Colormap to use. Default: custom red-white-green gradient.
258 vmin : float, optional
259 Minimum value for colormap. Default: 0.
260 vmax : float, optional
261 Maximum value for colormap. Default: 100.
262 cbar_label : str, optional
263 Colorbar label. Default: 'Disentanglement score (%)'.
264 fontsize : int, optional
265 Font size for tick labels. Default: 14.
266 title_fontsize : int, optional
267 Font size for title. Default: 18.
268 show_grid : bool, optional
269 Whether to show grid lines. Default: True.
270 grid_alpha : float, optional
271 Grid transparency. Default: 0.3.
273 Returns
274 -------
275 fig : matplotlib.figure.Figure
276 Figure containing the heatmap.
277 ax : matplotlib.axes.Axes
278 Axes containing the heatmap.
280 Notes
281 -----
282 The heatmap uses a diverging colormap where:
283 - Red indicates low disentanglement (feature is redundant)
284 - White indicates balanced contribution (~50%)
285 - Green indicates high disentanglement (feature is primary)
287 Cells are masked (shown in white) where no data is available.
288 """
289 import seaborn as sns
290 from matplotlib.colors import LinearSegmentedColormap
291 import pandas as pd
293 # Calculate relative disentanglement matrix (as percentage)
294 with np.errstate(divide='ignore', invalid='ignore'):
295 rel_disent_matrix = np.divide(disent_matrix, count_matrix) * 100
296 rel_disent_matrix[count_matrix == 0] = np.nan
298 # Create default colormap if not provided
299 if cmap is None:
300 # Red -> Gray -> Green gradient
301 # Gray at 50% represents equal selectivity (no disentanglement)
302 colors = [(1, 0, 0), (0.7, 0.7, 0.7), (0, 1, 0)]
303 n_bins = 100
304 cmap = LinearSegmentedColormap.from_list("disentanglement_cmap", colors, N=n_bins)
306 # Create DataFrame for seaborn
307 df_heatmap = pd.DataFrame(rel_disent_matrix, columns=feat_names, index=feat_names)
309 # Create figure
310 fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
312 # Create heatmap
313 sns.heatmap(df_heatmap,
314 ax=ax,
315 cmap=cmap,
316 vmin=vmin,
317 vmax=vmax,
318 cbar_kws={'label': cbar_label},
319 mask=np.isnan(rel_disent_matrix),
320 square=True,
321 linewidths=0.5,
322 linecolor='gray')
324 # Add grid if requested
325 if show_grid:
326 ax.grid(True, linestyle='-', alpha=grid_alpha, color='black')
328 # Set title
329 if title is None:
330 title = 'Disentanglement Analysis'
331 ax.set_title(title, fontsize=title_fontsize, pad=20)
333 # Configure tick labels
334 ax.set_xticks(np.arange(len(feat_names)) + 0.5)
335 ax.set_xticklabels(feat_names, fontsize=fontsize, rotation=45, ha='right')
336 ax.set_yticks(np.arange(len(feat_names)) + 0.5)
337 ax.set_yticklabels(feat_names, fontsize=fontsize, rotation=0)
339 # Set axis labels
340 ax.set_xlabel('Feature (as secondary)', fontsize=fontsize + 2)
341 ax.set_ylabel('Feature (as primary)', fontsize=fontsize + 2)
343 plt.tight_layout()
345 return fig, ax
348def plot_disentanglement_summary(disent_matrix, count_matrix, feat_names,
349 experiments=None, title_prefix='',
350 figsize=(14, 10), dpi=100):
351 """Plot comprehensive disentanglement analysis with multiple views.
353 Creates a figure with multiple subplots showing:
354 1. Disentanglement heatmap
355 2. Feature dominance scores
356 3. Pairwise interaction counts
358 Parameters
359 ----------
360 disent_matrix : ndarray or list of ndarray
361 Disentanglement matrix(es). If list, matrices are summed.
362 count_matrix : ndarray or list of ndarray
363 Count matrix(es). If list, matrices are summed.
364 feat_names : list of str
365 Feature names corresponding to matrix indices.
366 experiments : list of str, optional
367 Experiment names if multiple matrices provided.
368 title_prefix : str, optional
369 Prefix for the main title.
370 figsize : tuple, optional
371 Figure size. Default: (14, 10).
372 dpi : int, optional
373 Figure DPI. Default: 100.
375 Returns
376 -------
377 fig : matplotlib.figure.Figure
378 Figure containing all subplots.
379 """
380 # Handle multiple experiments
381 if isinstance(disent_matrix, list):
382 total_disent = np.sum(disent_matrix, axis=0)
383 total_count = np.sum(count_matrix, axis=0)
384 n_exp = len(disent_matrix)
385 else:
386 total_disent = disent_matrix
387 total_count = count_matrix
388 n_exp = 1
390 # Create figure with subplots
391 fig = plt.figure(figsize=figsize, dpi=dpi)
392 gs = fig.add_gridspec(2, 2, height_ratios=[3, 1], width_ratios=[3, 1])
394 # Main heatmap
395 ax_main = fig.add_subplot(gs[0, 0])
397 # Calculate relative disentanglement matrix
398 with np.errstate(divide='ignore', invalid='ignore'):
399 rel_disent_matrix = np.divide(total_disent, total_count) * 100
400 rel_disent_matrix[total_count == 0] = np.nan
402 # Create colormap
403 from matplotlib.colors import LinearSegmentedColormap
404 import seaborn as sns
405 import pandas as pd
406 colors = [(1, 0, 0), (1, 1, 1), (0, 1, 0)]
407 cmap = LinearSegmentedColormap.from_list("disentanglement_cmap", colors, N=100)
409 # Create DataFrame and plot
410 df_heatmap = pd.DataFrame(rel_disent_matrix, columns=feat_names, index=feat_names)
411 sns.heatmap(df_heatmap, ax=ax_main, cmap=cmap, vmin=0, vmax=100,
412 cbar_kws={'label': 'Disentanglement score (%)'},
413 mask=np.isnan(rel_disent_matrix), square=True,
414 linewidths=0.5, linecolor='gray')
415 ax_main.set_title('Disentanglement Heatmap')
417 # Feature dominance scores (how often each feature is primary)
418 ax_dom = fig.add_subplot(gs[0, 1])
419 with np.errstate(divide='ignore', invalid='ignore'):
420 dominance_scores = np.nansum(total_disent / total_count, axis=1)
421 y_pos = np.arange(len(feat_names))
422 ax_dom.barh(y_pos, dominance_scores, color='green', alpha=0.7)
423 ax_dom.set_yticks(y_pos)
424 ax_dom.set_yticklabels(feat_names)
425 ax_dom.set_xlabel('Dominance Score')
426 ax_dom.set_title('Feature Dominance')
427 ax_dom.grid(True, alpha=0.3)
429 # Interaction counts
430 ax_counts = fig.add_subplot(gs[1, :])
431 pair_counts = []
432 pair_labels = []
433 for i in range(len(feat_names)):
434 for j in range(i + 1, len(feat_names)):
435 if total_count[i, j] > 0:
436 pair_counts.append(total_count[i, j])
437 pair_labels.append(f'{feat_names[i]}-{feat_names[j]}')
439 x_pos = np.arange(len(pair_counts))
440 ax_counts.bar(x_pos, pair_counts, color='blue', alpha=0.7)
441 ax_counts.set_xticks(x_pos)
442 ax_counts.set_xticklabels(pair_labels, rotation=45, ha='right')
443 ax_counts.set_ylabel('Number of neurons')
444 ax_counts.set_title('Pairwise interaction counts')
445 ax_counts.grid(True, axis='y', alpha=0.3)
447 # Main title
448 if n_exp > 1:
449 title = f'{title_prefix}Disentanglement Analysis ({n_exp} experiments)'
450 else:
451 title = f'{title_prefix}Disentanglement Analysis'
452 fig.suptitle(title, fontsize=16, y=0.98)
454 plt.tight_layout()
455 return fig
458def plot_selectivity_heatmap(exp, significant_neurons,
459 metric='mi', cmap='viridis', use_log_scale=False,
460 vmin=None, vmax=None, figsize=(10, 8),
461 significance_threshold=None, ax=None):
462 """Create a heatmap showing metric values for selective neuron-feature pairs.
464 Parameters
465 ----------
466 exp : Experiment
467 The experiment object containing all data and results
468 significant_neurons : dict
469 Dictionary mapping neuron IDs to lists of significant features
470 metric : str, optional
471 Which metric to display ('mi' for mutual information, 'corr' for correlation)
472 Default: 'mi'
473 cmap : str, optional
474 Colormap to use. Default: 'viridis'
475 use_log_scale : bool, optional
476 Whether to use log scale for metric values. Default: False
477 vmin : float, optional
478 Minimum value for colormap. If None, auto-determined from data
479 vmax : float, optional
480 Maximum value for colormap. If None, auto-determined from data
481 figsize : tuple, optional
482 Figure size (ignored if ax provided). Default: (10, 8)
483 significance_threshold : float, optional
484 If provided, only show pairs with p-value below this threshold
485 ax : matplotlib.axes.Axes, optional
486 Axes to plot on. If None, creates new figure.
488 Returns
489 -------
490 fig : matplotlib.figure.Figure
491 Figure containing the heatmap
492 ax : matplotlib.axes.Axes
493 Axes containing the heatmap
494 stats : dict
495 Dictionary containing statistics about the data:
496 - n_selective: number of selective neurons
497 - n_pairs: total number of significant pairs
498 - selectivity_rate: percentage of selective neurons
499 - metric_values: list of all non-zero metric values
500 - sparsity: percentage of zero entries in the matrix
501 """
502 # Get all features and create ordered lists
503 all_features = sorted([f for f in exp.dynamic_features.keys() if isinstance(f, str)])
504 all_neurons = list(range(exp.n_cells))
506 # Create matrix with metric values (0 for non-selective pairs)
507 selectivity_matrix = np.zeros((len(all_neurons), len(all_features)))
509 # Collect all metric values for statistics
510 all_metric_values = []
512 for neuron_idx, cell_id in enumerate(all_neurons):
513 for feat_idx, feat_name in enumerate(all_features):
514 # Check if this neuron-feature pair is significant
515 if cell_id in significant_neurons and feat_name in significant_neurons[cell_id]:
516 # Get the statistics for this pair
517 pair_stats = exp.get_neuron_feature_pair_stats(cell_id, feat_name, mode='calcium')
519 # Check significance threshold if provided
520 if significance_threshold is not None:
521 pval = pair_stats.get('pval', None)
522 # Skip if pval is None (failed stage 1) or above threshold
523 if pval is None or pval > significance_threshold:
524 continue
526 # Get the metric value - 'me' contains the metric value for whichever metric was used
527 value = pair_stats.get('me', 0)
529 selectivity_matrix[neuron_idx, feat_idx] = value
530 all_metric_values.append(value)
532 # Apply log scale if requested
533 if use_log_scale and len(all_metric_values) > 0:
534 # Add small epsilon to avoid log(0)
535 epsilon = 1e-10
536 selectivity_matrix = np.log10(selectivity_matrix + epsilon)
537 # Set zeros back to a special value for visualization
538 selectivity_matrix[selectivity_matrix < np.log10(epsilon * 2)] = np.nan
540 # Create figure if needed
541 if ax is None:
542 fig, ax = plt.subplots(figsize=figsize)
543 else:
544 fig = ax.figure
546 # Determine color limits
547 if len(all_metric_values) > 0:
548 if vmin is None:
549 vmin = 0 if not use_log_scale else np.log10(min(all_metric_values))
550 if vmax is None:
551 vmax = max(all_metric_values) if not use_log_scale else np.log10(max(all_metric_values))
552 else:
553 vmin = 0
554 vmax = 1
556 # Create masked array to handle NaN values properly
557 masked_matrix = np.ma.masked_invalid(selectivity_matrix)
559 # Plot heatmap
560 im = ax.imshow(masked_matrix, cmap=cmap, aspect='auto', interpolation='nearest',
561 vmin=vmin, vmax=vmax)
563 # Set ticks and labels
564 ax.set_xticks(range(len(all_features)))
565 ax.set_xticklabels(all_features, rotation=45, ha='right')
566 ax.set_yticks(range(0, len(all_neurons), max(1, len(all_neurons)//20))) # Show ~20 neuron labels
567 ax.set_yticklabels(range(0, len(all_neurons), max(1, len(all_neurons)//20)))
569 # Labels and title
570 ax.set_xlabel('Features', fontsize=12)
571 ax.set_ylabel('Neurons', fontsize=12)
572 metric_name = 'Mutual Information' if metric == 'mi' else 'Correlation'
573 scale_text = ' (log₁₀)' if use_log_scale else ''
574 ax.set_title(f'Neuronal Selectivity: {metric_name}{scale_text}', fontsize=14, fontweight='bold')
576 # Add colorbar with appropriate label
577 cbar = plt.colorbar(im, ax=ax)
578 cbar.set_label(f'{metric_name}{scale_text}', rotation=270, labelpad=20)
580 # Calculate statistics
581 n_selective = len(significant_neurons)
582 n_pairs = sum(len(features) for features in significant_neurons.values())
583 selectivity_rate = (n_selective / exp.n_cells) * 100
584 sparsity = (1 - n_pairs/(len(all_neurons)*len(all_features)))*100
586 # Add summary text
587 summary_lines = [
588 f'Selective neurons: {n_selective}/{exp.n_cells} ({selectivity_rate:.1f}%)',
589 f'Total selective pairs: {n_pairs}'
590 ]
592 if len(all_metric_values) > 0:
593 summary_lines.extend([
594 f'{metric.upper()} range: [{min(all_metric_values):.3f}, {max(all_metric_values):.3f}]',
595 f'Mean {metric.upper()}: {np.mean(all_metric_values):.3f}'
596 ])
598 summary_text = '\n'.join(summary_lines)
599 # Position text in the lower right corner to avoid colorbar overlap
600 fig.text(0.98, 0.02, summary_text, transform=fig.transFigure,
601 fontsize=10, verticalalignment='bottom', horizontalalignment='right',
602 bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
604 # Add grid for better readability
605 ax.set_xticks(np.arange(len(all_features)) - 0.5, minor=True)
606 ax.set_yticks(np.arange(len(all_neurons)) - 0.5, minor=True)
607 ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
609 plt.tight_layout()
611 # Return statistics
612 stats = {
613 'n_selective': n_selective,
614 'n_pairs': n_pairs,
615 'selectivity_rate': selectivity_rate,
616 'metric_values': all_metric_values,
617 'sparsity': sparsity
618 }
620 return fig, ax, stats