Coverage for src / tracekit / visualization / waveform.py: 100%

131 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Waveform visualization functions. 

2 

3This module provides time-domain waveform and multi-channel plots 

4with measurement annotations. 

5 

6 

7Example: 

8 >>> from tracekit.visualization.waveform import plot_waveform, plot_multi_channel 

9 >>> plot_waveform(trace) 

10 >>> plot_multi_channel([ch1, ch2, ch3]) 

11 

12References: 

13 matplotlib best practices for scientific visualization 

14""" 

15 

16from __future__ import annotations 

17 

18from typing import TYPE_CHECKING, Any, cast 

19 

20import numpy as np 

21 

22try: 

23 import matplotlib.pyplot as plt 

24 

25 HAS_MATPLOTLIB = True 

26except ImportError: 

27 HAS_MATPLOTLIB = False 

28 

29from tracekit.core.types import DigitalTrace, WaveformTrace 

30 

31if TYPE_CHECKING: 

32 from matplotlib.axes import Axes 

33 from matplotlib.figure import Figure 

34 from numpy.typing import NDArray 

35 

36 

37def plot_waveform( 

38 trace: WaveformTrace, 

39 *, 

40 ax: Axes | None = None, 

41 time_unit: str = "auto", 

42 time_range: tuple[float, float] | None = None, 

43 show_grid: bool = True, 

44 color: str = "C0", 

45 label: str | None = None, 

46 show_measurements: dict[str, Any] | None = None, 

47 title: str | None = None, 

48 xlabel: str = "Time", 

49 ylabel: str = "Amplitude", 

50 show: bool = True, 

51 save_path: str | None = None, 

52 figsize: tuple[float, float] = (10, 6), 

53) -> Figure: 

54 """Plot time-domain waveform. 

55 

56 Args: 

57 trace: Waveform trace to plot. 

58 ax: Matplotlib axes. If None, creates new figure. 

59 time_unit: Time unit ("s", "ms", "us", "ns", "auto"). 

60 time_range: Optional (start, end) time range in seconds to display. 

61 show_grid: Show grid lines. 

62 color: Line color. 

63 label: Legend label. 

64 show_measurements: Dictionary of measurements to annotate. 

65 title: Plot title. 

66 xlabel: X-axis label (appended with time unit). 

67 ylabel: Y-axis label. 

68 show: If True, call plt.show() to display the plot. 

69 save_path: Path to save the figure. If None, figure is not saved. 

70 figsize: Figure size (width, height) in inches. Only used if ax is None. 

71 

72 Returns: 

73 Matplotlib Figure object. 

74 

75 Raises: 

76 ImportError: If matplotlib is not installed. 

77 ValueError: If axes has no associated figure. 

78 

79 Example: 

80 >>> import tracekit as tk 

81 >>> trace = tk.load("signal.wfm") 

82 >>> fig = tk.plot_waveform(trace, time_unit="us", show=False) 

83 >>> fig.savefig("waveform.png") 

84 

85 >>> # With custom styling 

86 >>> fig = tk.plot_waveform(trace, 

87 ... title="Captured Signal", 

88 ... xlabel="Time", 

89 ... ylabel="Voltage", 

90 ... color="blue") 

