Coverage for src / tracekit / visualization / spectral.py: 99%

182 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Spectral visualization functions. 

2 

3This module provides spectrum and spectrogram plots for 

4frequency-domain analysis visualization. 

5 

6 

7Example: 

8 >>> from tracekit.visualization.spectral import plot_spectrum, plot_spectrogram 

9 >>> plot_spectrum(trace) 

10 >>> plot_spectrogram(trace) 

11 

12References: 

13 matplotlib best practices for scientific visualization 

14""" 

15 

16from __future__ import annotations 

17 

18from typing import TYPE_CHECKING, Any, Literal, cast 

19 

20import numpy as np 

21 

22try: 

23 import matplotlib.pyplot as plt 

24 from matplotlib.colors import Normalize # noqa: F401 

25 

26 HAS_MATPLOTLIB = True 

27except ImportError: 

28 HAS_MATPLOTLIB = False 

29 

30 

31if TYPE_CHECKING: 

32 from matplotlib.axes import Axes 

33 from matplotlib.figure import Figure 

34 

35 from tracekit.core.types import WaveformTrace 

36 

37 

38def plot_spectrum( 

39 trace: WaveformTrace, 

40 *, 

41 ax: Axes | None = None, 

42 freq_unit: str = "auto", 

43 db_ref: float | None = None, 

44 freq_range: tuple[float, float] | None = None, 

45 show_grid: bool = True, 

46 color: str = "C0", 

47 title: str | None = None, 

48 window: str = "hann", 

49 xscale: Literal["linear", "log"] = "log", 

50 show: bool = True, 

51 save_path: str | None = None, 

52 figsize: tuple[float, float] = (10, 6), 

53 xlim: tuple[float, float] | None = None, 

54 ylim: tuple[float, float] | None = None, 

55 fft_result: tuple[Any, Any] | None = None, 

56 log_scale: bool = True, 

57 db_scale: bool | None = None, 

58) -> Figure: 

59 """Plot magnitude spectrum. 

60 

61 Args: 

62 trace: Waveform trace to analyze. 

63 ax: Matplotlib axes. If None, creates new figure. 

64 freq_unit: Frequency unit ("Hz", "kHz", "MHz", "auto"). 

65 db_ref: Reference for dB scaling. If None, uses max value. 

66 freq_range: Frequency range (min, max) in Hz to display. 

67 show_grid: Show grid lines. 

68 color: Line color. 

69 title: Plot title. 

70 window: Window function for FFT. 

71 xscale: X-axis scale ("linear" or "log"). Deprecated, use log_scale instead. 

72 show: If True, call plt.show() to display the plot. 

73 save_path: Path to save the figure. If None, figure is not saved. 

74 figsize: Figure size (width, height) in inches. Only used if ax is None. 

75 xlim: X-axis limits (min, max) in selected frequency units. 

76 ylim: Y-axis limits (min, max) in dB. 

77 fft_result: Pre-computed FFT result (frequencies, magnitudes). If None, computes FFT. 

78 log_scale: Use logarithmic scale for frequency axis (default True). 

79 db_scale: Deprecated alias for log_scale. If provided, overrides log_scale. 

80 

81 Returns: 

82 Matplotlib Figure object. 

83 

84 Raises: 

85 ImportError: If matplotlib is not installed. 

86 ValueError: If axes must have an associated figure. 

87 

88 Example: 

89 >>> import tracekit as tk 

90 >>> trace = tk.load("signal.wfm") 

91 >>> fig = tk.plot_spectrum(trace, freq_unit="MHz", log_scale=True) 

92 

93 >>> # With pre-computed FFT 

94 >>> freq, mag = tk.fft(trace) 

95 >>> fig = tk.plot_spectrum(trace, fft_result=(freq, mag), show=False) 

96 >>> fig.savefig("spectrum.png") 

