Coverage for src / tracekit / visualization / interactive.py: 98%

303 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Interactive visualization features. 

2 

3This module provides interactive plotting capabilities including zoom, 

4pan, cursors, and specialized plot types. 

5 

6 

7Example: 

8 >>> from tracekit.visualization.interactive import ( 

9 ... plot_with_cursors, plot_phase, plot_bode, 

10 ... plot_waterfall, plot_histogram 

11 ... ) 

12 >>> fig, ax = plot_with_cursors(trace) 

13 >>> plot_bode(frequencies, magnitude, phase) 

14 

15References: 

16 matplotlib interactive features 

17""" 

18 

19from __future__ import annotations 

20 

21from dataclasses import dataclass, field 

22from typing import TYPE_CHECKING, Any, Literal, cast 

23 

24import numpy as np 

25from scipy import signal as scipy_signal 

26 

27if TYPE_CHECKING: 

28 from matplotlib.axes import Axes 

29 from matplotlib.backend_bases import MouseEvent 

30 from matplotlib.figure import Figure 

31 from numpy.typing import NDArray 

32 

33from tracekit.core.types import WaveformTrace 

34 

35# Optional matplotlib import 

36try: 

37 import matplotlib.pyplot as plt 

38 from matplotlib.widgets import Cursor, MultiCursor, SpanSelector # noqa: F401 

39 

40 MATPLOTLIB_AVAILABLE = True 

41except ImportError: 

42 MATPLOTLIB_AVAILABLE = False 

43 

44 

45@dataclass 

46class CursorMeasurement: 

47 """Measurement result from cursors. 

48 

49 Attributes: 

50 x1: First cursor X position. 

51 x2: Second cursor X position. 

52 y1: First cursor Y position. 

53 y2: Second cursor Y position. 

54 delta_x: X difference (x2 - x1). 

55 delta_y: Y difference (y2 - y1). 

56 frequency: 1/delta_x if delta_x > 0. 

57 slope: delta_y/delta_x if delta_x != 0. 

58 

59 References: 

60 VIS-008 

61 """ 

62 

63 x1: float 

64 x2: float 

65 y1: float 

66 y2: float 

67 delta_x: float 

68 delta_y: float 

69 frequency: float | None = None 

70 slope: float | None = None 

71 

72 

73@dataclass 

74class ZoomState: 

75 """Current zoom/pan state. 

76 

77 Attributes: 

78 xlim: Current X-axis limits. 

79 ylim: Current Y-axis limits. 

80 history: Stack of previous zoom states. 

81 home_xlim: Original X-axis limits. 

82 home_ylim: Original Y-axis limits. 

83 

84 References: 

85 VIS-007 

86 """ 

87 

88 xlim: tuple[float, float] 

89 ylim: tuple[float, float] 

90 history: list[tuple[tuple[float, float], tuple[float, float]]] = field(default_factory=list) 

91 home_xlim: tuple[float, float] | None = None 

92 home_ylim: tuple[float, float] | None = None 

93 

94 

95def enable_zoom_pan( 

96 ax: Axes, 

97 *, 

98 enable_zoom: bool = True, 

99 enable_pan: bool = True, 

100 zoom_factor: float = 1.5, 

101) -> ZoomState: 

102 """Enable interactive zoom and pan on an axes. 

103 

104 Adds scroll wheel zoom and click-drag pan functionality. 

105 

106 Args: 

107 ax: Matplotlib axes to enable zoom/pan on. 

108 enable_zoom: Enable scroll wheel zoom. 

109 enable_pan: Enable click-drag pan. 

110 zoom_factor: Zoom factor per scroll step. 

111 

112 Returns: 

113 ZoomState object tracking zoom history. 

114 

115 Raises: 

116 ImportError: If matplotlib is not available. 

117 

118 Example: 

119 >>> fig, ax = plt.subplots() 

120 >>> ax.plot(trace.time_vector, trace.data) 

121 >>> state = enable_zoom_pan(ax) 

122 

