Coverage for src / tracekit / visualization / specialized.py: 97%

177 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Specialized plot types for protocol analysis and state visualization. 

2 

3This module provides specialized visualizations including protocol timing 

4diagrams, state machine views, and bus transaction timelines. 

5 

6 

7Example: 

8 >>> from tracekit.visualization.specialized import plot_protocol_timing 

9 >>> fig = plot_protocol_timing(decoded_packets, sample_rate=1e6) 

10 

11References: 

12 - Wavedrom-style digital waveform rendering 

13 - State machine diagram standards 

14 - Bus protocol visualization best practices 

15""" 

16 

17from __future__ import annotations 

18 

19from dataclasses import dataclass 

20from typing import TYPE_CHECKING, Literal 

21 

22import numpy as np 

23 

24if TYPE_CHECKING: 

25 from matplotlib.axes import Axes 

26 from matplotlib.figure import Figure 

27 from numpy.typing import NDArray 

28 

29try: 

30 import matplotlib.pyplot as plt 

31 from matplotlib import patches 

32 

33 HAS_MATPLOTLIB = True 

34except ImportError: 

35 HAS_MATPLOTLIB = False 

36 

37 

38@dataclass 

39class ProtocolSignal: 

40 """Protocol signal for timing diagram. 

41 

42 Attributes: 

43 name: Signal name 

44 data: Signal data (0/1 for digital, values for analog) 

45 type: Signal type ("digital", "clock", "bus", "analog") 

46 transitions: List of transition times 

47 annotations: Dict of time -> annotation text 

48 """ 

49 

50 name: str 

51 data: NDArray[np.float64] 

52 type: Literal["digital", "clock", "bus", "analog"] = "digital" 

53 transitions: list[float] | None = None 

54 annotations: dict[float, str] | None = None 

55 

56 

57@dataclass 

58class StateTransition: 

59 """State machine transition. 

60 

61 Attributes: 

62 from_state: Source state name 

63 to_state: Target state name 

64 condition: Transition condition/label 

65 style: Line style ("solid", "dashed", "dotted") 

66 """ 

67 

68 from_state: str 

69 to_state: str 

70 condition: str = "" 

71 style: Literal["solid", "dashed", "dotted"] = "solid" 

72 

73 

74def plot_protocol_timing( 

75 signals: list[ProtocolSignal], 

76 sample_rate: float, 

77 *, 

78 time_range: tuple[float, float] | None = None, 

79 time_unit: str = "auto", 

80 style: Literal["wavedrom", "classic"] = "wavedrom", 

81 figsize: tuple[float, float] | None = None, 

82 title: str | None = None, 

83) -> Figure: 

84 """Plot protocol timing diagram in wavedrom style. 

85 

86 Creates a timing diagram showing digital signals, clock edges, and 

87 bus transactions with annotations for protocol events. 

88 

89 Args: 

90 signals: List of ProtocolSignal objects to plot. 

91 sample_rate: Sample rate in Hz. 

92 time_range: Time range to plot (t_min, t_max) in seconds. None = full range. 

93 time_unit: Time unit for x-axis ("s", "ms", "us", "ns", "auto"). 

94 style: Diagram style ("wavedrom" = clean digital, "classic" = traditional). 

95 figsize: Figure size (width, height). Auto-calculated if None. 

96 title: Plot title. 

97 

98 Returns: 

99 Matplotlib Figure object. 

100 

101 Raises: 

102 ImportError: If matplotlib is not available. 

103 ValueError: If signals list is empty. 

104 

105 Example: 

106 >>> sda = ProtocolSignal("SDA", sda_data, type="digital") 

107 >>> scl = ProtocolSignal("SCL", scl_data, type="clock") 

108 >>> fig = plot_protocol_timing( 

109 ... [scl, sda], 

110 ... sample_rate=1e6, 

111 ... style="wavedrom", 

112 ... title="I2C Transaction" 

113 ... ) 

114 

115 References: 

116 VIS-021: Specialized - Protocol Timing Diagram 

117 Wavedrom digital waveform rendering 