97 """ 

98 if not HAS_MATPLOTLIB: 

99 raise ImportError("matplotlib is required for visualization") 

100 

101 # Handle deprecated db_scale parameter 

102 if db_scale is not None: 

103 log_scale = db_scale 

104 

105 from tracekit.analyzers.waveform.spectral import fft 

106 

107 if ax is None: 

108 fig, ax = plt.subplots(figsize=figsize) 

109 else: 

110 fig_temp = ax.get_figure() 

111 if fig_temp is None: 

112 raise ValueError("Axes must have an associated figure") 

113 fig = cast("Figure", fig_temp) 

114 

115 # Compute FFT if not provided 

116 if fft_result is not None: 

117 freq, mag_db = fft_result 

118 else: 

119 freq, mag_db = fft(trace, window=window) # type: ignore[misc] 

120 

121 # Auto-select frequency unit 

122 if freq_unit == "auto": 

123 max_freq = freq[-1] 

124 if max_freq >= 1e9: 

125 freq_unit = "GHz" 

126 elif max_freq >= 1e6: 

127 freq_unit = "MHz" 

128 elif max_freq >= 1e3: 

129 freq_unit = "kHz" 

130 else: 

131 freq_unit = "Hz" 

132 

133 freq_divisors = {"Hz": 1.0, "kHz": 1e3, "MHz": 1e6, "GHz": 1e9} 

134 divisor = freq_divisors.get(freq_unit, 1.0) 

135 freq_scaled = freq / divisor 

136 

137 # Adjust dB reference if specified 

138 if db_ref is not None: 

139 mag_db = mag_db - db_ref 

140 

141 # Plot 

142 ax.plot(freq_scaled, mag_db, color=color, linewidth=0.8) 

143 

144 ax.set_xlabel(f"Frequency ({freq_unit})") 

145 ax.set_ylabel("Magnitude (dB)") 

146 

147 # Use log_scale parameter, fall back to xscale for backward compatibility 

148 # Note: xscale is Literal["linear", "log"] so can never be "log" at this point 

149 ax.set_xscale("log" if log_scale else "linear") 

150 

151 if title: 

152 ax.set_title(title) 

153 else: 

154 ax.set_title("Magnitude Spectrum") 

155 

156 if show_grid: 

157 ax.grid(True, alpha=0.3, which="both") 

158 

159 # Set reasonable y-limits 

160 valid_db = mag_db[np.isfinite(mag_db)] 

161 if len(valid_db) > 0: 

162 y_max = np.max(valid_db) 

163 y_min = max(np.min(valid_db), y_max - 120) # Limit dynamic range 

164 ax.set_ylim(y_min, y_max + 5) 

165 

166 # Apply custom limits if specified 

167 if freq_range is not None: 

168 ax.set_xlim(freq_range[0] / divisor, freq_range[1] / divisor) 

169 elif xlim is not None: 

170 ax.set_xlim(xlim) 

171 

172 if ylim is not None: 

173 ax.set_ylim(ylim) 

174 

175 fig.tight_layout() 

176 

177 # Save if path provided 

178 if save_path is not None: 

179 fig.savefig(save_path, dpi=300, bbox_inches="tight") 

180 

181 # Show if requested 

182 if show: 

183 plt.show() 

184 

185 return fig 

186 

187 

188def plot_spectrogram( 

189 trace: WaveformTrace, 

190 *, 

191 ax: Axes | None = None, 

192 time_unit: str = "auto", 

193 freq_unit: str = "auto", 

194 cmap: str = "viridis", 

195 vmin: float | None = None, 

196 vmax: float | None = None, 

197 title: str | None = None, 

198 window: str = "hann", 

199 nperseg: int | None = None, 

200 nfft: int | None = None, 

201 overlap: float | None = None, 

202) -> Figure: 

203 """Plot spectrogram (time-frequency representation). 

204 

205 Args: 

206 trace: Waveform trace to analyze. 

207 ax: Matplotlib axes. If None, creates new figure. 

208 time_unit: Time unit ("s", "ms", "us", "auto"). 