123 References: 

124 VIS-007 

125 """ 

126 if not MATPLOTLIB_AVAILABLE: 

127 raise ImportError("matplotlib is required for interactive visualization") 

128 

129 # Store initial state 

130 xlim = ax.get_xlim() 

131 ylim = ax.get_ylim() 

132 state = ZoomState( 

133 xlim=xlim, 

134 ylim=ylim, 

135 home_xlim=xlim, 

136 home_ylim=ylim, 

137 ) 

138 

139 def on_scroll(event): # type: ignore[no-untyped-def] 

140 if event.inaxes != ax: 

141 return 

142 

143 # Get mouse position 

144 x_data = event.xdata 

145 y_data = event.ydata 

146 

147 if x_data is None or y_data is None: 

148 return 

149 

150 # Determine zoom direction 

151 if event.button == "up": 

152 factor = 1 / zoom_factor 

153 elif event.button == "down": 

154 factor = zoom_factor 

155 else: 

156 return 

157 

158 # Save current state 

159 state.history.append((state.xlim, state.ylim)) 

160 

161 # Calculate new limits centered on mouse position 

162 cur_xlim = ax.get_xlim() 

163 cur_ylim = ax.get_ylim() 

164 

165 new_width = (cur_xlim[1] - cur_xlim[0]) * factor 

166 new_height = (cur_ylim[1] - cur_ylim[0]) * factor 

167 

168 rel_x = (x_data - cur_xlim[0]) / (cur_xlim[1] - cur_xlim[0]) 

169 rel_y = (y_data - cur_ylim[0]) / (cur_ylim[1] - cur_ylim[0]) 

170 

171 new_xlim = ( 

172 x_data - new_width * rel_x, 

173 x_data + new_width * (1 - rel_x), 

174 ) 

175 new_ylim = ( 

176 y_data - new_height * rel_y, 

177 y_data + new_height * (1 - rel_y), 

178 ) 

179 

180 ax.set_xlim(new_xlim) 

181 ax.set_ylim(new_ylim) 

182 state.xlim = new_xlim 

183 state.ylim = new_ylim 

184 

185 ax.figure.canvas.draw_idle() 

186 

187 if enable_zoom: 

188 ax.figure.canvas.mpl_connect("scroll_event", on_scroll) 

189 

190 # Pan state 

191 pan_active = [False] 

192 pan_start: list[float | None] = [None, None] 

193 

194 def on_press(event): # type: ignore[no-untyped-def] 

195 if event.inaxes != ax: 

196 return 

197 if event.button == 1: # Left click 

198 pan_active[0] = True 

199 pan_start[0] = event.xdata 

200 pan_start[1] = event.ydata 

201 

202 def on_release(event: MouseEvent) -> None: 

203 pan_active[0] = False 

204 

205 def on_motion(event: MouseEvent) -> None: 

206 if not pan_active[0]: 

207 return 

208 if event.inaxes != ax: 

209 return 

210 if event.xdata is None or event.ydata is None: 

211 return 

212 if pan_start[0] is None or pan_start[1] is None: 212 ↛ 213line 212 didn't jump to line 213 because the condition on line 212 was never true

213 return 

214 

215 dx = pan_start[0] - event.xdata 

216 dy = pan_start[1] - event.ydata 

217 

218 cur_xlim = ax.get_xlim() 

219 cur_ylim = ax.get_ylim() 

220 

221 new_xlim = (cur_xlim[0] + dx, cur_xlim[1] + dx) 

222 new_ylim = (cur_ylim[0] + dy, cur_ylim[1] + dy) 

223 

224 ax.set_xlim(new_xlim) 

225 ax.set_ylim(new_ylim) 

226 state.xlim = new_xlim 

227 state.ylim = new_ylim 

228 

229 ax.figure.canvas.draw_idle() 

230 

231 if enable_pan: 

232 ax.figure.canvas.mpl_connect("button_press_event", on_press) 

233 ax.figure.canvas.mpl_connect("button_release_event", on_release) # type: ignore[arg-type] 

234 ax.figure.canvas.mpl_connect("motion_notify_event", on_motion) # type: ignore[arg-type] 

235 

236 return state 

237 

238 

239def plot_with_cursors( 

240 trace: WaveformTrace | NDArray[np.floating[Any]], 

241 *, 

242 sample_rate: float | None = None, 

243 cursor_type: Literal["vertical", "horizontal", "cross"] = "cross", 

244 ax: Axes | None = None, 

245 **plot_kwargs: Any, 

246) -> tuple[Figure, Axes, Cursor]: 

247 """Plot waveform with interactive measurement cursors. 

