Coverage for src / tracekit / analyzers / protocols / base.py: 73%

126 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Protocol decoder base class (sigrok-inspired). 

2 

3This module provides the base class for protocol decoders, 

4following a sigrok-inspired API for consistency with the 

5open-source protocol decoding ecosystem. 

6 

7 

8Example: 

9 >>> from tracekit.analyzers.protocols.base import ProtocolDecoder 

10 >>> class UARTDecoder(ProtocolDecoder): 

11 ... id = "uart" 

12 ... name = "UART" 

13 ... channels = [{"id": "rx", "name": "RX", "desc": "Receive data"}] 

14 ... def decode(self, trace): 

15 ... # Implementation 

16 ... pass 

17 

18References: 

19 sigrok Protocol Decoder API: https://sigrok.org/wiki/Protocol_decoder_API 

20""" 

21 

22from __future__ import annotations 

23 

24from abc import ABC, abstractmethod 

25from dataclasses import dataclass, field 

26from enum import IntEnum 

27from typing import TYPE_CHECKING, Any 

28 

29import numpy as np 

30 

31from tracekit.core.types import DigitalTrace, ProtocolPacket 

32 

33if TYPE_CHECKING: 

34 from collections.abc import Iterator 

35 

36 from numpy.typing import NDArray 

37 

38 

39class AnnotationLevel(IntEnum): 

40 """Annotation hierarchy levels. 

41 

42 Protocol decoders use multiple annotation levels for different 

43 levels of detail, from raw bits to high-level interpretations. 

44 """ 

45 

46 BITS = 0 # Raw bit values 

47 BYTES = 1 # Byte values 

48 WORDS = 2 # Words/frames 

49 FIELDS = 3 # Named fields 

50 PACKETS = 4 # Complete packets 

51 MESSAGES = 5 # High-level messages 

52 

53 

54@dataclass 

55class Annotation: 

56 """Protocol annotation at a specific time range. 

57 

58 Attributes: 

59 start_time: Start time in seconds. 

60 end_time: End time in seconds. 

61 level: Annotation level (bits, bytes, packets, etc.). 

62 text: Human-readable annotation text. 

63 data: Raw data associated with annotation. 

64 metadata: Additional annotation metadata. 

65 """ 

66 

67 start_time: float 

68 end_time: float 

69 level: AnnotationLevel 

70 text: str 

71 data: bytes | None = None 

72 metadata: dict[str, Any] = field(default_factory=dict) 

73 

74 

75@dataclass 

76class ChannelDef: 

77 """Channel definition for protocol decoder. 

78 

79 Attributes: 

80 id: Channel identifier (e.g., "tx", "rx", "clk"). 

81 name: Human-readable name. 

82 desc: Description of channel purpose. 

83 required: Whether channel is required. 

84 """ 

85 

86 id: str 

87 name: str 

88 desc: str = "" 

89 required: bool = True 

90 

91 

92@dataclass 

93class OptionDef: 

94 """Option definition for protocol decoder. 

95 

96 Attributes: 

97 id: Option identifier. 

98 name: Human-readable name. 

99 desc: Description. 

100 default: Default value. 

101 values: List of valid values (if enumerated). 

102 """ 

103 

104 id: str 

105 name: str 

106 desc: str = "" 

107 default: Any = None 

108 values: list[Any] | None = None 

109 

110 

111class DecoderState: 

112 """Base class for decoder state machines. 

113 

114 Protocol decoders can subclass this to track their internal state 

115 during frame/packet decoding. 

116 """ 

117 

118 def __init__(self) -> None: 

119 """Initialize decoder state.""" 

120 self.reset() 

121 

122 def reset(self) -> None: 

123 """Reset state to initial values.""" 

124 pass 

125 

126 

127class ProtocolDecoder(ABC): 

128 """Base class for protocol decoders. 

129 

130 Provides sigrok-inspired API for implementing protocol decoders 

131 that convert digital traces to decoded protocol packets. 

132 

133 Class Attributes: 

134 api_version: Protocol decoder API version. 

135 id: Unique decoder identifier. 

136 name: Human-readable decoder name. 

137 longname: Full name with description. 

138 desc: Short description. 

139 license: License identifier. 

140 inputs: Required input types (e.g., ["logic"]). 

141 outputs: Output types produced. 

142 channels: Required channel definitions. 

143 optional_channels: Optional channel definitions. 

144 options: Configurable options. 

145 annotations: Annotation type definitions. 

146 

147 Example: 

148 >>> class SPIDecoder(ProtocolDecoder): 

149 ... id = "spi" 

150 ... name = "SPI" 

151 ... channels = [ 

152 ... ChannelDef("clk", "CLK", "Clock"), 

153 ... ChannelDef("mosi", "MOSI", "Master Out Slave In"), 

154 ... ChannelDef("miso", "MISO", "Master In Slave Out"), 

155 ... ] 

156 ... optional_channels = [ 

157 ... ChannelDef("cs", "CS#", "Chip Select", required=False), 

158 ... ] 