209 freq_unit: Frequency unit ("Hz", "kHz", "MHz", "auto"). 

210 cmap: Colormap name. 

211 vmin: Minimum dB value for color scaling. 

212 vmax: Maximum dB value for color scaling. 

213 title: Plot title. 

214 window: Window function. 

215 nperseg: Segment length for STFT. 

216 nfft: FFT length. If specified, overrides nperseg. 

217 overlap: Overlap fraction (0.0 to 1.0). Default is 0.5 (50%). 

218 

219 Returns: 

220 Matplotlib Figure object. 

221 

222 Raises: 

223 ImportError: If matplotlib is not installed. 

224 ValueError: If axes must have an associated figure. 

225 

226 Example: 

227 >>> fig = plot_spectrogram(trace) 

228 >>> plt.show() 

229 """ 

230 if not HAS_MATPLOTLIB: 

231 raise ImportError("matplotlib is required for visualization") 

232 

233 from tracekit.analyzers.waveform.spectral import spectrogram 

234 

235 if ax is None: 

236 fig, ax = plt.subplots(figsize=(10, 4)) 

237 else: 

238 fig_temp = ax.get_figure() 

239 if fig_temp is None: 

240 raise ValueError("Axes must have an associated figure") 

241 fig = cast("Figure", fig_temp) 

242 

243 # Handle nfft as alias for nperseg 

244 if nfft is not None: 

245 nperseg = nfft 

246 

247 # Compute spectrogram with optional overlap 

248 noverlap = None 

249 if overlap is not None and nperseg is not None: 

250 noverlap = int(nperseg * overlap) 

251 times, freq, Sxx_db = spectrogram(trace, window=window, nperseg=nperseg, noverlap=noverlap) 

252 

253 # Auto-select units 

254 if time_unit == "auto": 

255 max_time = times[-1] if len(times) > 0 else 0 

256 if max_time < 1e-6: 

257 time_unit = "ns" 

258 elif max_time < 1e-3: 

259 time_unit = "us" 

260 elif max_time < 1: 

261 time_unit = "ms" 

262 else: 

263 time_unit = "s" 

264 

265 if freq_unit == "auto": 

266 max_freq = freq[-1] if len(freq) > 0 else 0 

267 if max_freq >= 1e9: 

268 freq_unit = "GHz" 

269 elif max_freq >= 1e6: 

270 freq_unit = "MHz" 

271 elif max_freq >= 1e3: 

272 freq_unit = "kHz" 

273 else: 

274 freq_unit = "Hz" 

275 

276 time_multipliers = {"s": 1.0, "ms": 1e3, "us": 1e6, "ns": 1e9} 

277 freq_divisors = {"Hz": 1.0, "kHz": 1e3, "MHz": 1e6, "GHz": 1e9} 

278 

279 time_mult = time_multipliers.get(time_unit, 1.0) 

280 freq_div = freq_divisors.get(freq_unit, 1.0) 

281 

282 times_scaled = times * time_mult 

283 freq_scaled = freq / freq_div 

284 

285 # Auto color limits 

286 if vmin is None or vmax is None: 

287 valid_db = Sxx_db[np.isfinite(Sxx_db)] 

288 if len(valid_db) > 0: 

289 if vmax is None: 

290 vmax = np.max(valid_db) 

291 if vmin is None: 291 ↛ 295line 291 didn't jump to line 295 because the condition on line 291 was always true

292 vmin = max(np.min(valid_db), vmax - 80) 

293 

294 # Plot 

295 pcm = ax.pcolormesh( 

296 times_scaled, 

297 freq_scaled, 

298 Sxx_db, 

299 shading="auto", 

300 cmap=cmap, 

301 vmin=vmin, 

302 vmax=vmax, 

303 ) 

304 

305 ax.set_xlabel(f"Time ({time_unit})") 

306 ax.set_ylabel(f"Frequency ({freq_unit})") 

307 

308 if title: 

309 ax.set_title(title) 

310 else: 

311 ax.set_title("Spectrogram") 

312 

313 # Colorbar 

314 cbar = fig.colorbar(pcm, ax=ax) 

315 cbar.set_label("Magnitude (dB)") 

316 

317 fig.tight_layout() 

318 return fig 

319 

320 

321def plot_psd( 

322 trace: WaveformTrace, 

323 *, 

324 ax: Axes | None = None, 

325 freq_unit: str = "auto", 

326 show_grid: bool = True, 

327 color: str = "C0", 

328 title: str | None = None, 

329 window: str = "hann", 

330 xscale: Literal["linear", "log"] = "log", 

331) -> Figure: 

332 """Plot Power Spectral Density. 