248 

249 Args: 

250 trace: Input trace or numpy array. 

251 sample_rate: Sample rate (required for arrays). 

252 cursor_type: Type of cursor lines. 

253 ax: Existing axes to plot on. 

254 **plot_kwargs: Additional arguments to plot(). 

255 

256 Returns: 

257 Tuple of (figure, axes, cursor widget). 

258 

259 Raises: 

260 ImportError: If matplotlib is not available. 

261 ValueError: If axes has no associated figure. 

262 

263 Example: 

264 >>> fig, ax, cursor = plot_with_cursors(trace) 

265 >>> plt.show() 

266 

267 References: 

268 VIS-008 

269 """ 

270 if not MATPLOTLIB_AVAILABLE: 

271 raise ImportError("matplotlib is required for interactive visualization") 

272 

273 # Get data and time vector 

274 if isinstance(trace, WaveformTrace): 

275 data = trace.data 

276 time = trace.time_vector 

277 else: 

278 data = np.asarray(trace) 

279 if sample_rate is None: 

280 sample_rate = 1.0 

281 time = np.arange(len(data)) / sample_rate 

282 

283 # Create figure if needed 

284 if ax is None: 

285 fig, ax = plt.subplots(figsize=(10, 6)) 

286 else: 

287 fig_temp = ax.figure 

288 if fig_temp is None: 

289 raise ValueError("Axes must have an associated figure") 

290 fig = cast("Figure", fig_temp) 

291 

292 # Plot data 

293 ax.plot(time, data, **plot_kwargs) 

294 ax.set_xlabel("Time (s)") 

295 ax.set_ylabel("Amplitude") 

296 ax.grid(True, alpha=0.3) 

297 

298 # Create cursor 

299 if cursor_type == "vertical": 

300 cursor = Cursor(ax, useblit=True, color="red", linewidth=1, vertOn=True, horizOn=False) 

301 elif cursor_type == "horizontal": 

302 cursor = Cursor(ax, useblit=True, color="red", linewidth=1, vertOn=False, horizOn=True) 

303 else: # cross 

304 cursor = Cursor(ax, useblit=True, color="red", linewidth=1) 

305 

306 return fig, ax, cursor 

307 

308 

309def add_measurement_cursors( 

310 ax: Axes, 

311 *, 

312 color: str = "red", 

313 linestyle: str = "--", 

314) -> dict: # type: ignore[type-arg] 

315 """Add dual measurement cursors to an axes. 

316 

317 Click and drag to define measurement region. Returns measurement 

318 data in the callback. 

319 

320 Args: 

321 ax: Axes to add cursors to. 

322 color: Cursor line color. 

323 linestyle: Cursor line style. 

324 

325 Returns: 

326 Dictionary with cursor state and get_measurement() function. 

327 

328 Raises: 

329 ImportError: If matplotlib is not available. 

330 

331 Example: 

332 >>> cursors = add_measurement_cursors(ax) 

333 >>> measurement = cursors['get_measurement']() 

334 >>> print(f"Delta X: {measurement.delta_x}") 

335 

336 References: 

337 VIS-008 