91 """ 

92 if not HAS_MATPLOTLIB: 

93 raise ImportError("matplotlib is required for visualization") 

94 

95 if ax is None: 

96 fig, ax = plt.subplots(figsize=figsize) 

97 else: 

98 fig_temp = ax.get_figure() 

99 if fig_temp is None: 

100 raise ValueError("Axes must have an associated figure") 

101 fig = cast("Figure", fig_temp) 

102 

103 # Calculate time axis 

104 time = trace.time_vector 

105 

106 # Auto-select time unit 

107 if time_unit == "auto": 

108 duration = time[-1] if len(time) > 0 else 0 

109 if duration < 1e-6: 

110 time_unit = "ns" 

111 elif duration < 1e-3: 

112 time_unit = "us" 

113 elif duration < 1: 

114 time_unit = "ms" 

115 else: 

116 time_unit = "s" 

117 

118 time_multipliers = {"s": 1.0, "ms": 1e3, "us": 1e6, "ns": 1e9} 

119 multiplier = time_multipliers.get(time_unit, 1.0) 

120 time_scaled = time * multiplier 

121 

122 # Plot waveform 

123 ax.plot(time_scaled, trace.data, color=color, label=label, linewidth=0.8) 

124 

125 # Apply time range if specified 

126 if time_range is not None: 

127 ax.set_xlim(time_range[0] * multiplier, time_range[1] * multiplier) 

128 

129 # Labels 

130 ax.set_xlabel(f"{xlabel} ({time_unit})") 

131 ax.set_ylabel(ylabel) 

132 

133 if title: 

134 ax.set_title(title) 

135 elif trace.metadata.channel_name: 

136 ax.set_title(f"Waveform - {trace.metadata.channel_name}") 

137 

138 if show_grid: 

139 ax.grid(True, alpha=0.3) 

140 

141 if label: 

142 ax.legend() 

143 

144 # Add measurement annotations 

145 if show_measurements: 

146 _add_measurement_annotations(ax, trace, show_measurements, time_unit, multiplier) 

147 

148 fig.tight_layout() 

149 

150 # Save if path provided 

151 if save_path is not None: 

152 fig.savefig(save_path, dpi=300, bbox_inches="tight") 

153 

154 # Show if requested 

155 if show: 

156 plt.show() 

157 

158 return fig 

159 

160 

161def plot_multi_channel( 

162 traces: list[WaveformTrace | DigitalTrace], 

163 *, 

164 names: list[str] | None = None, 

165 shared_x: bool = True, 

166 share_x: bool | None = None, 

167 colors: list[str] | None = None, 

168 time_unit: str = "auto", 

169 show_grid: bool = True, 

170 figsize: tuple[float, float] | None = None, 

171 title: str | None = None, 

172) -> Figure: 

173 """Plot multiple channels in stacked subplots. 

174 

175 Args: 

176 traces: List of traces to plot. 

177 names: Channel names for labels. 

178 shared_x: Share x-axis across subplots. 

179 share_x: Alias for shared_x (for compatibility). 

180 colors: List of colors for each trace. If None, uses default cycle. 

181 time_unit: Time unit ("s", "ms", "us", "ns", "auto"). 

182 show_grid: Show grid lines. 

183 figsize: Figure size (width, height) in inches. 

184 title: Overall figure title. 

185 

186 Returns: 

187 Matplotlib Figure object. 

188 

189 Raises: 

190 ImportError: If matplotlib is not available. 

191 

192 Example: 

193 >>> fig = plot_multi_channel([ch1, ch2, ch3], names=["CLK", "DATA", "CS"]) 

194 >>> plt.show() 

195 """ 

196 # Handle share_x alias 

197 if share_x is not None: 

198 shared_x = share_x 

199 if not HAS_MATPLOTLIB: 

200 raise ImportError("matplotlib is required for visualization") 

201 

202 n_channels = len(traces) 

203 

204 if names is None: 

205 names = [f"CH{i + 1}" for i in range(n_channels)] 

206 

207 if figsize is None: 

208 figsize = (10, 2 * n_channels) 

209 

210 fig, axes = plt.subplots( 

211 n_channels, 

212 1, 

213 figsize=figsize, 

214 sharex=shared_x, 

215 ) 

216 

217 if n_channels == 1: 

218 axes = [axes] 

219 

220 # Auto-select time unit from first trace 

221 if time_unit == "auto" and len(traces) > 0: 

222 ref_trace = traces[0] 

223 duration = len(ref_trace.data) * ref_trace.metadata.time_base 

224 if duration < 1e-6: 

225 time_unit = "ns" 

226 elif duration < 1e-3: 

227 time_unit = "us" 

228 elif duration < 1: 

229 time_unit = "ms" 

230 else: 

231 time_unit = "s" 

232 

233 time_multipliers = {"s": 1.0, "ms": 1e3, "us": 1e6, "ns": 1e9} 

234 multiplier = time_multipliers.get(time_unit, 1.0) 

235 

236 for i, (trace, name, ax) in enumerate(zip(traces, names, axes, strict=False)): 

237 time = trace.time_vector * multiplier 

238 color = colors[i] if colors is not None and i < len(colors) else f"C{i}" 

239 

240 if isinstance(trace, WaveformTrace): 

241 ax.plot(time, trace.data, color=color, linewidth=0.8) 

242 ax.set_ylabel("V") 

243 else: 

244 # Digital trace - step plot 

245 ax.step(time, trace.data.astype(int), color=color, where="post", linewidth=1.0) 

246 ax.set_ylim(-0.1, 1.1) 

247 ax.set_yticks([0, 1]) 

248 ax.set_yticklabels(["L", "H"]) 

249 

250 ax.set_ylabel(name, rotation=0, ha="right", va="center") 

251 

252 if show_grid: 

253 ax.grid(True, alpha=0.3) 

254 

255 # Only show x-label on bottom plot 

256 if i == n_channels - 1: 

257 ax.set_xlabel(f"Time ({time_unit})") 

258 

259 if title: 

260 fig.suptitle(title) 

261 

262 fig.tight_layout() 

263 return fig 

264 

265 

266def plot_xy( 

267 x_trace: WaveformTrace | NDArray[np.float64], 

268 y_trace: WaveformTrace | NDArray[np.float64], 

269 *, 

270 ax: Axes | None = None, 

271 color: str = "C0", 

272 marker: str = "", 

273 alpha: float = 0.7, 

274 title: str | None = None, 

275) -> Figure: 

276 """Plot X-Y (Lissajous) diagram. 