159 ... options = [ 

160 ... OptionDef("cpol", "Clock Polarity", default=0, values=[0, 1]), 

161 ... OptionDef("cpha", "Clock Phase", default=0, values=[0, 1]), 

162 ... ] 

163 """ 

164 

165 # API version 

166 api_version: int = 3 

167 

168 # Decoder identification 

169 id: str = "unknown" 

170 name: str = "Unknown" 

171 longname: str = "" 

172 desc: str = "" 

173 license: str = "MIT" 

174 

175 # Input/output types 

176 inputs: list[str] = ["logic"] # noqa: RUF012 

177 outputs: list[str] = ["packets"] # noqa: RUF012 

178 

179 # Channel definitions 

180 channels: list[ChannelDef] = [] # noqa: RUF012 

181 optional_channels: list[ChannelDef] = [] # noqa: RUF012 

182 

183 # Options 

184 options: list[OptionDef] = [] # noqa: RUF012 

185 

186 # Annotation definitions (override in subclass) 

187 annotations: list[tuple[str, str]] = [] # noqa: RUF012 

188 

189 def __init__(self, **options: Any) -> None: 

190 """Initialize decoder with options. 

191 

192 Args: 

193 **options: Decoder-specific options. 

194 

195 Raises: 

196 ValueError: If unknown option is provided 

197 """ 

198 self._options: dict[str, Any] = {} 

199 self._annotations: list[Annotation] = [] 

200 self._packets: list[ProtocolPacket] = [] 

201 self._state = DecoderState() 

202 

203 # Set default options 

204 for opt in self.options: 

205 self._options[opt.id] = opt.default 

206 

207 # Override with provided options 

208 for key, value in options.items(): 

209 if any(opt.id == key for opt in self.options): 209 ↛ 212line 209 didn't jump to line 212 because the condition on line 209 was always true

210 self._options[key] = value 

211 else: 

212 raise ValueError(f"Unknown option: {key}") 

213 

214 def get_option(self, name: str) -> Any: 

215 """Get option value. 

216 

217 Args: 

218 name: Option name. 

219 

220 Returns: 

221 Option value. 

222 """ 

223 return self._options.get(name) 

224 

225 def set_option(self, name: str, value: Any) -> None: 

226 """Set option value. 

227 

228 Args: 

229 name: Option name. 

230 value: New value. 

231 """ 

232 self._options[name] = value 

233 

234 def reset(self) -> None: 

235 """Reset decoder state. 

236 

237 Clears all accumulated annotations and packets, and resets 

238 the internal state machine to initial state. 

239 """ 

240 self._annotations.clear() 

241 self._packets.clear() 

242 self._state.reset() 

243 

244 def put_annotation( 

245 self, 

246 start_time: float, 

247 end_time: float, 

248 level: AnnotationLevel, 

249 text: str, 

250 data: bytes | None = None, 

251 **metadata: Any, 

252 ) -> None: 

253 """Add an annotation. 

254 

255 Args: 

256 start_time: Start time in seconds. 

257 end_time: End time in seconds. 

258 level: Annotation level. 

259 text: Annotation text. 

260 data: Associated binary data. 

261 **metadata: Additional metadata. 

262 """ 

263 self._annotations.append( 

264 Annotation( 

265 start_time=start_time, 

266 end_time=end_time, 

267 level=level, 

268 text=text, 

269 data=data, 

270 metadata=metadata, 

271 ) 

272 ) 

273 

274 def put_packet( 

275 self, 

276 timestamp: float, 

277 data: bytes, 

278 annotations: dict[str, Any] | None = None, 

279 errors: list[str] | None = None, 

280 ) -> None: 

281 """Add a decoded packet. 

282 

283 Args: 

284 timestamp: Packet start time. 

285 data: Decoded data bytes. 

286 annotations: Packet annotations. 

287 errors: Detected errors. 

288 """ 

289 self._packets.append( 

290 ProtocolPacket( 

291 timestamp=timestamp, 

292 protocol=self.id, 

293 data=data, 

294 annotations=annotations or {}, 

295 errors=errors or [], 

296 ) 

297 ) 

298 

299 @abstractmethod 

300 def decode( 

301 self, 

302 trace: DigitalTrace, 

303 **channels: NDArray[np.bool_], 

304 ) -> Iterator[ProtocolPacket]: 

305 """Decode a digital trace. 

306 

307 This is the main entry point for decoding. Implementations should 

308 yield ProtocolPacket objects as they are decoded. 

309 

310 Args: 

311 trace: Primary input trace. 

312 **channels: Additional channel data by name. 

313 

314 Yields: 

315 Decoded protocol packets. 

316 

317 Example: 

318 >>> decoder = UARTDecoder(baudrate=115200) 

319 >>> for packet in decoder.decode(trace): 

320 ... print(f"Data: {packet.data.hex()}") 

321 """ 

322 pass 

323 

324 def get_annotations( 

325 self, 

326 *, 

327 level: AnnotationLevel | None = None, 

328 start_time: float | None = None, 

329 end_time: float | None = None, 

330 ) -> list[Annotation]: 

331 """Get accumulated annotations. 

332 

333 Args: 

334 level: Filter by annotation level. 