333 

334 Args: 

335 trace: Waveform trace to analyze. 

336 ax: Matplotlib axes. 

337 freq_unit: Frequency unit. 

338 show_grid: Show grid lines. 

339 color: Line color. 

340 title: Plot title. 

341 window: Window function. 

342 xscale: X-axis scale. 

343 

344 Returns: 

345 Matplotlib Figure object. 

346 

347 Raises: 

348 ImportError: If matplotlib is not installed. 

349 ValueError: If axes must have an associated figure. 

350 

351 Example: 

352 >>> fig = plot_psd(trace) 

353 >>> plt.show() 

354 """ 

355 if not HAS_MATPLOTLIB: 

356 raise ImportError("matplotlib is required for visualization") 

357 

358 from tracekit.analyzers.waveform.spectral import psd 

359 

360 if ax is None: 

361 fig, ax = plt.subplots(figsize=(10, 4)) 

362 else: 

363 fig_temp = ax.get_figure() 

364 if fig_temp is None: 

365 raise ValueError("Axes must have an associated figure") 

366 fig = cast("Figure", fig_temp) 

367 

368 # Compute PSD 

369 freq, psd_db = psd(trace, window=window) 

370 

371 # Auto-select frequency unit 

372 if freq_unit == "auto": 372 ↛ 383line 372 didn't jump to line 383 because the condition on line 372 was always true

373 max_freq = freq[-1] 

374 if max_freq >= 1e9: 

375 freq_unit = "GHz" 

376 elif max_freq >= 1e6: 

377 freq_unit = "MHz" 

378 elif max_freq >= 1e3: 

379 freq_unit = "kHz" 

380 else: 

381 freq_unit = "Hz" 

382 

383 freq_divisors = {"Hz": 1.0, "kHz": 1e3, "MHz": 1e6, "GHz": 1e9} 

384 divisor = freq_divisors.get(freq_unit, 1.0) 

385 freq_scaled = freq / divisor 

386 

387 # Plot 

388 ax.plot(freq_scaled, psd_db, color=color, linewidth=0.8) 

389 

390 ax.set_xlabel(f"Frequency ({freq_unit})") 

391 ax.set_ylabel("PSD (dB/Hz)") 

392 ax.set_xscale(xscale) 

393 

394 if title: 

395 ax.set_title(title) 

396 else: 

397 ax.set_title("Power Spectral Density") 

398 

399 if show_grid: 

400 ax.grid(True, alpha=0.3, which="both") 

401 

402 fig.tight_layout() 

403 return fig 

404 

405 

406def plot_fft( 

407 trace: WaveformTrace, 

408 *, 

409 ax: Axes | None = None, 

410 show: bool = True, 

411 save_path: str | None = None, 

412 title: str | None = None, 

413 xlabel: str = "Frequency", 

414 ylabel: str = "Magnitude (dB)", 

415 figsize: tuple[float, float] = (10, 6), 

416 freq_unit: str = "auto", 

417 log_scale: bool = True, 

418 show_grid: bool = True, 

419 color: str = "C0", 

420 window: str = "hann", 

421 xlim: tuple[float, float] | None = None, 

422 ylim: tuple[float, float] | None = None, 

423) -> Figure: 

424 """Plot FFT magnitude spectrum. 

425 

