Coverage for src / tracekit / visualization / specialized.py: 97%
177 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Specialized plot types for protocol analysis and state visualization.
3This module provides specialized visualizations including protocol timing
4diagrams, state machine views, and bus transaction timelines.
7Example:
8 >>> from tracekit.visualization.specialized import plot_protocol_timing
9 >>> fig = plot_protocol_timing(decoded_packets, sample_rate=1e6)
11References:
12 - Wavedrom-style digital waveform rendering
13 - State machine diagram standards
14 - Bus protocol visualization best practices
15"""
17from __future__ import annotations
19from dataclasses import dataclass
20from typing import TYPE_CHECKING, Literal
22import numpy as np
24if TYPE_CHECKING:
25 from matplotlib.axes import Axes
26 from matplotlib.figure import Figure
27 from numpy.typing import NDArray
29try:
30 import matplotlib.pyplot as plt
31 from matplotlib import patches
33 HAS_MATPLOTLIB = True
34except ImportError:
35 HAS_MATPLOTLIB = False
38@dataclass
39class ProtocolSignal:
40 """Protocol signal for timing diagram.
42 Attributes:
43 name: Signal name
44 data: Signal data (0/1 for digital, values for analog)
45 type: Signal type ("digital", "clock", "bus", "analog")
46 transitions: List of transition times
47 annotations: Dict of time -> annotation text
48 """
50 name: str
51 data: NDArray[np.float64]
52 type: Literal["digital", "clock", "bus", "analog"] = "digital"
53 transitions: list[float] | None = None
54 annotations: dict[float, str] | None = None
57@dataclass
58class StateTransition:
59 """State machine transition.
61 Attributes:
62 from_state: Source state name
63 to_state: Target state name
64 condition: Transition condition/label
65 style: Line style ("solid", "dashed", "dotted")
66 """
68 from_state: str
69 to_state: str
70 condition: str = ""
71 style: Literal["solid", "dashed", "dotted"] = "solid"
74def plot_protocol_timing(
75 signals: list[ProtocolSignal],
76 sample_rate: float,
77 *,
78 time_range: tuple[float, float] | None = None,
79 time_unit: str = "auto",
80 style: Literal["wavedrom", "classic"] = "wavedrom",
81 figsize: tuple[float, float] | None = None,
82 title: str | None = None,
83) -> Figure:
84 """Plot protocol timing diagram in wavedrom style.
86 Creates a timing diagram showing digital signals, clock edges, and
87 bus transactions with annotations for protocol events.
89 Args:
90 signals: List of ProtocolSignal objects to plot.
91 sample_rate: Sample rate in Hz.
92 time_range: Time range to plot (t_min, t_max) in seconds. None = full range.
93 time_unit: Time unit for x-axis ("s", "ms", "us", "ns", "auto").
94 style: Diagram style ("wavedrom" = clean digital, "classic" = traditional).
95 figsize: Figure size (width, height). Auto-calculated if None.
96 title: Plot title.
98 Returns:
99 Matplotlib Figure object.
101 Raises:
102 ImportError: If matplotlib is not available.
103 ValueError: If signals list is empty.
105 Example:
106 >>> sda = ProtocolSignal("SDA", sda_data, type="digital")
107 >>> scl = ProtocolSignal("SCL", scl_data, type="clock")
108 >>> fig = plot_protocol_timing(
109 ... [scl, sda],
110 ... sample_rate=1e6,
111 ... style="wavedrom",
112 ... title="I2C Transaction"
113 ... )
115 References:
116 VIS-021: Specialized - Protocol Timing Diagram
117 Wavedrom digital waveform rendering
118 """
119 if not HAS_MATPLOTLIB: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true
120 raise ImportError("matplotlib is required for visualization")
122 if len(signals) == 0:
123 raise ValueError("signals list cannot be empty")
125 n_signals = len(signals)
127 # Auto-calculate figure size
128 if figsize is None:
129 width = 12
130 height = max(4, n_signals * 0.8 + 1)
131 figsize = (width, height)
133 fig, axes = plt.subplots(
134 n_signals,
135 1,
136 figsize=figsize,
137 sharex=True,
138 gridspec_kw={"hspace": 0.1},
139 )
141 if n_signals == 1:
142 axes = [axes]
144 # Determine time range
145 if time_range is None:
146 max_len = max(len(sig.data) for sig in signals)
147 t_min = 0.0
148 t_max = max_len / sample_rate
149 else:
150 t_min, t_max = time_range
152 # Select time unit
153 if time_unit == "auto":
154 time_range_val = t_max - t_min
155 if time_range_val < 1e-6:
156 time_unit = "ns"
157 time_mult = 1e9
158 elif time_range_val < 1e-3:
159 time_unit = "us"
160 time_mult = 1e6
161 elif time_range_val < 1:
162 time_unit = "ms"
163 time_mult = 1e3
164 else:
165 time_unit = "s"
166 time_mult = 1.0
167 else:
168 time_mult = {"s": 1.0, "ms": 1e3, "us": 1e6, "ns": 1e9}.get(time_unit, 1.0)
170 # Plot each signal
171 for _idx, (signal, ax) in enumerate(zip(signals, axes, strict=False)):
172 # Create time vector
173 time = np.arange(len(signal.data)) / sample_rate * time_mult
175 # Filter to time range
176 mask = (time >= t_min * time_mult) & (time <= t_max * time_mult)
177 time = time[mask]
178 data = signal.data[mask]
180 if style == "wavedrom":
181 _plot_wavedrom_signal(ax, time, data, signal)
182 else:
183 _plot_classic_signal(ax, time, data, signal)
185 # Add signal name label
186 ax.set_ylabel(signal.name, rotation=0, ha="right", va="center", fontsize=10)
187 ax.set_ylim(-0.2, 1.3)
189 # Remove y-axis ticks
190 ax.set_yticks([])
192 # Grid for timing
193 ax.grid(True, axis="x", alpha=0.3, linestyle=":")
195 # Add annotations
196 if signal.annotations:
197 for t, text in signal.annotations.items():
198 if t_min <= t <= t_max:
199 ax.annotate(
200 text,
201 xy=(t * time_mult, 1.2),
202 fontsize=8,
203 ha="center",
204 bbox={"boxstyle": "round,pad=0.3", "facecolor": "yellow", "alpha": 0.7},
205 )
207 # X-axis label only on bottom plot
208 axes[-1].set_xlabel(f"Time ({time_unit})", fontsize=11)
210 if title:
211 fig.suptitle(title, fontsize=14, y=0.98)
213 fig.tight_layout()
214 return fig
217def _plot_wavedrom_signal(
218 ax: Axes,
219 time: NDArray[np.float64],
220 data: NDArray[np.float64],
221 signal: ProtocolSignal,
222) -> None:
223 """Plot signal in wavedrom style (clean digital waveform)."""
224 if signal.type == "clock":
225 # Clock signal: square wave
226 for i in range(len(time) - 1):
227 level = 1 if data[i] > 0.5 else 0
228 ax.plot(
229 [time[i], time[i + 1]],
230 [level, level],
231 "b-",
232 linewidth=1.5,
233 )
234 # Vertical transition
235 if i < len(time) - 1: 235 ↛ 226line 235 didn't jump to line 226 because the condition on line 235 was always true
236 next_level = 1 if data[i + 1] > 0.5 else 0
237 if level != next_level: 237 ↛ 226line 237 didn't jump to line 226 because the condition on line 237 was always true
238 ax.plot(
239 [time[i + 1], time[i + 1]],
240 [level, next_level],
241 "b-",
242 linewidth=1.5,
243 )
245 elif signal.type == "digital":
246 # Digital signal: step function with transitions
247 for i in range(len(time) - 1):
248 level = 1 if data[i] > 0.5 else 0
249 ax.plot(
250 [time[i], time[i + 1]],
251 [level, level],
252 "k-",
253 linewidth=1.5,
254 )
255 # Vertical transition with slight slant for visual clarity
256 if i < len(time) - 1: 256 ↛ 247line 256 didn't jump to line 247 because the condition on line 256 was always true
257 next_level = 1 if data[i + 1] > 0.5 else 0
258 if level != next_level:
259 transition_width = (time[i + 1] - time[i]) * 0.1
260 ax.plot(
261 [time[i + 1] - transition_width, time[i + 1]],
262 [level, next_level],
263 "k-",
264 linewidth=1.5,
265 )
267 elif signal.type == "bus":
268 # Bus signal: show as high-impedance or data values
269 ax.fill_between(time, 0.3, 0.7, alpha=0.3, color="gray")
270 ax.plot(time, np.full_like(time, 0.5), "k-", linewidth=0.5)
272 else:
273 # Analog signal
274 ax.plot(time, data, "r-", linewidth=1.2)
277def _plot_classic_signal(
278 ax: Axes,
279 time: NDArray[np.float64],
280 data: NDArray[np.float64],
281 signal: ProtocolSignal,
282) -> None:
283 """Plot signal in classic style (traditional oscilloscope-like)."""
284 ax.plot(time, data, "b-", linewidth=1.2)
285 ax.axhline(0.5, color="gray", linestyle="--", linewidth=0.5, alpha=0.5)
288def plot_state_machine(
289 states: list[str],
290 transitions: list[StateTransition],
291 *,
292 initial_state: str | None = None,
293 final_states: list[str] | None = None,
294 layout: Literal["circular", "hierarchical", "force"] = "circular",
295 figsize: tuple[float, float] = (10, 8),
296 title: str | None = None,
297) -> Figure:
298 """Plot state machine diagram.
300 Creates a state diagram showing states as nodes and transitions as
301 directed edges with condition labels.
303 Args:
304 states: List of state names.
305 transitions: List of StateTransition objects.
306 initial_state: Initial state (marked with double circle).
307 final_states: List of final states (marked with double circle).
308 layout: Layout algorithm for state positioning.
309 figsize: Figure size (width, height).
310 title: Plot title.
312 Returns:
313 Matplotlib Figure object.
315 Raises:
316 ImportError: If matplotlib is not available.
318 Example:
319 >>> states = ["IDLE", "ACTIVE", "WAIT", "DONE"]
320 >>> transitions = [
321 ... StateTransition("IDLE", "ACTIVE", "START"),
322 ... StateTransition("ACTIVE", "WAIT", "BUSY"),
323 ... StateTransition("WAIT", "ACTIVE", "RETRY"),
324 ... StateTransition("ACTIVE", "DONE", "COMPLETE"),
325 ... ]
326 >>> fig = plot_state_machine(
327 ... states, transitions, initial_state="IDLE", final_states=["DONE"]
328 ... )
330 References:
331 VIS-022: Specialized - State Machine View
332 """
333 if not HAS_MATPLOTLIB: 333 ↛ 334line 333 didn't jump to line 334 because the condition on line 333 was never true
334 raise ImportError("matplotlib is required for visualization")
336 fig, ax = plt.subplots(figsize=figsize)
338 # Calculate state positions using selected layout
339 positions = _calculate_state_positions(states, layout)
341 # Draw states as circles
342 state_radius = 0.15
344 for state, (x, y) in positions.items():
345 # Draw state circle
346 circle = patches.Circle(
347 (x, y),
348 state_radius,
349 fill=True,
350 facecolor="lightblue",
351 edgecolor="black",
352 linewidth=2.0,
353 )
354 ax.add_patch(circle)
356 # Mark initial state with double circle
357 if state == initial_state:
358 outer_circle = patches.Circle(
359 (x, y),
360 state_radius * 1.2,
361 fill=False,
362 edgecolor="black",
363 linewidth=2.0,
364 )
365 ax.add_patch(outer_circle)
367 # Mark final states with double circle
368 if final_states and state in final_states:
369 inner_circle = patches.Circle(
370 (x, y),
371 state_radius * 0.8,
372 fill=False,
373 edgecolor="black",
374 linewidth=2.0,
375 )
376 ax.add_patch(inner_circle)
378 # Add state label
379 ax.text(
380 x,
381 y,
382 state,
383 ha="center",
384 va="center",
385 fontsize=10,
386 fontweight="bold",
387 )
389 # Draw transitions as arrows
390 for trans in transitions:
391 if trans.from_state not in positions or trans.to_state not in positions:
392 continue
394 x1, y1 = positions[trans.from_state]
395 x2, y2 = positions[trans.to_state]
397 # Calculate arrow start/end on circle perimeter
398 dx = x2 - x1
399 dy = y2 - y1
400 dist = np.sqrt(dx**2 + dy**2)
402 if dist < 1e-6:
403 # Self-loop
404 _draw_self_loop(ax, x1, y1, state_radius, trans.condition)
405 continue
407 # Normalize
408 dx_norm = dx / dist
409 dy_norm = dy / dist
411 # Arrow start/end on circle edges
412 arrow_start_x = x1 + dx_norm * state_radius
413 arrow_start_y = y1 + dy_norm * state_radius
414 arrow_end_x = x2 - dx_norm * state_radius
415 arrow_end_y = y2 - dy_norm * state_radius
417 # Line style
418 linestyle = {
419 "solid": "-",
420 "dashed": "--",
421 "dotted": ":",
422 }.get(trans.style, "-")
424 # Draw arrow
425 ax.annotate(
426 "",
427 xy=(arrow_end_x, arrow_end_y),
428 xytext=(arrow_start_x, arrow_start_y),
429 arrowprops={
430 "arrowstyle": "->",
431 "lw": 1.5,
432 "linestyle": linestyle,
433 "color": "black",
434 },
435 )
437 # Add transition label
438 if trans.condition:
439 mid_x = (x1 + x2) / 2
440 mid_y = (y1 + y2) / 2
441 ax.text(
442 mid_x,
443 mid_y,
444 trans.condition,
445 fontsize=8,
446 ha="center",
447 bbox={
448 "boxstyle": "round,pad=0.3",
449 "facecolor": "white",
450 "edgecolor": "gray",
451 "alpha": 0.9,
452 },
453 )
455 # Set axis properties
456 ax.set_aspect("equal")
457 ax.axis("off")
458 ax.set_xlim(-0.2, 1.2)
459 ax.set_ylim(-0.2, 1.2)
461 if title:
462 ax.set_title(title, fontsize=14, pad=20)
464 fig.tight_layout()
465 return fig
468def _calculate_state_positions(
469 states: list[str],
470 layout: str,
471) -> dict[str, tuple[float, float]]:
472 """Calculate state positions using layout algorithm."""
473 n_states = len(states)
474 positions = {}
476 if layout == "circular":
477 # Arrange states in a circle
478 angle_step = 2 * np.pi / n_states
479 for i, state in enumerate(states):
480 angle = i * angle_step
481 x = 0.5 + 0.4 * np.cos(angle)
482 y = 0.5 + 0.4 * np.sin(angle)
483 positions[state] = (x, y)
485 elif layout == "hierarchical":
486 # Arrange in rows (simplified hierarchical)
487 states_per_row = int(np.ceil(np.sqrt(n_states)))
488 for i, state in enumerate(states):
489 row = i // states_per_row
490 col = i % states_per_row
491 x = (col + 0.5) / states_per_row
492 y = 1.0 - (row + 0.5) / np.ceil(n_states / states_per_row)
493 positions[state] = (x, y)
495 else: # force-directed (simplified)
496 # Use random positions as a placeholder for true force-directed layout
497 np.random.seed(42)
498 for i, state in enumerate(states): # noqa: B007
499 x = 0.2 + 0.6 * np.random.rand()
500 y = 0.2 + 0.6 * np.random.rand()
501 positions[state] = (x, y)
503 return positions
506def _draw_self_loop(
507 ax: Axes,
508 x: float,
509 y: float,
510 radius: float,
511 label: str,
512) -> None:
513 """Draw self-loop transition on state."""
514 # Draw arc above state
515 arc = patches.Arc(
516 (x, y + radius),
517 width=radius * 1.5,
518 height=radius * 1.5,
519 angle=0,
520 theta1=0,
521 theta2=180,
522 linewidth=1.5,
523 edgecolor="black",
524 fill=False,
525 )
526 ax.add_patch(arc)
528 # Add arrow head
529 ax.annotate(
530 "",
531 xy=(x - radius * 0.7, y + radius * 0.3),
532 xytext=(x - radius * 0.5, y + radius * 0.5),
533 arrowprops={"arrowstyle": "->", "lw": 1.5, "color": "black"},
534 )
536 # Add label
537 if label: 537 ↛ exitline 537 didn't return from function '_draw_self_loop' because the condition on line 537 was always true
538 ax.text(
539 x,
540 y + radius * 2.2,
541 label,
542 fontsize=8,
543 ha="center",
544 bbox={"boxstyle": "round,pad=0.2", "facecolor": "white", "alpha": 0.9},
545 )
548__all__ = [
549 "ProtocolSignal",
550 "StateTransition",
551 "plot_protocol_timing",
552 "plot_state_machine",
553]