118 """ 

119 if not HAS_MATPLOTLIB: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true

120 raise ImportError("matplotlib is required for visualization") 

121 

122 if len(signals) == 0: 

123 raise ValueError("signals list cannot be empty") 

124 

125 n_signals = len(signals) 

126 

127 # Auto-calculate figure size 

128 if figsize is None: 

129 width = 12 

130 height = max(4, n_signals * 0.8 + 1) 

131 figsize = (width, height) 

132 

133 fig, axes = plt.subplots( 

134 n_signals, 

135 1, 

136 figsize=figsize, 

137 sharex=True, 

138 gridspec_kw={"hspace": 0.1}, 

139 ) 

140 

141 if n_signals == 1: 

142 axes = [axes] 

143 

144 # Determine time range 

145 if time_range is None: 

146 max_len = max(len(sig.data) for sig in signals) 

147 t_min = 0.0 

148 t_max = max_len / sample_rate 

149 else: 

150 t_min, t_max = time_range 

151 

152 # Select time unit 

153 if time_unit == "auto": 

154 time_range_val = t_max - t_min 

155 if time_range_val < 1e-6: 

156 time_unit = "ns" 

157 time_mult = 1e9 

158 elif time_range_val < 1e-3: 

159 time_unit = "us" 

160 time_mult = 1e6 

161 elif time_range_val < 1: 

162 time_unit = "ms" 

163 time_mult = 1e3 

164 else: 

165 time_unit = "s" 

166 time_mult = 1.0 

167 else: 

168 time_mult = {"s": 1.0, "ms": 1e3, "us": 1e6, "ns": 1e9}.get(time_unit, 1.0) 

169 

170 # Plot each signal 

171 for _idx, (signal, ax) in enumerate(zip(signals, axes, strict=False)): 

172 # Create time vector 

173 time = np.arange(len(signal.data)) / sample_rate * time_mult 

174 

175 # Filter to time range 

176 mask = (time >= t_min * time_mult) & (time <= t_max * time_mult) 

177 time = time[mask] 

178 data = signal.data[mask] 

179 

180 if style == "wavedrom": 

181 _plot_wavedrom_signal(ax, time, data, signal) 

182 else: 

183 _plot_classic_signal(ax, time, data, signal) 

184 

185 # Add signal name label 

186 ax.set_ylabel(signal.name, rotation=0, ha="right", va="center", fontsize=10) 

187 ax.set_ylim(-0.2, 1.3) 

188 

189 # Remove y-axis ticks 

190 ax.set_yticks([]) 

191 

192 # Grid for timing 

193 ax.grid(True, axis="x", alpha=0.3, linestyle=":") 

194 

195 # Add annotations 

196 if signal.annotations: 

197 for t, text in signal.annotations.items(): 

198 if t_min <= t <= t_max: 

199 ax.annotate( 

200 text, 

201 xy=(t * time_mult, 1.2), 

202 fontsize=8, 

203 ha="center", 

204 bbox={"boxstyle": "round,pad=0.3", "facecolor": "yellow", "alpha": 0.7}, 

205 ) 

206 

207 # X-axis label only on bottom plot 

208 axes[-1].set_xlabel(f"Time ({time_unit})", fontsize=11) 

209 

210 if title: 

211 fig.suptitle(title, fontsize=14, y=0.98) 

212 

213 fig.tight_layout() 

214 return fig 

215 

216 

217def _plot_wavedrom_signal( 

218 ax: Axes, 

219 time: NDArray[np.float64], 

220 data: NDArray[np.float64], 

221 signal: ProtocolSignal, 

222) -> None: 

223 """Plot signal in wavedrom style (clean digital waveform).""" 

224 if signal.type == "clock": 

225 # Clock signal: square wave 

226 for i in range(len(time) - 1): 

227 level = 1 if data[i] > 0.5 else 0 

228 ax.plot( 

229 [time[i], time[i + 1]], 

230 [level, level], 

231 "b-", 

232 linewidth=1.5, 

233 ) 

234 # Vertical transition 

235 if i < len(time) - 1: 235 ↛ 226line 235 didn't jump to line 226 because the condition on line 235 was always true

236 next_level = 1 if data[i + 1] > 0.5 else 0 

237 if level != next_level: 237 ↛ 226line 237 didn't jump to line 226 because the condition on line 237 was always true

238 ax.plot( 

239 [time[i + 1], time[i + 1]], 

240 [level, next_level], 

241 "b-", 

242 linewidth=1.5, 

243 ) 

244 

245 elif signal.type == "digital": 

246 # Digital signal: step function with transitions 

247 for i in range(len(time) - 1): 

248 level = 1 if data[i] > 0.5 else 0 

249 ax.plot( 

250 [time[i], time[i + 1]], 

251 [level, level], 

252 "k-", 

253 linewidth=1.5, 

254 ) 

255 # Vertical transition with slight slant for visual clarity 

256 if i < len(time) - 1: 256 ↛ 247line 256 didn't jump to line 247 because the condition on line 256 was always true

257 next_level = 1 if data[i + 1] > 0.5 else 0 

258 if level != next_level: 

259 transition_width = (time[i + 1] - time[i]) * 0.1 

260 ax.plot( 

261 [time[i + 1] - transition_width, time[i + 1]], 

262 [level, next_level], 

263 "k-", 

264 linewidth=1.5, 

265 ) 

266 

267 elif signal.type == "bus": 

268 # Bus signal: show as high-impedance or data values 

269 ax.fill_between(time, 0.3, 0.7, alpha=0.3, color="gray") 

270 ax.plot(time, np.full_like(time, 0.5), "k-", linewidth=0.5) 

271 

272 else: 

273 # Analog signal 

274 ax.plot(time, data, "r-", linewidth=1.2) 

275 

276 

277def _plot_classic_signal( 

278 ax: Axes, 

279 time: NDArray[np.float64], 

280 data: NDArray[np.float64], 

281 signal: ProtocolSignal, 

282) -> None: 

283 """Plot signal in classic style (traditional oscilloscope-like).""" 

284 ax.plot(time, data, "b-", linewidth=1.2) 

285 ax.axhline(0.5, color="gray", linestyle="--", linewidth=0.5, alpha=0.5) 

286 

287 

288def plot_state_machine( 

289 states: list[str], 

290 transitions: list[StateTransition], 

291 *, 

292 initial_state: str | None = None, 

293 final_states: list[str] | None = None, 

294 layout: Literal["circular", "hierarchical", "force"] = "circular", 

295 figsize: tuple[float, float] = (10, 8), 

296 title: str | None = None, 

297) -> Figure: 

298 """Plot state machine diagram. 

