Coverage for src / tracekit / visualization / waveform.py: 100%
131 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Waveform visualization functions.
3This module provides time-domain waveform and multi-channel plots
4with measurement annotations.
7Example:
8 >>> from tracekit.visualization.waveform import plot_waveform, plot_multi_channel
9 >>> plot_waveform(trace)
10 >>> plot_multi_channel([ch1, ch2, ch3])
12References:
13 matplotlib best practices for scientific visualization
14"""
16from __future__ import annotations
18from typing import TYPE_CHECKING, Any, cast
20import numpy as np
22try:
23 import matplotlib.pyplot as plt
25 HAS_MATPLOTLIB = True
26except ImportError:
27 HAS_MATPLOTLIB = False
29from tracekit.core.types import DigitalTrace, WaveformTrace
31if TYPE_CHECKING:
32 from matplotlib.axes import Axes
33 from matplotlib.figure import Figure
34 from numpy.typing import NDArray
37def plot_waveform(
38 trace: WaveformTrace,
39 *,
40 ax: Axes | None = None,
41 time_unit: str = "auto",
42 time_range: tuple[float, float] | None = None,
43 show_grid: bool = True,
44 color: str = "C0",
45 label: str | None = None,
46 show_measurements: dict[str, Any] | None = None,
47 title: str | None = None,
48 xlabel: str = "Time",
49 ylabel: str = "Amplitude",
50 show: bool = True,
51 save_path: str | None = None,
52 figsize: tuple[float, float] = (10, 6),
53) -> Figure:
54 """Plot time-domain waveform.
56 Args:
57 trace: Waveform trace to plot.
58 ax: Matplotlib axes. If None, creates new figure.
59 time_unit: Time unit ("s", "ms", "us", "ns", "auto").
60 time_range: Optional (start, end) time range in seconds to display.
61 show_grid: Show grid lines.
62 color: Line color.
63 label: Legend label.
64 show_measurements: Dictionary of measurements to annotate.
65 title: Plot title.
66 xlabel: X-axis label (appended with time unit).
67 ylabel: Y-axis label.
68 show: If True, call plt.show() to display the plot.
69 save_path: Path to save the figure. If None, figure is not saved.
70 figsize: Figure size (width, height) in inches. Only used if ax is None.
72 Returns:
73 Matplotlib Figure object.
75 Raises:
76 ImportError: If matplotlib is not installed.
77 ValueError: If axes has no associated figure.
79 Example:
80 >>> import tracekit as tk
81 >>> trace = tk.load("signal.wfm")
82 >>> fig = tk.plot_waveform(trace, time_unit="us", show=False)
83 >>> fig.savefig("waveform.png")
85 >>> # With custom styling
86 >>> fig = tk.plot_waveform(trace,
87 ... title="Captured Signal",
88 ... xlabel="Time",
89 ... ylabel="Voltage",
90 ... color="blue")
91 """
92 if not HAS_MATPLOTLIB:
93 raise ImportError("matplotlib is required for visualization")
95 if ax is None:
96 fig, ax = plt.subplots(figsize=figsize)
97 else:
98 fig_temp = ax.get_figure()
99 if fig_temp is None:
100 raise ValueError("Axes must have an associated figure")
101 fig = cast("Figure", fig_temp)
103 # Calculate time axis
104 time = trace.time_vector
106 # Auto-select time unit
107 if time_unit == "auto":
108 duration = time[-1] if len(time) > 0 else 0
109 if duration < 1e-6:
110 time_unit = "ns"
111 elif duration < 1e-3:
112 time_unit = "us"
113 elif duration < 1:
114 time_unit = "ms"
115 else:
116 time_unit = "s"
118 time_multipliers = {"s": 1.0, "ms": 1e3, "us": 1e6, "ns": 1e9}
119 multiplier = time_multipliers.get(time_unit, 1.0)
120 time_scaled = time * multiplier
122 # Plot waveform
123 ax.plot(time_scaled, trace.data, color=color, label=label, linewidth=0.8)
125 # Apply time range if specified
126 if time_range is not None:
127 ax.set_xlim(time_range[0] * multiplier, time_range[1] * multiplier)
129 # Labels
130 ax.set_xlabel(f"{xlabel} ({time_unit})")
131 ax.set_ylabel(ylabel)
133 if title:
134 ax.set_title(title)
135 elif trace.metadata.channel_name:
136 ax.set_title(f"Waveform - {trace.metadata.channel_name}")
138 if show_grid:
139 ax.grid(True, alpha=0.3)
141 if label:
142 ax.legend()
144 # Add measurement annotations
145 if show_measurements:
146 _add_measurement_annotations(ax, trace, show_measurements, time_unit, multiplier)
148 fig.tight_layout()
150 # Save if path provided
151 if save_path is not None:
152 fig.savefig(save_path, dpi=300, bbox_inches="tight")
154 # Show if requested
155 if show:
156 plt.show()
158 return fig
161def plot_multi_channel(
162 traces: list[WaveformTrace | DigitalTrace],
163 *,
164 names: list[str] | None = None,
165 shared_x: bool = True,
166 share_x: bool | None = None,
167 colors: list[str] | None = None,
168 time_unit: str = "auto",
169 show_grid: bool = True,
170 figsize: tuple[float, float] | None = None,
171 title: str | None = None,
172) -> Figure:
173 """Plot multiple channels in stacked subplots.
175 Args:
176 traces: List of traces to plot.
177 names: Channel names for labels.
178 shared_x: Share x-axis across subplots.
179 share_x: Alias for shared_x (for compatibility).
180 colors: List of colors for each trace. If None, uses default cycle.
181 time_unit: Time unit ("s", "ms", "us", "ns", "auto").
182 show_grid: Show grid lines.
183 figsize: Figure size (width, height) in inches.
184 title: Overall figure title.
186 Returns:
187 Matplotlib Figure object.
189 Raises:
190 ImportError: If matplotlib is not available.
192 Example:
193 >>> fig = plot_multi_channel([ch1, ch2, ch3], names=["CLK", "DATA", "CS"])
194 >>> plt.show()
195 """
196 # Handle share_x alias
197 if share_x is not None:
198 shared_x = share_x
199 if not HAS_MATPLOTLIB:
200 raise ImportError("matplotlib is required for visualization")
202 n_channels = len(traces)
204 if names is None:
205 names = [f"CH{i + 1}" for i in range(n_channels)]
207 if figsize is None:
208 figsize = (10, 2 * n_channels)
210 fig, axes = plt.subplots(
211 n_channels,
212 1,
213 figsize=figsize,
214 sharex=shared_x,
215 )
217 if n_channels == 1:
218 axes = [axes]
220 # Auto-select time unit from first trace
221 if time_unit == "auto" and len(traces) > 0:
222 ref_trace = traces[0]
223 duration = len(ref_trace.data) * ref_trace.metadata.time_base
224 if duration < 1e-6:
225 time_unit = "ns"
226 elif duration < 1e-3:
227 time_unit = "us"
228 elif duration < 1:
229 time_unit = "ms"
230 else:
231 time_unit = "s"
233 time_multipliers = {"s": 1.0, "ms": 1e3, "us": 1e6, "ns": 1e9}
234 multiplier = time_multipliers.get(time_unit, 1.0)
236 for i, (trace, name, ax) in enumerate(zip(traces, names, axes, strict=False)):
237 time = trace.time_vector * multiplier
238 color = colors[i] if colors is not None and i < len(colors) else f"C{i}"
240 if isinstance(trace, WaveformTrace):
241 ax.plot(time, trace.data, color=color, linewidth=0.8)
242 ax.set_ylabel("V")
243 else:
244 # Digital trace - step plot
245 ax.step(time, trace.data.astype(int), color=color, where="post", linewidth=1.0)
246 ax.set_ylim(-0.1, 1.1)
247 ax.set_yticks([0, 1])
248 ax.set_yticklabels(["L", "H"])
250 ax.set_ylabel(name, rotation=0, ha="right", va="center")
252 if show_grid:
253 ax.grid(True, alpha=0.3)
255 # Only show x-label on bottom plot
256 if i == n_channels - 1:
257 ax.set_xlabel(f"Time ({time_unit})")
259 if title:
260 fig.suptitle(title)
262 fig.tight_layout()
263 return fig
266def plot_xy(
267 x_trace: WaveformTrace | NDArray[np.float64],
268 y_trace: WaveformTrace | NDArray[np.float64],
269 *,
270 ax: Axes | None = None,
271 color: str = "C0",
272 marker: str = "",
273 alpha: float = 0.7,
274 title: str | None = None,
275) -> Figure:
276 """Plot X-Y (Lissajous) diagram.
278 Args:
279 x_trace: X-axis waveform.
280 y_trace: Y-axis waveform.
281 ax: Matplotlib axes.
282 color: Line/marker color.
283 marker: Marker style.
284 alpha: Transparency.
285 title: Plot title.
287 Returns:
288 Matplotlib Figure object.
290 Raises:
291 ImportError: If matplotlib is not available.
292 ValueError: If axes has no associated figure.
294 Example:
295 >>> fig = plot_xy(ch1, ch2) # Phase relationship
296 """
297 if not HAS_MATPLOTLIB:
298 raise ImportError("matplotlib is required for visualization")
300 if ax is None:
301 fig, ax = plt.subplots(figsize=(6, 6))
302 else:
303 fig_temp = ax.get_figure()
304 if fig_temp is None:
305 raise ValueError("Axes must have an associated figure")
306 fig = cast("Figure", fig_temp)
308 x_data = x_trace.data if isinstance(x_trace, WaveformTrace) else x_trace
309 y_data = y_trace.data if isinstance(y_trace, WaveformTrace) else y_trace
311 # Ensure same length
312 min_len = min(len(x_data), len(y_data))
313 x_data = x_data[:min_len]
314 y_data = y_data[:min_len]
316 ax.plot(x_data, y_data, color=color, marker=marker, alpha=alpha, linewidth=0.5)
318 ax.set_xlabel("X (V)")
319 ax.set_ylabel("Y (V)")
320 ax.set_aspect("equal")
321 ax.grid(True, alpha=0.3)
323 if title:
324 ax.set_title(title)
326 fig.tight_layout()
327 return fig
330def _add_measurement_annotations(
331 ax: Axes,
332 trace: WaveformTrace,
333 measurements: dict[str, Any],
334 time_unit: str,
335 multiplier: float,
336) -> None:
337 """Add measurement annotations to plot."""
338 # Create annotation text
339 text_lines = []
341 for name, value in measurements.items():
342 if isinstance(value, dict):
343 val = value.get("value", value)
344 unit = value.get("unit", "")
345 if isinstance(val, float) and not np.isnan(val):
346 text_lines.append(f"{name}: {val:.4g} {unit}")
347 elif isinstance(value, float) and not np.isnan(value):
348 text_lines.append(f"{name}: {value:.4g}")
350 if text_lines:
351 text = "\n".join(text_lines)
352 ax.annotate(
353 text,
354 xy=(0.02, 0.98),
355 xycoords="axes fraction",
356 verticalalignment="top",
357 fontfamily="monospace",
358 fontsize=8,
359 bbox={"boxstyle": "round", "facecolor": "wheat", "alpha": 0.8},
360 )
363__all__ = [
364 "plot_multi_channel",
365 "plot_waveform",
366 "plot_xy",
367]