338 """ 

339 if not MATPLOTLIB_AVAILABLE: 

340 raise ImportError("matplotlib is required for interactive visualization") 

341 

342 state: dict[str, float | None | Any] = { 

343 "x1": None, 

344 "x2": None, 

345 "y1": None, 

346 "y2": None, 

347 "line1": None, 

348 "line2": None, 

349 } 

350 

351 def onselect(xmin: float, xmax: float) -> None: 

352 state["x1"] = xmin 

353 state["x2"] = xmax 

354 

355 # Get Y values at cursor positions 

356 for line in ax.get_lines(): 356 ↛ exitline 356 didn't return from function 'onselect' because the loop on line 356 didn't complete

357 xdata = line.get_xdata() 

358 ydata = line.get_ydata() 

359 # Type narrowing: these return ArrayLike from Line2D 

360 xdata_arr = np.asarray(xdata) 

361 ydata_arr = np.asarray(ydata) 

362 if len(xdata_arr) > 0: 362 ↛ 356line 362 didn't jump to line 356 because the condition on line 362 was always true

363 # Interpolate Y at cursor positions 

364 y1_interp: float = float(np.interp(xmin, xdata_arr, ydata_arr)) 

365 y2_interp: float = float(np.interp(xmax, xdata_arr, ydata_arr)) 

366 state["y1"] = y1_interp 

367 state["y2"] = y2_interp 

368 break 

369 

370 span = SpanSelector( 

371 ax, 

372 onselect, 

373 "horizontal", 

374 useblit=True, 

375 props={"alpha": 0.3, "facecolor": color}, 

376 interactive=True, 

377 ) 

378 

379 def get_measurement() -> CursorMeasurement | None: 

380 x1 = state["x1"] 

381 x2 = state["x2"] 

382 y1 = state["y1"] 

383 y2 = state["y2"] 

384 

385 if ( 

386 x1 is None 

387 or x2 is None 

388 or not isinstance(x1, int | float) 

389 or not isinstance(x2, int | float) 

390 ): 

391 return None 

392 

393 delta_x = x2 - x1 

394 y1_val = float(y1) if y1 is not None else 0.0 

395 y2_val = float(y2) if y2 is not None else 0.0 

396 delta_y = y2_val - y1_val 

397 

398 return CursorMeasurement( 

399 x1=x1, 

400 x2=x2, 

401 y1=y1_val, 

402 y2=y2_val, 

403 delta_x=delta_x, 

404 delta_y=delta_y, 

405 frequency=1 / delta_x if delta_x > 0 else None, 

406 slope=delta_y / delta_x if delta_x != 0 else None, 

407 ) 

408 

409 return { 

410 "span": span, 

411 "state": state, 

412 "get_measurement": get_measurement, 

413 } 

414 

415 

416def plot_phase( 

417 trace1: WaveformTrace | NDArray[np.floating[Any]], 

418 trace2: WaveformTrace | NDArray[np.floating[Any]] | None = None, 

419 *, 

420 delay: int = 1, 

421 delay_samples: int | None = None, 

422 ax: Axes | None = None, 

423 **plot_kwargs: Any, 

424) -> tuple[Figure, Axes]: 

425 """Create phase plot (X-Y plot) of two signals. 

426 

427 Plots trace1 on X-axis vs trace2 on Y-axis, useful for 

428 visualizing phase relationships and Lissajous figures. 

429 If trace2 is not provided, creates a self-phase plot using 

430 time-delayed version of trace1. 

431 

432 Args: 

433 trace1: Signal for X-axis. 

434 trace2: Signal for Y-axis. If None, uses delayed trace1. 

435 delay: Sample delay for self-phase plot (when trace2=None). 

436 delay_samples: Alias for delay parameter. 

437 ax: Existing axes to plot on. 

438 **plot_kwargs: Additional arguments to plot(). 

439 

440 Returns: 

441 Tuple of (figure, axes). 

442 

443 Raises: 

444 ImportError: If matplotlib is not available. 

445 ValueError: If axes has no associated figure. 

446 

447 Example: 

448 >>> fig, ax = plot_phase(signal_x, signal_y) 

449 >>> plt.show() 

450 >>> # Self-phase plot 