299 

300 Creates a state diagram showing states as nodes and transitions as 

301 directed edges with condition labels. 

302 

303 Args: 

304 states: List of state names. 

305 transitions: List of StateTransition objects. 

306 initial_state: Initial state (marked with double circle). 

307 final_states: List of final states (marked with double circle). 

308 layout: Layout algorithm for state positioning. 

309 figsize: Figure size (width, height). 

310 title: Plot title. 

311 

312 Returns: 

313 Matplotlib Figure object. 

314 

315 Raises: 

316 ImportError: If matplotlib is not available. 

317 

318 Example: 

319 >>> states = ["IDLE", "ACTIVE", "WAIT", "DONE"] 

320 >>> transitions = [ 

321 ... StateTransition("IDLE", "ACTIVE", "START"), 

322 ... StateTransition("ACTIVE", "WAIT", "BUSY"), 

323 ... StateTransition("WAIT", "ACTIVE", "RETRY"), 

324 ... StateTransition("ACTIVE", "DONE", "COMPLETE"), 

325 ... ] 

326 >>> fig = plot_state_machine( 

327 ... states, transitions, initial_state="IDLE", final_states=["DONE"] 

328 ... ) 

329 

330 References: 

331 VIS-022: Specialized - State Machine View 

332 """ 

333 if not HAS_MATPLOTLIB: 333 ↛ 334line 333 didn't jump to line 334 because the condition on line 333 was never true

334 raise ImportError("matplotlib is required for visualization") 

335 

336 fig, ax = plt.subplots(figsize=figsize) 

337 

338 # Calculate state positions using selected layout 

339 positions = _calculate_state_positions(states, layout) 

340 

341 # Draw states as circles 

342 state_radius = 0.15 

343 

344 for state, (x, y) in positions.items(): 

345 # Draw state circle 

346 circle = patches.Circle( 

347 (x, y), 

348 state_radius, 

349 fill=True, 

350 facecolor="lightblue", 

351 edgecolor="black", 

352 linewidth=2.0, 

353 ) 

354 ax.add_patch(circle) 

355 

356 # Mark initial state with double circle 

357 if state == initial_state: 

358 outer_circle = patches.Circle( 

359 (x, y), 

360 state_radius * 1.2, 

361 fill=False, 

362 edgecolor="black", 

363 linewidth=2.0, 

364 ) 

365 ax.add_patch(outer_circle) 

366 

367 # Mark final states with double circle 

368 if final_states and state in final_states: 

369 inner_circle = patches.Circle( 

370 (x, y), 

371 state_radius * 0.8, 

372 fill=False, 

373 edgecolor="black", 

374 linewidth=2.0, 

375 ) 

376 ax.add_patch(inner_circle) 

377 

378 # Add state label 

379 ax.text( 

380 x, 

381 y, 

382 state, 

383 ha="center", 

384 va="center", 

385 fontsize=10, 

386 fontweight="bold", 

387 ) 

388 

389 # Draw transitions as arrows 

390 for trans in transitions: 

391 if trans.from_state not in positions or trans.to_state not in positions: 

392 continue 

393 

394 x1, y1 = positions[trans.from_state] 

395 x2, y2 = positions[trans.to_state] 

396 

397 # Calculate arrow start/end on circle perimeter 

398 dx = x2 - x1 

399 dy = y2 - y1 

400 dist = np.sqrt(dx**2 + dy**2) 

401 

402 if dist < 1e-6: 

403 # Self-loop 

404 _draw_self_loop(ax, x1, y1, state_radius, trans.condition) 

405 continue 

406 

407 # Normalize 

408 dx_norm = dx / dist 

409 dy_norm = dy / dist 

410 

411 # Arrow start/end on circle edges 

412 arrow_start_x = x1 + dx_norm * state_radius 

413 arrow_start_y = y1 + dy_norm * state_radius 

414 arrow_end_x = x2 - dx_norm * state_radius 

415 arrow_end_y = y2 - dy_norm * state_radius 

416 

417 # Line style 

418 linestyle = { 

419 "solid": "-", 

420 "dashed": "--", 

421 "dotted": ":", 

422 }.get(trans.style, "-") 

423 

424 # Draw arrow 

425 ax.annotate( 

426 "", 

427 xy=(arrow_end_x, arrow_end_y), 

428 xytext=(arrow_start_x, arrow_start_y), 

429 arrowprops={ 

430 "arrowstyle": "->", 

431 "lw": 1.5, 

432 "linestyle": linestyle, 

433 "color": "black", 

434 }, 

435 ) 

436 

437 # Add transition label 

438 if trans.condition: 

439 mid_x = (x1 + x2) / 2 

440 mid_y = (y1 + y2) / 2 

441 ax.text( 

442 mid_x, 

443 mid_y, 

444 trans.condition, 

445 fontsize=8, 

446 ha="center", 

447 bbox={ 

448 "boxstyle": "round,pad=0.3", 

449 "facecolor": "white", 

450 "edgecolor": "gray", 

451 "alpha": 0.9, 

452 }, 

453 ) 

454 

455 # Set axis properties 

456 ax.set_aspect("equal") 

457 ax.axis("off") 

458 ax.set_xlim(-0.2, 1.2) 

459 ax.set_ylim(-0.2, 1.2) 

460 

461 if title: 

462 ax.set_title(title, fontsize=14, pad=20) 

463 

464 fig.tight_layout() 

465 return fig 

466 

467 

468def _calculate_state_positions( 

469 states: list[str], 

470 layout: str, 

471) -> dict[str, tuple[float, float]]: 

472 """Calculate state positions using layout algorithm.""" 

473 n_states = len(states) 

474 positions = {} 

475 

476 if layout == "circular": 

477 # Arrange states in a circle 

478 angle_step = 2 * np.pi / n_states 

479 for i, state in enumerate(states): 

480 angle = i * angle_step 

481 x = 0.5 + 0.4 * np.cos(angle) 

482 y = 0.5 + 0.4 * np.sin(angle) 

483 positions[state] = (x, y) 

484 

485 elif layout == "hierarchical": 

486 # Arrange in rows (simplified hierarchical) 

487 states_per_row = int(np.ceil(np.sqrt(n_states))) 

488 for i, state in enumerate(states): 

489 row = i // states_per_row 

490 col = i % states_per_row 

491 x = (col + 0.5) / states_per_row 

492 y = 1.0 - (row + 0.5) / np.ceil(n_states / states_per_row) 

493 positions[state] = (x, y) 

494 

495 else: # force-directed (simplified) 

496 # Use random positions as a placeholder for true force-directed layout 

497 np.random.seed(42) 

498 for i, state in enumerate(states): # noqa: B007 

499 x = 0.2 + 0.6 * np.random.rand() 

500 y = 0.2 + 0.6 * np.random.rand() 

501 positions[state] = (x, y) 

502 

503 return positions 

504 

505 

506def _draw_self_loop( 

507 ax: Axes, 

508 x: float, 

509 y: float, 

510 radius: float, 

511 label: str, 

512) -> None: 

513 """Draw self-loop transition on state.""" 

514 # Draw arc above state 

515 arc = patches.Arc( 

516 (x, y + radius), 

517 width=radius * 1.5, 

518 height=radius * 1.5, 

519 angle=0, 

520 theta1=0, 

521 theta2=180, 

522 linewidth=1.5, 

523 edgecolor="black", 

524 fill=False, 

525 ) 

526 ax.add_patch(arc) 

527 

528 # Add arrow head 

529 ax.annotate( 

530 "", 

531 xy=(x - radius * 0.7, y + radius * 0.3), 

532 xytext=(x - radius * 0.5, y + radius * 0.5), 

533 arrowprops={"arrowstyle": "->", "lw": 1.5, "color": "black"}, 

534 ) 

535 

536 # Add label 

537 if label: 537 ↛ exitline 537 didn't return from function '_draw_self_loop' because the condition on line 537 was always true

538 ax.text( 

539 x, 

540 y + radius * 2.2, 

541 label, 

542 fontsize=8, 

543 ha="center", 

544 bbox={"boxstyle": "round,pad=0.2", "facecolor": "white", "alpha": 0.9}, 

545 ) 

546 

547 

548__all__ = [ 

549 "ProtocolSignal", 

550 "StateTransition", 

551 "plot_protocol_timing", 

552 "plot_state_machine", 

553]