277 

278 Args: 

279 x_trace: X-axis waveform. 

280 y_trace: Y-axis waveform. 

281 ax: Matplotlib axes. 

282 color: Line/marker color. 

283 marker: Marker style. 

284 alpha: Transparency. 

285 title: Plot title. 

286 

287 Returns: 

288 Matplotlib Figure object. 

289 

290 Raises: 

291 ImportError: If matplotlib is not available. 

292 ValueError: If axes has no associated figure. 

293 

294 Example: 

295 >>> fig = plot_xy(ch1, ch2) # Phase relationship 

296 """ 

297 if not HAS_MATPLOTLIB: 

298 raise ImportError("matplotlib is required for visualization") 

299 

300 if ax is None: 

301 fig, ax = plt.subplots(figsize=(6, 6)) 

302 else: 

303 fig_temp = ax.get_figure() 

304 if fig_temp is None: 

305 raise ValueError("Axes must have an associated figure") 

306 fig = cast("Figure", fig_temp) 

307 

308 x_data = x_trace.data if isinstance(x_trace, WaveformTrace) else x_trace 

309 y_data = y_trace.data if isinstance(y_trace, WaveformTrace) else y_trace 

310 

311 # Ensure same length 

312 min_len = min(len(x_data), len(y_data)) 

313 x_data = x_data[:min_len] 

314 y_data = y_data[:min_len] 

315 

316 ax.plot(x_data, y_data, color=color, marker=marker, alpha=alpha, linewidth=0.5) 

317 

318 ax.set_xlabel("X (V)") 

319 ax.set_ylabel("Y (V)") 

320 ax.set_aspect("equal") 

321 ax.grid(True, alpha=0.3) 

322 

323 if title: 

324 ax.set_title(title) 

325 

326 fig.tight_layout() 

327 return fig 

328 

329 

330def _add_measurement_annotations( 

331 ax: Axes, 

332 trace: WaveformTrace, 

333 measurements: dict[str, Any], 

334 time_unit: str, 

335 multiplier: float, 

336) -> None: 

337 """Add measurement annotations to plot.""" 

338 # Create annotation text 

339 text_lines = [] 

340 

341 for name, value in measurements.items(): 

342 if isinstance(value, dict): 

343 val = value.get("value", value) 

344 unit = value.get("unit", "") 

345 if isinstance(val, float) and not np.isnan(val): 

346 text_lines.append(f"{name}: {val:.4g} {unit}") 

347 elif isinstance(value, float) and not np.isnan(value): 

348 text_lines.append(f"{name}: {value:.4g}") 

349 

350 if text_lines: 

351 text = "\n".join(text_lines) 

352 ax.annotate( 

353 text, 

354 xy=(0.02, 0.98), 

355 xycoords="axes fraction", 

356 verticalalignment="top", 

357 fontfamily="monospace", 

358 fontsize=8, 

359 bbox={"boxstyle": "round", "facecolor": "wheat", "alpha": 0.8}, 

360 ) 

361 

362 

363__all__ = [ 

364 "plot_multi_channel", 

365 "plot_waveform", 

366 "plot_xy", 

367]