451 >>> fig, ax = plot_phase(signal, delay_samples=10) 

452 

453 References: 

454 VIS-009 

455 """ 

456 if not MATPLOTLIB_AVAILABLE: 

457 raise ImportError("matplotlib is required for interactive visualization") 

458 

459 # Handle delay_samples alias 

460 if delay_samples is not None: 

461 delay = delay_samples 

462 

463 # Get data 

464 data1 = trace1.data if isinstance(trace1, WaveformTrace) else np.asarray(trace1) 

465 

466 # If trace2 not provided, create self-phase plot with delay 

467 if trace2 is None: 

468 data2 = np.roll(data1, -delay) 

469 else: 

470 data2 = trace2.data if isinstance(trace2, WaveformTrace) else np.asarray(trace2) 

471 

472 # Ensure same length 

473 n = min(len(data1), len(data2)) 

474 data1 = data1[:n] 

475 data2 = data2[:n] 

476 

477 # Create figure if needed 

478 if ax is None: 

479 fig, ax = plt.subplots(figsize=(8, 8)) 

480 else: 

481 fig_temp = ax.figure 

482 if fig_temp is None: 

483 raise ValueError("Axes must have an associated figure") 

484 fig = cast("Figure", fig_temp) 

485 

486 # Plot 

487 defaults: dict[str, Any] = {"alpha": 0.5, "marker": ".", "linestyle": "-", "markersize": 2} 

488 defaults.update(plot_kwargs) 

489 ax.plot(data1, data2, **defaults) 

490 

491 # Equal aspect ratio for proper phase visualization 

492 ax.set_aspect("equal", adjustable="datalim") 

493 ax.set_xlabel("Signal 1") 

494 ax.set_ylabel("Signal 2") 

495 ax.set_title("Phase Plot (X-Y)") 

496 ax.grid(True, alpha=0.3) 

497 

498 return fig, ax 

499 

500 

501def plot_bode( 

502 frequencies: NDArray[np.floating[Any]], 

503 magnitude: NDArray[np.floating[Any]] | NDArray[np.complexfloating[Any, Any]], 

504 phase: NDArray[np.floating[Any]] | None = None, 

505 *, 

506 magnitude_db: bool = True, 

507 phase_degrees: bool = True, 

508 fig: Figure | None = None, 

509 **plot_kwargs: Any, 

510) -> Figure: 

511 """Create Bode plot with magnitude and phase. 

512 

513 Standard frequency response visualization with logarithmic 

514 frequency axis. 

515 

516 Args: 

517 frequencies: Frequency array in Hz. 

518 magnitude: Magnitude array (linear or dB), or complex transfer function H(s). 

519 If complex, magnitude and phase are extracted automatically. 

520 phase: Phase array in radians (optional). Ignored if magnitude is complex. 

521 magnitude_db: If True, magnitude is already in dB. Ignored if complex input. 

522 phase_degrees: If True, convert phase to degrees. 

523 fig: Existing figure to plot on. 

524 **plot_kwargs: Additional arguments to plot(). 

525 

526 Returns: 

527 Matplotlib Figure object with magnitude and optionally phase axes. 

528 

529 Raises: 

530 ImportError: If matplotlib is not available. 

531 

532 Example: 

533 >>> # With complex transfer function 

534 >>> H = 1 / (1 + 1j * freqs / 1000) 

535 >>> fig = plot_bode(freqs, H) 

536 >>> ax_mag, ax_phase = fig.axes[:2] # Access axes from figure 

537 >>> plt.show() 

538 

539 References: 

540 VIS-010 

