Coverage for src / tracekit / analyzers / protocols / base.py: 73%
126 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Protocol decoder base class (sigrok-inspired).
3This module provides the base class for protocol decoders,
4following a sigrok-inspired API for consistency with the
5open-source protocol decoding ecosystem.
8Example:
9 >>> from tracekit.analyzers.protocols.base import ProtocolDecoder
10 >>> class UARTDecoder(ProtocolDecoder):
11 ... id = "uart"
12 ... name = "UART"
13 ... channels = [{"id": "rx", "name": "RX", "desc": "Receive data"}]
14 ... def decode(self, trace):
15 ... # Implementation
16 ... pass
18References:
19 sigrok Protocol Decoder API: https://sigrok.org/wiki/Protocol_decoder_API
20"""
22from __future__ import annotations
24from abc import ABC, abstractmethod
25from dataclasses import dataclass, field
26from enum import IntEnum
27from typing import TYPE_CHECKING, Any
29import numpy as np
31from tracekit.core.types import DigitalTrace, ProtocolPacket
33if TYPE_CHECKING:
34 from collections.abc import Iterator
36 from numpy.typing import NDArray
39class AnnotationLevel(IntEnum):
40 """Annotation hierarchy levels.
42 Protocol decoders use multiple annotation levels for different
43 levels of detail, from raw bits to high-level interpretations.
44 """
46 BITS = 0 # Raw bit values
47 BYTES = 1 # Byte values
48 WORDS = 2 # Words/frames
49 FIELDS = 3 # Named fields
50 PACKETS = 4 # Complete packets
51 MESSAGES = 5 # High-level messages
54@dataclass
55class Annotation:
56 """Protocol annotation at a specific time range.
58 Attributes:
59 start_time: Start time in seconds.
60 end_time: End time in seconds.
61 level: Annotation level (bits, bytes, packets, etc.).
62 text: Human-readable annotation text.
63 data: Raw data associated with annotation.
64 metadata: Additional annotation metadata.
65 """
67 start_time: float
68 end_time: float
69 level: AnnotationLevel
70 text: str
71 data: bytes | None = None
72 metadata: dict[str, Any] = field(default_factory=dict)
75@dataclass
76class ChannelDef:
77 """Channel definition for protocol decoder.
79 Attributes:
80 id: Channel identifier (e.g., "tx", "rx", "clk").
81 name: Human-readable name.
82 desc: Description of channel purpose.
83 required: Whether channel is required.
84 """
86 id: str
87 name: str
88 desc: str = ""
89 required: bool = True
92@dataclass
93class OptionDef:
94 """Option definition for protocol decoder.
96 Attributes:
97 id: Option identifier.
98 name: Human-readable name.
99 desc: Description.
100 default: Default value.
101 values: List of valid values (if enumerated).
102 """
104 id: str
105 name: str
106 desc: str = ""
107 default: Any = None
108 values: list[Any] | None = None
111class DecoderState:
112 """Base class for decoder state machines.
114 Protocol decoders can subclass this to track their internal state
115 during frame/packet decoding.
116 """
118 def __init__(self) -> None:
119 """Initialize decoder state."""
120 self.reset()
122 def reset(self) -> None:
123 """Reset state to initial values."""
124 pass
127class ProtocolDecoder(ABC):
128 """Base class for protocol decoders.
130 Provides sigrok-inspired API for implementing protocol decoders
131 that convert digital traces to decoded protocol packets.
133 Class Attributes:
134 api_version: Protocol decoder API version.
135 id: Unique decoder identifier.
136 name: Human-readable decoder name.
137 longname: Full name with description.
138 desc: Short description.
139 license: License identifier.
140 inputs: Required input types (e.g., ["logic"]).
141 outputs: Output types produced.
142 channels: Required channel definitions.
143 optional_channels: Optional channel definitions.
144 options: Configurable options.
145 annotations: Annotation type definitions.
147 Example:
148 >>> class SPIDecoder(ProtocolDecoder):
149 ... id = "spi"
150 ... name = "SPI"
151 ... channels = [
152 ... ChannelDef("clk", "CLK", "Clock"),
153 ... ChannelDef("mosi", "MOSI", "Master Out Slave In"),
154 ... ChannelDef("miso", "MISO", "Master In Slave Out"),
155 ... ]
156 ... optional_channels = [
157 ... ChannelDef("cs", "CS#", "Chip Select", required=False),
158 ... ]
159 ... options = [
160 ... OptionDef("cpol", "Clock Polarity", default=0, values=[0, 1]),
161 ... OptionDef("cpha", "Clock Phase", default=0, values=[0, 1]),
162 ... ]
163 """
165 # API version
166 api_version: int = 3
168 # Decoder identification
169 id: str = "unknown"
170 name: str = "Unknown"
171 longname: str = ""
172 desc: str = ""
173 license: str = "MIT"
175 # Input/output types
176 inputs: list[str] = ["logic"] # noqa: RUF012
177 outputs: list[str] = ["packets"] # noqa: RUF012
179 # Channel definitions
180 channels: list[ChannelDef] = [] # noqa: RUF012
181 optional_channels: list[ChannelDef] = [] # noqa: RUF012
183 # Options
184 options: list[OptionDef] = [] # noqa: RUF012
186 # Annotation definitions (override in subclass)
187 annotations: list[tuple[str, str]] = [] # noqa: RUF012
189 def __init__(self, **options: Any) -> None:
190 """Initialize decoder with options.
192 Args:
193 **options: Decoder-specific options.
195 Raises:
196 ValueError: If unknown option is provided
197 """
198 self._options: dict[str, Any] = {}
199 self._annotations: list[Annotation] = []
200 self._packets: list[ProtocolPacket] = []
201 self._state = DecoderState()
203 # Set default options
204 for opt in self.options:
205 self._options[opt.id] = opt.default
207 # Override with provided options
208 for key, value in options.items():
209 if any(opt.id == key for opt in self.options): 209 ↛ 212line 209 didn't jump to line 212 because the condition on line 209 was always true
210 self._options[key] = value
211 else:
212 raise ValueError(f"Unknown option: {key}")
214 def get_option(self, name: str) -> Any:
215 """Get option value.
217 Args:
218 name: Option name.
220 Returns:
221 Option value.
222 """
223 return self._options.get(name)
225 def set_option(self, name: str, value: Any) -> None:
226 """Set option value.
228 Args:
229 name: Option name.
230 value: New value.
231 """
232 self._options[name] = value
234 def reset(self) -> None:
235 """Reset decoder state.
237 Clears all accumulated annotations and packets, and resets
238 the internal state machine to initial state.
239 """
240 self._annotations.clear()
241 self._packets.clear()
242 self._state.reset()
244 def put_annotation(
245 self,
246 start_time: float,
247 end_time: float,
248 level: AnnotationLevel,
249 text: str,
250 data: bytes | None = None,
251 **metadata: Any,
252 ) -> None:
253 """Add an annotation.
255 Args:
256 start_time: Start time in seconds.
257 end_time: End time in seconds.
258 level: Annotation level.
259 text: Annotation text.
260 data: Associated binary data.
261 **metadata: Additional metadata.
262 """
263 self._annotations.append(
264 Annotation(
265 start_time=start_time,
266 end_time=end_time,
267 level=level,
268 text=text,
269 data=data,
270 metadata=metadata,
271 )
272 )
274 def put_packet(
275 self,
276 timestamp: float,
277 data: bytes,
278 annotations: dict[str, Any] | None = None,
279 errors: list[str] | None = None,
280 ) -> None:
281 """Add a decoded packet.
283 Args:
284 timestamp: Packet start time.
285 data: Decoded data bytes.
286 annotations: Packet annotations.
287 errors: Detected errors.
288 """
289 self._packets.append(
290 ProtocolPacket(
291 timestamp=timestamp,
292 protocol=self.id,
293 data=data,
294 annotations=annotations or {},
295 errors=errors or [],
296 )
297 )
299 @abstractmethod
300 def decode(
301 self,
302 trace: DigitalTrace,
303 **channels: NDArray[np.bool_],
304 ) -> Iterator[ProtocolPacket]:
305 """Decode a digital trace.
307 This is the main entry point for decoding. Implementations should
308 yield ProtocolPacket objects as they are decoded.
310 Args:
311 trace: Primary input trace.
312 **channels: Additional channel data by name.
314 Yields:
315 Decoded protocol packets.
317 Example:
318 >>> decoder = UARTDecoder(baudrate=115200)
319 >>> for packet in decoder.decode(trace):
320 ... print(f"Data: {packet.data.hex()}")
321 """
322 pass
324 def get_annotations(
325 self,
326 *,
327 level: AnnotationLevel | None = None,
328 start_time: float | None = None,
329 end_time: float | None = None,
330 ) -> list[Annotation]:
331 """Get accumulated annotations.
333 Args:
334 level: Filter by annotation level.
335 start_time: Filter by start time (inclusive).
336 end_time: Filter by end time (inclusive).
338 Returns:
339 List of matching annotations.
340 """
341 result = self._annotations
343 if level is not None: 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true
344 result = [a for a in result if a.level == level]
346 if start_time is not None: 346 ↛ 347line 346 didn't jump to line 347 because the condition on line 346 was never true
347 result = [a for a in result if a.end_time >= start_time]
349 if end_time is not None: 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true
350 result = [a for a in result if a.start_time <= end_time]
352 return result
354 def get_packets(self) -> list[ProtocolPacket]:
355 """Get all decoded packets.
357 Returns:
358 List of decoded packets.
359 """
360 return list(self._packets)
362 @classmethod
363 def get_channel_ids(cls, include_optional: bool = False) -> list[str]:
364 """Get list of channel IDs.
366 Args:
367 include_optional: Include optional channels.
369 Returns:
370 List of channel ID strings.
371 """
372 ids = [ch.id for ch in cls.channels]
373 if include_optional:
374 ids.extend(ch.id for ch in cls.optional_channels)
375 return ids
377 @classmethod
378 def get_option_ids(cls) -> list[str]:
379 """Get list of option IDs.
381 Returns:
382 List of option ID strings.
383 """
384 return [opt.id for opt in cls.options]
387class SyncDecoder(ProtocolDecoder):
388 """Base class for synchronous protocol decoders.
390 Synchronous protocols use a clock signal for timing. This base class
391 provides helpers for clock edge detection and data sampling.
392 """
394 def sample_on_edge(
395 self,
396 clock: NDArray[np.bool_],
397 data: NDArray[np.bool_],
398 edge: str = "rising",
399 ) -> NDArray[np.bool_]:
400 """Sample data on clock edges.
402 Args:
403 clock: Clock signal.
404 data: Data signal.
405 edge: "rising" or "falling".
407 Returns:
408 Data values at clock edges.
409 """
410 if edge == "rising":
411 edges = np.where(~clock[:-1] & clock[1:])[0]
412 else:
413 edges = np.where(clock[:-1] & ~clock[1:])[0]
415 # Sample data after edge (shifted by 1)
416 sample_indices = edges + 1
417 sample_indices = sample_indices[sample_indices < len(data)]
419 result: NDArray[np.bool_] = data[sample_indices]
420 return result
423class AsyncDecoder(ProtocolDecoder):
424 """Base class for asynchronous protocol decoders.
426 Asynchronous protocols (like UART) use timing-based sampling without
427 a separate clock signal. This base class provides helpers for
428 bit-timing and symbol detection.
429 """
431 def __init__(self, baudrate: int = 9600, **options: Any) -> None:
432 """Initialize async decoder.
434 Args:
435 baudrate: Bit rate in bps.
436 **options: Additional options.
437 """
438 super().__init__(**options)
439 self._baudrate = baudrate
441 @property
442 def baudrate(self) -> int:
443 """Get baud rate."""
444 return self._baudrate
446 @baudrate.setter
447 def baudrate(self, value: int) -> None:
448 """Set baud rate."""
449 self._baudrate = value
451 def bit_time(self, sample_rate: float) -> float:
452 """Get bit time in samples.
454 Args:
455 sample_rate: Sample rate in Hz.
457 Returns:
458 Number of samples per bit.
459 """
460 return sample_rate / self._baudrate
462 def find_start_bit(
463 self,
464 data: NDArray[np.bool_],
465 start_idx: int = 0,
466 idle_high: bool = True,
467 ) -> int | None:
468 """Find start bit transition.
470 Args:
471 data: Digital signal.
472 start_idx: Start search index.
473 idle_high: True if idle is high (standard UART).
475 Returns:
476 Index of start bit, or None if not found.
477 """
478 search_region = data[start_idx:]
480 if idle_high:
481 # Look for falling edge (high to low)
482 transitions = np.where(search_region[:-1] & ~search_region[1:])[0]
483 else:
484 # Look for rising edge (low to high)
485 transitions = np.where(~search_region[:-1] & search_region[1:])[0]
487 if len(transitions) == 0:
488 return None
490 return int(start_idx + transitions[0])
493__all__ = [
494 "Annotation",
495 "AnnotationLevel",
496 "AsyncDecoder",
497 "ChannelDef",
498 "DecoderState",
499 "OptionDef",
500 "ProtocolDecoder",
501 "SyncDecoder",
502]