426 Computes and plots the FFT magnitude spectrum of a waveform trace. 

427 This is a convenience function that combines FFT computation and visualization. 

428 

429 Args: 

430 trace: Waveform trace to analyze and plot. 

431 ax: Matplotlib axes. If None, creates new figure. 

432 show: If True, call plt.show() to display the plot. 

433 save_path: Path to save the figure. If None, figure is not saved. 

434 title: Plot title. If None, uses default "FFT Magnitude Spectrum". 

435 xlabel: X-axis label (appended with frequency unit). 

436 ylabel: Y-axis label. 

437 figsize: Figure size (width, height) in inches. Only used if ax is None. 

438 freq_unit: Frequency unit ("Hz", "kHz", "MHz", "GHz", "auto"). 

439 log_scale: Use logarithmic scale for frequency axis. 

440 show_grid: Show grid lines. 

441 color: Line color. 

442 window: Window function for FFT computation. 

443 xlim: X-axis limits (min, max) in selected frequency units. 

444 ylim: Y-axis limits (min, max) in dB. 

445 

446 Returns: 

447 Matplotlib Figure object. 

448 

449 Raises: 

450 ImportError: If matplotlib is not installed. 

451 ValueError: If axes must have an associated figure. 

452 

453 Example: 

454 >>> import tracekit as tk 

455 >>> trace = tk.load("signal.wfm") 

456 >>> fig = tk.plot_fft(trace, freq_unit="MHz", show=False) 

457 >>> fig.savefig("spectrum.png") 

458 

459 >>> # With custom styling 

460 >>> fig = tk.plot_fft(trace, 

461 ... title="Signal FFT", 

462 ... log_scale=True, 

463 ... xlim=(1e3, 1e6), 

464 ... ylim=(-100, 0)) 

465 

466 References: 

467 IEEE 1241-2010: Standard for Terminology and Test Methods for 

468 Analog-to-Digital Converters 

469 """ 

470 if not HAS_MATPLOTLIB: 

471 raise ImportError("matplotlib is required for visualization") 

472 

473 # Create figure if needed 

474 if ax is None: 

475 fig, ax = plt.subplots(figsize=figsize) 

476 else: 

477 fig_temp = ax.get_figure() 

478 if fig_temp is None: 

479 raise ValueError("Axes must have an associated figure") 

480 fig = cast("Figure", fig_temp) 

481 

482 # Use plot_spectrum to do the actual plotting 

483 xscale_value: Literal["linear", "log"] = "log" if log_scale else "linear" 

484 plot_spectrum( 

485 trace, 

486 ax=ax, 

487 freq_unit=freq_unit, 

488 show_grid=show_grid, 

489 color=color, 

490 title=title if title else "FFT Magnitude Spectrum", 

491 window=window, 

492 xscale=xscale_value, 

493 ) 

494 

495 # Apply custom labels if different from defaults 

496 if xlabel != "Frequency": 

497 # Get current label to preserve unit 

498 current_label = ax.get_xlabel() 

499 if "(" in current_label and ")" in current_label: 

500 unit = current_label[current_label.find("(") : current_label.find(")") + 1] 

501 ax.set_xlabel(f"{xlabel} {unit}") 

502 else: 

503 ax.set_xlabel(xlabel) 

504 

505 if ylabel != "Magnitude (dB)": 

506 ax.set_ylabel(ylabel) 

507 

508 # Apply custom limits if specified 

509 if xlim is not None: 

510 ax.set_xlim(xlim) 

511 

512 if ylim is not None: 

513 ax.set_ylim(ylim) 

514 

515 # Save if path provided 

516 if save_path is not None: 

517 fig.savefig(save_path, dpi=300, bbox_inches="tight") 

518 

519 # Show if requested 

520 if show: 

521 plt.show() 

522 

523 return fig 

524 

525 

526__all__ = [ 

527 "plot_fft", 

528 "plot_psd", 

529 "plot_spectrogram", 

530 "plot_spectrum", 

531]