541 """ 

542 if not MATPLOTLIB_AVAILABLE: 

543 raise ImportError("matplotlib is required for interactive visualization") 

544 

545 frequencies = np.asarray(frequencies) 

546 magnitude = np.asarray(magnitude) 

547 

548 # Handle complex transfer function input 

549 if np.iscomplexobj(magnitude): 

550 # Extract phase from complex input 

551 phase = np.angle(magnitude) 

552 # Convert to magnitude in dB 

553 with np.errstate(divide="ignore"): 

554 magnitude = 20 * np.log10(np.abs(magnitude)) 

555 magnitude = np.nan_to_num(magnitude, neginf=-200) 

556 elif not magnitude_db: 

557 # Convert magnitude to dB if needed 

558 with np.errstate(divide="ignore"): 

559 magnitude = 20 * np.log10(np.abs(magnitude)) 

560 magnitude = np.nan_to_num(magnitude, neginf=-200) 

561 

562 # Create figure 

563 if phase is not None: 

564 if fig is None: 

565 fig, (ax_mag, ax_phase) = plt.subplots(2, 1, figsize=(10, 8), sharex=True) 

566 else: 

567 axes = fig.subplots(2, 1, sharex=True) 

568 ax_mag, ax_phase = axes 

569 else: 

570 if fig is None: 

571 fig, ax_mag = plt.subplots(figsize=(10, 5)) 

572 else: 

573 ax_mag = fig.subplots() 

574 ax_phase = None 

575 

576 # Plot magnitude 

577 ax_mag.semilogx(frequencies, magnitude, **plot_kwargs) 

578 ax_mag.set_ylabel("Magnitude (dB)") 

579 ax_mag.grid(True, which="both", alpha=0.3) 

580 ax_mag.set_title("Bode Plot") 

581 

582 # Plot phase if provided 

583 if phase is not None and ax_phase is not None: 

584 phase = np.asarray(phase) 

585 if phase_degrees: 

586 phase = np.degrees(phase) 

587 ylabel = "Phase (degrees)" 

588 else: 

589 ylabel = "Phase (radians)" 

590 

591 ax_phase.semilogx(frequencies, phase, **plot_kwargs) 

592 ax_phase.set_ylabel(ylabel) 

593 ax_phase.set_xlabel("Frequency (Hz)") 

594 ax_phase.grid(True, which="both", alpha=0.3) 

595 else: 

596 ax_mag.set_xlabel("Frequency (Hz)") 

597 

598 fig.tight_layout() 

599 

600 return fig 

601 

602 

603def plot_waterfall( 

604 data: NDArray[np.floating[Any]], 

605 *, 

606 time_axis: NDArray[np.floating[Any]] | None = None, 

607 freq_axis: NDArray[np.floating[Any]] | None = None, 

608 sample_rate: float = 1.0, 

609 nperseg: int = 256, 

610 noverlap: int | None = None, 

611 cmap: str = "viridis", 

612 ax: Axes | None = None, 

613 **kwargs: Any, 

614) -> tuple[Figure, Axes]: 

615 """Create 3D waterfall plot (spectrogram with depth). 

616 

617 Shows spectrum evolution over time as stacked frequency slices. 

618 

619 Args: 

620 data: Input signal array (1D) or pre-computed spectrogram (2D). 

621 If 2D, treated as (n_traces, n_points) spectrogram data. 

622 time_axis: Time axis for signal. 

623 freq_axis: Frequency axis (if pre-computed). 

624 sample_rate: Sample rate in Hz. 

625 nperseg: Segment length for FFT. 

626 noverlap: Overlap between segments. 

627 cmap: Colormap for amplitude coloring. 

628 ax: Existing 3D axes to plot on. 

629 **kwargs: Additional arguments. 

630 

631 Returns: 

632 Tuple of (figure, axes). 

633 

634 Raises: 

635 ImportError: If matplotlib is not available. 

636 TypeError: If axes is not a 3D axes. 

637 ValueError: If axes has no associated figure. 

638 

639 Example: 

640 >>> fig, ax = plot_waterfall(signal, sample_rate=1e6) 

641 >>> plt.show() 

642 >>> # With 2D precomputed data 

643 >>> fig, ax = plot_waterfall(spectrogram_data) 

644 

645 References: 

646 VIS-011 

