Coverage for src / tracekit / comparison / visualization.py: 86%
177 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Visualization utilities for trace comparison.
3This module provides visualization functions for comparing traces including
4overlay plots, difference plots, and comparison heat maps.
7Example:
8 >>> from tracekit.comparison.visualization import (
9 ... plot_overlay,
10 ... plot_difference,
11 ... plot_comparison_heatmap
12 ... )
13 >>> fig = plot_overlay(trace1, trace2)
14 >>> fig = plot_difference(trace1, trace2)
16References:
17 - Tufte, E. R. (2001). The Visual Display of Quantitative Information
18"""
20from __future__ import annotations
22from typing import TYPE_CHECKING, Any
24import matplotlib.pyplot as plt
25import numpy as np
26from matplotlib.gridspec import GridSpec
28if TYPE_CHECKING:
29 from matplotlib.figure import Figure
31 from tracekit.comparison.compare import ComparisonResult
32 from tracekit.core.types import WaveformTrace
35def plot_overlay(
36 trace1: WaveformTrace,
37 trace2: WaveformTrace,
38 *,
39 labels: tuple[str, str] = ("Trace 1", "Trace 2"),
40 title: str = "Trace Comparison - Overlay",
41 highlight_differences: bool = True,
42 difference_threshold: float | None = None,
43 figsize: tuple[float, float] = (10, 6),
44 **kwargs: Any,
45) -> Figure:
46 """Create overlay plot showing both traces.
48 : Overlay plot with difference highlighting.
50 Args:
51 trace1: First trace
52 trace2: Second trace
53 labels: Labels for the two traces
54 title: Plot title
55 highlight_differences: Highlight regions where traces differ
56 difference_threshold: Threshold for highlighting (default: auto)
57 figsize: Figure size (width, height)
58 **kwargs: Additional arguments passed to plot()
60 Returns:
61 Matplotlib Figure object
63 Example:
64 >>> from tracekit.comparison.visualization import plot_overlay
65 >>> fig = plot_overlay(measured, reference,
66 ... labels=("Measured", "Reference"))
67 >>> plt.show()
69 References:
70 CMP-003: Overlay plot with difference highlighting
71 """
72 fig, ax = plt.subplots(figsize=figsize)
74 # Get data
75 data1 = trace1.data.astype(np.float64)
76 data2 = trace2.data.astype(np.float64)
78 # Align lengths
79 min_len = min(len(data1), len(data2))
80 data1 = data1[:min_len]
81 data2 = data2[:min_len]
83 # Create time axis
84 if hasattr(trace1, "metadata") and trace1.metadata.sample_rate is not None: 84 ↛ 89line 84 didn't jump to line 89 because the condition on line 84 was always true
85 sample_rate = trace1.metadata.sample_rate
86 time = np.arange(min_len) / sample_rate
87 xlabel = "Time (s)"
88 else:
89 time = np.arange(min_len, dtype=np.float64)
90 xlabel = "Sample"
92 # Plot traces
93 ax.plot(time, data1, label=labels[0], alpha=0.7, linewidth=1.5, **kwargs)
94 ax.plot(time, data2, label=labels[1], alpha=0.7, linewidth=1.5, **kwargs)
96 # Highlight differences
97 if highlight_differences:
98 diff = np.abs(data1 - data2)
99 if difference_threshold is None:
100 # Auto threshold: mean + 2*std of difference
101 difference_threshold = float(np.mean(diff) + 2 * np.std(diff))
103 # Find regions with significant difference
104 diff_mask = diff > difference_threshold
105 if np.any(diff_mask): 105 ↛ 107line 105 didn't jump to line 107 because the condition on line 105 was never true
106 # Highlight regions with vertical spans
107 in_region = False
108 start_idx = 0
109 for i in range(len(diff_mask)):
110 if diff_mask[i] and not in_region:
111 start_idx = i
112 in_region = True
113 elif not diff_mask[i] and in_region:
114 ax.axvspan(
115 time[start_idx],
116 time[i - 1],
117 alpha=0.2,
118 color="red",
119 label="Difference" if start_idx == 0 else "",
120 )
121 in_region = False
122 # Handle last region
123 if in_region:
124 ax.axvspan(time[start_idx], time[-1], alpha=0.2, color="red")
126 ax.set_xlabel(xlabel)
127 ax.set_ylabel("Amplitude")
128 ax.set_title(title)
129 ax.legend()
130 ax.grid(True, alpha=0.3)
131 plt.tight_layout()
133 return fig
136def plot_difference(
137 trace1: WaveformTrace,
138 trace2: WaveformTrace,
139 *,
140 title: str = "Trace Comparison - Difference",
141 normalize: bool = False,
142 show_statistics: bool = True,
143 figsize: tuple[float, float] = (10, 6),
144 **kwargs: Any,
145) -> Figure:
146 """Create difference plot (trace1 - trace2).
148 : Difference trace visualization.
150 Args:
151 trace1: First trace
152 trace2: Second trace
153 title: Plot title
154 normalize: Normalize difference to percentage
155 show_statistics: Show statistics text box
156 figsize: Figure size
157 **kwargs: Additional arguments passed to plot()
159 Returns:
160 Matplotlib Figure object
162 Example:
163 >>> from tracekit.comparison.visualization import plot_difference
164 >>> fig = plot_difference(measured, reference, normalize=True)
165 >>> plt.show()
167 References:
168 CMP-003: Comparison Visualization
169 """
170 fig, ax = plt.subplots(figsize=figsize)
172 # Get data
173 data1 = trace1.data.astype(np.float64)
174 data2 = trace2.data.astype(np.float64)
176 # Align lengths
177 min_len = min(len(data1), len(data2))
178 data1 = data1[:min_len]
179 data2 = data2[:min_len]
181 # Compute difference
182 diff = data1 - data2
184 if normalize:
185 # Normalize to percentage of reference range
186 ref_range = np.ptp(data2)
187 if ref_range > 0: 187 ↛ 189line 187 didn't jump to line 189 because the condition on line 187 was always true
188 diff = (diff / ref_range) * 100.0
189 ylabel = "Difference (%)"
190 else:
191 ylabel = "Difference"
193 # Create time axis
194 if hasattr(trace1, "metadata") and trace1.metadata.sample_rate is not None:
195 sample_rate = trace1.metadata.sample_rate
196 time = np.arange(min_len) / sample_rate
197 xlabel = "Time (s)"
198 else:
199 time = np.arange(min_len, dtype=np.float64)
200 xlabel = "Sample"
202 # Plot difference
203 ax.plot(time, diff, label="Difference", **kwargs)
204 ax.axhline(y=0, color="k", linestyle="--", alpha=0.5, linewidth=1)
206 # Add statistics text box
207 if show_statistics:
208 max_diff = float(np.max(np.abs(diff)))
209 rms_diff = float(np.sqrt(np.mean(diff**2)))
210 mean_diff = float(np.mean(diff))
211 std_diff = float(np.std(diff))
213 stats_text = (
214 f"Max: {max_diff:.3f}\nRMS: {rms_diff:.3f}\nMean: {mean_diff:.3f}\nStd: {std_diff:.3f}"
215 )
217 ax.text(
218 0.02,
219 0.98,
220 stats_text,
221 transform=ax.transAxes,
222 verticalalignment="top",
223 bbox={"boxstyle": "round", "facecolor": "wheat", "alpha": 0.8},
224 fontsize=9,
225 family="monospace",
226 )
228 ax.set_xlabel(xlabel)
229 ax.set_ylabel(ylabel)
230 ax.set_title(title)
231 ax.grid(True, alpha=0.3)
232 plt.tight_layout()
234 return fig
237def plot_comparison_heatmap(
238 trace1: WaveformTrace,
239 trace2: WaveformTrace,
240 *,
241 title: str = "Trace Comparison - Difference Heatmap",
242 window_size: int = 100,
243 figsize: tuple[float, float] = (10, 8),
244 **kwargs: Any,
245) -> Figure:
246 """Create difference heatmap showing where changes occur.
248 : Difference heat map showing where changes occur.
250 Args:
251 trace1: First trace
252 trace2: Second trace
253 title: Plot title
254 window_size: Window size for heatmap bins
255 figsize: Figure size
256 **kwargs: Additional arguments passed to imshow()
258 Returns:
259 Matplotlib Figure object
261 Example:
262 >>> from tracekit.comparison.visualization import plot_comparison_heatmap
263 >>> fig = plot_comparison_heatmap(trace1, trace2, window_size=50)
264 >>> plt.show()
266 References:
267 CMP-003: Difference heat map showing where changes occur
268 """
269 fig = plt.figure(figsize=figsize)
270 gs = GridSpec(2, 1, height_ratios=[3, 1], hspace=0.3)
271 ax_heat = fig.add_subplot(gs[0])
272 ax_trace = fig.add_subplot(gs[1], sharex=ax_heat)
274 # Get data
275 data1 = trace1.data.astype(np.float64)
276 data2 = trace2.data.astype(np.float64)
278 # Align lengths
279 min_len = min(len(data1), len(data2))
280 data1 = data1[:min_len]
281 data2 = data2[:min_len]
283 # Compute difference
284 diff = np.abs(data1 - data2)
286 # Create windowed heatmap
287 n_windows = min_len // window_size
288 if n_windows == 0:
289 n_windows = 1
290 window_size = min_len
292 heatmap_data = np.zeros((10, n_windows))
293 for i in range(n_windows):
294 start = i * window_size
295 end = min(start + window_size, min_len)
296 window_diff = diff[start:end]
298 # Bin into 10 levels based on amplitude
299 window_data1 = data1[start:end]
300 window_data2 = data2[start:end]
301 y_min = min(np.min(window_data1), np.min(window_data2))
302 y_max = max(np.max(window_data1), np.max(window_data2))
304 if y_max - y_min > 0: 304 ↛ 293line 304 didn't jump to line 293 because the condition on line 304 was always true
305 bins = np.linspace(y_min, y_max, 11)
306 for sample_idx in range(len(window_diff)):
307 y_val = window_data1[sample_idx] # window_data1 is already sliced
308 bin_idx = np.digitize(y_val, bins) - 1
309 bin_idx = max(0, min(9, bin_idx))
310 heatmap_data[bin_idx, i] += window_diff[sample_idx]
312 # Normalize heatmap
313 heatmap_data = heatmap_data / window_size
315 # Plot heatmap
316 im = ax_heat.imshow(
317 heatmap_data,
318 aspect="auto",
319 cmap="hot",
320 origin="lower",
321 interpolation="nearest",
322 **kwargs,
323 )
324 plt.colorbar(im, ax=ax_heat, label="Average Difference")
326 ax_heat.set_ylabel("Amplitude Bin")
327 ax_heat.set_title(title)
329 # Plot difference trace below
330 if hasattr(trace1, "metadata") and trace1.metadata.sample_rate is not None: 330 ↛ 335line 330 didn't jump to line 335 because the condition on line 330 was always true
331 sample_rate = trace1.metadata.sample_rate
332 time = np.arange(min_len) / sample_rate
333 xlabel = "Time (s)"
334 else:
335 time = np.arange(min_len, dtype=np.float64)
336 xlabel = "Sample"
338 ax_trace.plot(time, diff, linewidth=0.5)
339 ax_trace.set_xlabel(xlabel)
340 ax_trace.set_ylabel("Difference")
341 ax_trace.grid(True, alpha=0.3)
343 plt.tight_layout()
344 return fig
347def plot_comparison_summary(
348 result: ComparisonResult,
349 *,
350 title: str = "Trace Comparison Summary",
351 figsize: tuple[float, float] = (12, 8),
352) -> Figure:
353 """Create comprehensive comparison summary figure.
355 : Summary table of key differences.
357 Args:
358 result: ComparisonResult from compare_traces()
359 title: Plot title
360 figsize: Figure size
362 Returns:
363 Matplotlib Figure object
365 Example:
366 >>> from tracekit.comparison import compare_traces
367 >>> from tracekit.comparison.visualization import plot_comparison_summary
368 >>> result = compare_traces(trace1, trace2)
369 >>> fig = plot_comparison_summary(result)
370 >>> plt.show()
372 References:
373 CMP-003: Summary table of key differences
374 """
375 fig = plt.figure(figsize=figsize)
376 gs = GridSpec(3, 2, hspace=0.4, wspace=0.3)
378 # Statistics table
379 ax_stats = fig.add_subplot(gs[0, :])
380 ax_stats.axis("off")
382 stats_data = [
383 ["Match Status", "PASS ✓" if result.match else "FAIL ✗"],
384 ["Similarity Score", f"{result.similarity:.4f}"],
385 ["Correlation", f"{result.correlation:.4f}"],
386 ["Max Difference", f"{result.max_difference:.6f}"],
387 ["RMS Difference", f"{result.rms_difference:.6f}"],
388 ]
390 if result.statistics: 390 ↛ 399line 390 didn't jump to line 399 because the condition on line 390 was always true
391 stats_data.extend(
392 [
393 ["Mean Difference", f"{result.statistics['mean_difference']:.6f}"],
394 ["Violations", f"{result.statistics['num_violations']}"],
395 ["Violation Rate", f"{result.statistics['violation_rate'] * 100:.2f}%"],
396 ]
397 )
399 table = ax_stats.table(
400 cellText=stats_data,
401 colLabels=["Metric", "Value"],
402 cellLoc="left",
403 loc="center",
404 bbox=[0, 0, 1, 1], # type: ignore[arg-type]
405 )
406 table.auto_set_font_size(False)
407 table.set_fontsize(10)
408 table.scale(1, 2)
410 # Color code match status
411 if result.match:
412 table[(1, 1)].set_facecolor("#90EE90") # Light green
413 else:
414 table[(1, 1)].set_facecolor("#FFB6C1") # Light red
416 ax_stats.set_title(title, fontsize=14, fontweight="bold", pad=20)
418 # Overlay plot
419 if result.difference_trace is not None: 419 ↛ 433line 419 didn't jump to line 433 because the condition on line 419 was always true
420 # Plot difference trace
421 ax_overlay = fig.add_subplot(gs[1, :])
422 n_samples = len(result.difference_trace.data)
423 time = np.arange(n_samples)
424 ax_overlay.plot(time, result.difference_trace.data, label="Difference")
425 ax_overlay.axhline(y=0, color="k", linestyle="--", alpha=0.5)
426 ax_overlay.set_xlabel("Sample")
427 ax_overlay.set_ylabel("Difference")
428 ax_overlay.set_title("Difference Trace")
429 ax_overlay.legend()
430 ax_overlay.grid(True, alpha=0.3)
432 # Histogram of differences
433 if result.difference_trace is not None: 433 ↛ 445line 433 didn't jump to line 445 because the condition on line 433 was always true
434 ax_hist = fig.add_subplot(gs[2, 0])
435 diff_data = result.difference_trace.data
436 ax_hist.hist(diff_data, bins=50, edgecolor="black", alpha=0.7)
437 ax_hist.axvline(x=0, color="r", linestyle="--", linewidth=2, label="Zero difference")
438 ax_hist.set_xlabel("Difference")
439 ax_hist.set_ylabel("Count")
440 ax_hist.set_title("Difference Distribution")
441 ax_hist.legend()
442 ax_hist.grid(True, alpha=0.3)
444 # Violation locations
445 ax_viol = fig.add_subplot(gs[2, 1])
446 if result.violations is not None and len(result.violations) > 0:
447 ax_viol.scatter(
448 result.violations,
449 np.ones_like(result.violations),
450 marker="|",
451 s=100,
452 color="red",
453 alpha=0.5,
454 )
455 ax_viol.set_xlim(0, len(result.difference_trace.data) if result.difference_trace else 1000)
456 ax_viol.set_ylim(0.5, 1.5)
457 ax_viol.set_xlabel("Sample Index")
458 ax_viol.set_title(f"Violation Locations ({len(result.violations)} total)")
459 ax_viol.set_yticks([])
460 else:
461 ax_viol.text(
462 0.5,
463 0.5,
464 "No Violations",
465 ha="center",
466 va="center",
467 fontsize=14,
468 color="green",
469 )
470 ax_viol.axis("off")
472 plt.tight_layout()
473 return fig
476__all__ = [
477 "plot_comparison_heatmap",
478 "plot_comparison_summary",
479 "plot_difference",
480 "plot_overlay",
481]