335 start_time: Filter by start time (inclusive). 

336 end_time: Filter by end time (inclusive). 

337 

338 Returns: 

339 List of matching annotations. 

340 """ 

341 result = self._annotations 

342 

343 if level is not None: 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true

344 result = [a for a in result if a.level == level] 

345 

346 if start_time is not None: 346 ↛ 347line 346 didn't jump to line 347 because the condition on line 346 was never true

347 result = [a for a in result if a.end_time >= start_time] 

348 

349 if end_time is not None: 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true

350 result = [a for a in result if a.start_time <= end_time] 

351 

352 return result 

353 

354 def get_packets(self) -> list[ProtocolPacket]: 

355 """Get all decoded packets. 

356 

357 Returns: 

358 List of decoded packets. 

359 """ 

360 return list(self._packets) 

361 

362 @classmethod 

363 def get_channel_ids(cls, include_optional: bool = False) -> list[str]: 

364 """Get list of channel IDs. 

365 

366 Args: 

367 include_optional: Include optional channels. 

368 

369 Returns: 

370 List of channel ID strings. 

371 """ 

372 ids = [ch.id for ch in cls.channels] 

373 if include_optional: 

374 ids.extend(ch.id for ch in cls.optional_channels) 

375 return ids 

376 

377 @classmethod 

378 def get_option_ids(cls) -> list[str]: 

379 """Get list of option IDs. 

380 

381 Returns: 

382 List of option ID strings. 

383 """ 

384 return [opt.id for opt in cls.options] 

385 

386 

387class SyncDecoder(ProtocolDecoder): 

388 """Base class for synchronous protocol decoders. 

389 

390 Synchronous protocols use a clock signal for timing. This base class 

391 provides helpers for clock edge detection and data sampling. 

392 """ 

393 

394 def sample_on_edge( 

395 self, 

396 clock: NDArray[np.bool_], 

397 data: NDArray[np.bool_], 

398 edge: str = "rising", 

399 ) -> NDArray[np.bool_]: 

400 """Sample data on clock edges. 

401 

402 Args: 

403 clock: Clock signal. 

404 data: Data signal. 

405 edge: "rising" or "falling". 

406 

407 Returns: 

408 Data values at clock edges. 

409 """ 

410 if edge == "rising": 

411 edges = np.where(~clock[:-1] & clock[1:])[0] 

412 else: 

413 edges = np.where(clock[:-1] & ~clock[1:])[0] 

414 

415 # Sample data after edge (shifted by 1) 

416 sample_indices = edges + 1 

417 sample_indices = sample_indices[sample_indices < len(data)] 

418 

419 result: NDArray[np.bool_] = data[sample_indices] 

420 return result 

421 

422 

423class AsyncDecoder(ProtocolDecoder): 

424 """Base class for asynchronous protocol decoders. 

425 

426 Asynchronous protocols (like UART) use timing-based sampling without 

427 a separate clock signal. This base class provides helpers for 

428 bit-timing and symbol detection. 

429 """ 

430 

431 def __init__(self, baudrate: int = 9600, **options: Any) -> None: 

432 """Initialize async decoder. 

433 

434 Args: 

435 baudrate: Bit rate in bps. 

436 **options: Additional options. 

437 """ 

438 super().__init__(**options) 

439 self._baudrate = baudrate 

440 

441 @property 

442 def baudrate(self) -> int: 

443 """Get baud rate.""" 

444 return self._baudrate 

445 

446 @baudrate.setter 

447 def baudrate(self, value: int) -> None: 

448 """Set baud rate.""" 

449 self._baudrate = value 

450 

451 def bit_time(self, sample_rate: float) -> float: 

452 """Get bit time in samples. 

453 

454 Args: 

455 sample_rate: Sample rate in Hz. 

456 

457 Returns: 

458 Number of samples per bit. 

459 """ 

460 return sample_rate / self._baudrate 

461 

462 def find_start_bit( 

463 self, 

464 data: NDArray[np.bool_], 

465 start_idx: int = 0, 

466 idle_high: bool = True, 

467 ) -> int | None: 

468 """Find start bit transition. 

469 

470 Args: 

471 data: Digital signal. 

472 start_idx: Start search index. 

473 idle_high: True if idle is high (standard UART). 

474 

475 Returns: 

476 Index of start bit, or None if not found. 

477 """ 

478 search_region = data[start_idx:] 

479 

480 if idle_high: 

481 # Look for falling edge (high to low) 

482 transitions = np.where(search_region[:-1] & ~search_region[1:])[0] 

483 else: 

484 # Look for rising edge (low to high) 

485 transitions = np.where(~search_region[:-1] & search_region[1:])[0] 

486 

487 if len(transitions) == 0: 

488 return None 

489 

490 return int(start_idx + transitions[0]) 

491 

492 

493__all__ = [ 

494 "Annotation", 

495 "AnnotationLevel", 

496 "AsyncDecoder", 

497 "ChannelDef", 

498 "DecoderState", 

499 "OptionDef", 

500 "ProtocolDecoder", 

501 "SyncDecoder", 

502]