647 """ 

648 if not MATPLOTLIB_AVAILABLE: 

649 raise ImportError("matplotlib is required for interactive visualization") 

650 

651 data = np.asarray(data) 

652 

653 # Check if data is 2D (precomputed spectrogram) 

654 if data.ndim == 2: 

655 # Treat as precomputed spectrogram (n_traces, n_points) 

656 Sxx_db = data 

657 n_traces, n_points = data.shape 

658 frequencies = freq_axis if freq_axis is not None else np.arange(n_points) 

659 times = time_axis if time_axis is not None else np.arange(n_traces) 

660 elif freq_axis is not None: 660 ↛ 662line 660 didn't jump to line 662 because the condition on line 660 was never true

661 # 1D data with explicit freq_axis means precomputed 

662 Sxx_db = data 

663 frequencies = freq_axis 

664 times = ( 

665 time_axis 

666 if time_axis is not None 

667 else np.arange(Sxx_db.shape[1] if Sxx_db.ndim > 1 else 1) 

668 ) 

669 else: 

670 # Compute spectrogram from 1D signal 

671 if noverlap is None: 

672 noverlap = nperseg // 2 

673 

674 frequencies, times, Sxx = scipy_signal.spectrogram( 

675 data, fs=sample_rate, nperseg=nperseg, noverlap=noverlap 

676 ) 

677 Sxx_db = 10 * np.log10(Sxx + 1e-10) 

678 times = time_axis if time_axis is not None else np.arange(Sxx_db.shape[1]) 

679 

680 # Create 3D figure 

681 if ax is None: 

682 fig = plt.figure(figsize=(12, 8)) 

683 ax = fig.add_subplot(111, projection="3d") 

684 else: 

685 fig_temp = ax.figure 

686 if fig_temp is None: 

687 raise ValueError("Axes must have an associated figure") 

688 fig = cast("Figure", fig_temp) 

689 

690 # Create meshgrid 

691 T, F = np.meshgrid(times, frequencies) 

692 

693 # Ensure Sxx_db matches meshgrid shape (n_frequencies, n_times) 

694 if Sxx_db.shape != T.shape: 

695 if Sxx_db.T.shape == T.shape: 695 ↛ 702line 695 didn't jump to line 702 because the condition on line 695 was always true

696 Sxx_db = Sxx_db.T 

697 # If still mismatched, the data dimensions may be incompatible 

698 # but we'll let plot_surface raise a more informative error 

699 

700 # Plot surface 

701 # Type checking: ax must be a 3D axes at this point 

702 if not hasattr(ax, "plot_surface"): 

703 raise TypeError("Axes must be a 3D axes for waterfall plot") 

704 surf = ax.plot_surface( # type: ignore[attr-defined,union-attr] 

705 T, 

706 F, 

707 Sxx_db, 

708 cmap=cmap, 

709 linewidth=0, 

710 antialiased=True, 

711 alpha=0.8, 

712 ) 

713 

714 ax.set_xlabel("Time (s)") 

715 ax.set_ylabel("Frequency (Hz)") 

716 if hasattr(ax, "set_zlabel"): 716 ↛ 718line 716 didn't jump to line 718 because the condition on line 716 was always true

717 ax.set_zlabel("Power (dB)") # type: ignore[attr-defined] 

718 ax.set_title("Waterfall Plot (Spectrogram)") 

719 

720 fig.colorbar(surf, ax=ax, label="Power (dB)", shrink=0.5) 

721 

722 return fig, ax 

723 

724 

725def plot_histogram( 

726 trace: WaveformTrace | NDArray[np.floating[Any]], 

727 *, 

728 bins: int | str | NDArray[np.floating[Any]] = "auto", 

729 density: bool = True, 

730 show_stats: bool = True, 

731 show_kde: bool = False, 

732 ax: Axes | None = None, 

733 **hist_kwargs: Any, 

734) -> tuple[Figure, Axes, dict[str, Any]]: 

735 """Create histogram plot of signal amplitude distribution. 

736 

737 Optionally overlays kernel density estimate and statistics. 

738 

739 Args: 

740 trace: Input trace or numpy array. 

741 bins: Number of bins or binning strategy. 

742 density: If True, normalize to probability density. 

743 show_stats: Show mean and standard deviation lines. 

744 show_kde: Overlay kernel density estimate. 

745 ax: Existing axes to plot on. 

746 **hist_kwargs: Additional arguments to hist(). 

747 

748 Returns: 

749 Tuple of (figure, axes, stats_dict). 

750 

751 Raises: 

752 ImportError: If matplotlib is not available. 

753 ValueError: If axes has no associated figure. 

754 

755 Example: 

756 >>> fig, ax, stats = plot_histogram(trace, bins=50, show_kde=True) 

757 >>> print(f"Mean: {stats['mean']:.3f}") 

758 

759 References: 

760 VIS-012 

761 """ 

762 if not MATPLOTLIB_AVAILABLE: 

763 raise ImportError("matplotlib is required for interactive visualization") 

764 

765 # Get data 

766 data = trace.data if isinstance(trace, WaveformTrace) else np.asarray(trace) 

767 

768 # Create figure if needed 

769 if ax is None: 

770 fig, ax = plt.subplots(figsize=(10, 6)) 

771 else: 

772 fig_temp = ax.figure 

773 if fig_temp is None: 

774 raise ValueError("Axes must have an associated figure") 

775 fig = cast("Figure", fig_temp) 

776 

777 # Calculate statistics 

778 mean = float(np.mean(data)) 

779 std = float(np.std(data)) 

780 median = float(np.median(data)) 

781 min_val = float(np.min(data)) 

782 max_val = float(np.max(data)) 

783 

784 stats = { 

785 "mean": mean, 

786 "std": std, 

787 "median": median, 

788 "min": min_val, 

789 "max": max_val, 

790 "count": len(data), 

791 } 

792 

793 # Plot histogram 

794 defaults: dict[str, Any] = {"alpha": 0.7, "edgecolor": "black", "linewidth": 0.5} 

795 defaults.update(hist_kwargs) 

796 _counts, bin_edges, _patches = ax.hist(data, bins=bins, density=density, **defaults) # type: ignore[arg-type] 

797 

798 stats["bins"] = len(bin_edges) - 1 

799 

800 # Show statistics lines 

801 if show_stats: 

802 ax.get_ylim() 

803 ax.axvline(mean, color="red", linestyle="--", linewidth=2, label=f"Mean: {mean:.3g}") 

804 ax.axvline(mean - std, color="orange", linestyle=":", linewidth=1.5, label="Mean - Std") 

805 ax.axvline(mean + std, color="orange", linestyle=":", linewidth=1.5, label="Mean + Std") 

806 

807 # Show KDE 

808 if show_kde: 

809 from scipy.stats import gaussian_kde 

810 

811 kde = gaussian_kde(data) 

812 x_kde = np.linspace(min_val, max_val, 200) 

813 y_kde = kde(x_kde) 

814 

815 if density: 

816 ax.plot(x_kde, y_kde, "r-", linewidth=2, label="KDE") 

817 else: 

818 # Scale KDE to histogram 

819 bin_width = bin_edges[1] - bin_edges[0] 

820 ax.plot(x_kde, y_kde * len(data) * bin_width, "r-", linewidth=2, label="KDE") 

821 

822 ax.set_xlabel("Amplitude") 

823 ax.set_ylabel("Density" if density else "Count") 

824 ax.set_title("Amplitude Distribution") 

825 # Only show legend if there are labeled artists 

826 if show_stats or show_kde: 

827 ax.legend(loc="upper right") 

828 ax.grid(True, alpha=0.3) 

829 

830 return fig, ax, stats 

831 

832 

833__all__ = [ 

834 "CursorMeasurement", 

835 "ZoomState", 

836 "add_measurement_cursors", 

837 "enable_zoom_pan", 

838 "plot_bode", 

839 "plot_histogram", 

840 "plot_phase", 

841 "plot_waterfall", 

842 "plot_with_cursors", 

843]