Coverage for src / tracekit / workflows / protocol.py: 97%
146 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Protocol debug workflow.
3This module implements auto-detect protocol decoding with error context.
6Example:
7 >>> import tracekit as tk
8 >>> trace = tk.load('serial_capture.wfm')
9 >>> result = tk.debug_protocol(trace)
10 >>> print(f"Protocol: {result['protocol']}")
11 >>> print(f"Errors: {len(result['errors'])}")
13References:
14 UART: TIA-232-F
15 I2C: NXP UM10204
16 SPI: Motorola SPI Block Guide
17 CAN: ISO 11898
18"""
20from __future__ import annotations
22from typing import Any
24import numpy as np
26from tracekit.core.exceptions import AnalysisError
27from tracekit.core.types import DigitalTrace, ProtocolPacket, WaveformTrace
30def debug_protocol(
31 trace: WaveformTrace | DigitalTrace,
32 *,
33 protocol: str | None = None,
34 context_samples: int = 100,
35 error_types: list[str] | None = None,
36 decode_all: bool = False,
37) -> dict[str, Any]:
38 """Auto-detect and decode protocol with error context.
40 Automatically detects protocol type (UART, SPI, I2C, CAN) if not specified,
41 decodes packets, and highlights errors with surrounding context samples.
43 Args:
44 trace: Signal to decode.
45 protocol: Protocol type override ('UART', 'SPI', 'I2C', 'CAN', 'auto').
46 If None or 'auto', auto-detected.
47 context_samples: Number of samples to include before/after errors.
48 error_types: List of error types to detect. If None, detects all.
49 decode_all: If True, decode all packets. If False, focus on errors.
51 Returns:
52 Dictionary containing:
53 - protocol: Detected or specified protocol type
54 - baud_rate: Detected baud/clock rate (if applicable)
55 - packets: List of decoded ProtocolPacket objects
56 - errors: List of error dictionaries with context
57 - config: Protocol configuration used
58 - statistics: Decoding statistics (total packets, error count, etc.)
60 Raises:
61 AnalysisError: If protocol cannot be detected or decoded.
63 Example:
64 >>> trace = tk.load('uart_data.wfm')
65 >>> result = tk.debug_protocol(trace)
66 >>> print(f"Protocol: {result['protocol']}")
67 >>> print(f"Baud Rate: {result['baud_rate']}")
68 >>> for error in result['errors']:
69 ... print(f"Error at {error['timestamp']}: {error['type']}")
71 References:
72 sigrok Protocol Decoder API
73 UART: TIA-232-F (Serial communication)
74 I2C: NXP UM10204 (I2C-bus specification)
75 """
76 from tracekit.inference.protocol import detect_protocol
78 # Convert to digital if needed
79 digital_trace = _to_digital(trace)
81 # Auto-detect protocol if not specified
82 if protocol is None or protocol.lower() == "auto":
83 detection = detect_protocol(trace, min_confidence=0.5, return_candidates=True) # type: ignore[arg-type]
84 protocol = detection["protocol"]
85 config = detection["config"]
86 confidence = detection["confidence"]
87 else:
88 confidence = 1.0
89 config = _get_default_protocol_config(protocol.upper())
91 protocol = protocol.upper()
93 # Decode based on protocol type
94 packets: list[ProtocolPacket] = []
95 errors: list[dict[str, Any]] = []
97 if protocol == "UART":
98 packets, errors = _decode_uart(digital_trace, config, context_samples, error_types, trace)
99 elif protocol == "SPI":
100 packets, errors = _decode_spi(digital_trace, config, context_samples, error_types, trace)
101 elif protocol == "I2C":
102 packets, errors = _decode_i2c(digital_trace, config, context_samples, error_types, trace)
103 elif protocol == "CAN":
104 packets, errors = _decode_can(digital_trace, config, context_samples, error_types, trace)
105 else:
106 raise AnalysisError(f"Unsupported protocol: {protocol}")
108 # Filter to error packets if not decoding all
109 if not decode_all:
110 packets = [p for p in packets if p.errors]
112 # Calculate statistics
113 total_packets = len(packets)
114 error_count = len(errors)
115 statistics = {
116 "total_packets": total_packets,
117 "error_count": error_count,
118 "error_rate": error_count / total_packets if total_packets > 0 else 0,
119 "confidence": confidence,
120 }
122 result = {
123 "protocol": protocol,
124 "baud_rate": config.get("baud_rate") or config.get("clock_rate"),
125 "packets": packets,
126 "errors": errors,
127 "config": config,
128 "statistics": statistics,
129 }
131 return result
134def _to_digital(trace: WaveformTrace | DigitalTrace) -> DigitalTrace:
135 """Convert waveform trace to digital trace.
137 Args:
138 trace: Input trace.
140 Returns:
141 Digital trace.
142 """
143 if isinstance(trace, DigitalTrace):
144 return trace
146 data = trace.data
147 threshold = (np.max(data) + np.min(data)) / 2
148 digital_data = data > threshold
150 return DigitalTrace(
151 data=digital_data,
152 metadata=trace.metadata,
153 )
156def _get_default_protocol_config(protocol: str) -> dict[str, Any]:
157 """Get default configuration for a protocol.
159 Args:
160 protocol: Protocol name.
162 Returns:
163 Default configuration dictionary.
164 """
165 configs = {
166 "UART": {
167 "baud_rate": 115200,
168 "data_bits": 8,
169 "parity": "none",
170 "stop_bits": 1,
171 },
172 "SPI": {
173 "clock_polarity": 0,
174 "clock_phase": 0,
175 "bit_order": "MSB",
176 },
177 "I2C": {
178 "clock_rate": 100000, # Standard mode
179 "address_bits": 7,
180 },
181 "CAN": {
182 "baud_rate": 500000,
183 "sample_point": 0.75,
184 },
185 }
186 return configs.get(protocol, {}) # type: ignore[return-value]
189def _extract_context(
190 trace: WaveformTrace | DigitalTrace,
191 sample_idx: int,
192 context_samples: int,
193) -> WaveformTrace | DigitalTrace | None:
194 """Extract context samples around a point.
196 Args:
197 trace: Original trace.
198 sample_idx: Center sample index.
199 context_samples: Number of samples before and after.
201 Returns:
202 Sub-trace with context, or None if invalid.
203 """
204 data = trace.data
205 start = max(0, sample_idx - context_samples)
206 end = min(len(data), sample_idx + context_samples)
208 if end <= start:
209 return None
211 context_data = data[start:end]
213 if isinstance(trace, WaveformTrace):
214 return WaveformTrace(
215 data=context_data, # type: ignore[arg-type]
216 metadata=trace.metadata,
217 )
218 else:
219 return DigitalTrace(
220 data=context_data, # type: ignore[arg-type]
221 metadata=trace.metadata,
222 )
225def _decode_uart(
226 trace: DigitalTrace,
227 config: dict[str, Any],
228 context_samples: int,
229 error_types: list[str] | None,
230 original_trace: WaveformTrace | DigitalTrace,
231) -> tuple[list[ProtocolPacket], list[dict[str, Any]]]:
232 """Decode UART protocol with error context.
234 Args:
235 trace: Digital trace to decode.
236 config: UART configuration.
237 context_samples: Context window size.
238 error_types: Error types to detect.
239 original_trace: Original trace for context extraction.
241 Returns:
242 Tuple of (packets, errors).
243 """
244 from tracekit.analyzers.protocols.uart import UARTDecoder
246 baud_rate = config.get("baud_rate", 0)
247 data_bits = config.get("data_bits", 8)
248 parity = config.get("parity", "none")
249 stop_bits = config.get("stop_bits", 1)
251 decoder = UARTDecoder(
252 baudrate=baud_rate,
253 data_bits=data_bits,
254 parity=parity,
255 stop_bits=stop_bits,
256 )
258 packets = list(decoder.decode(trace))
259 errors: list[dict[str, Any]] = []
261 sample_rate = trace.metadata.sample_rate
263 for i, pkt in enumerate(packets):
264 if pkt.errors:
265 # Filter by error types if specified
266 relevant_errors = pkt.errors
267 if error_types:
268 relevant_errors = [
269 e for e in pkt.errors if any(t.lower() in e.lower() for t in error_types)
270 ]
272 for err in relevant_errors:
273 sample_idx = int(pkt.timestamp * sample_rate)
274 context_trace = _extract_context(original_trace, sample_idx, context_samples)
276 error = {
277 "type": err,
278 "timestamp": pkt.timestamp,
279 "packet_index": i,
280 "address": None,
281 "data": pkt.data,
282 "context": f"Samples {sample_idx - context_samples} to {sample_idx + context_samples}",
283 "context_trace": context_trace,
284 }
285 errors.append(error)
287 return packets, errors
290def _decode_spi(
291 trace: DigitalTrace,
292 config: dict[str, Any],
293 context_samples: int,
294 error_types: list[str] | None,
295 original_trace: WaveformTrace | DigitalTrace,
296) -> tuple[list[ProtocolPacket], list[dict[str, Any]]]:
297 """Decode SPI protocol with error context.
299 Args:
300 trace: Digital trace to decode.
301 config: SPI configuration.
302 context_samples: Context window size.
303 error_types: Error types to detect.
304 original_trace: Original trace for context extraction.
306 Returns:
307 Tuple of (packets, errors).
308 """
309 from tracekit.analyzers.protocols.spi import SPIDecoder
311 cpol = config.get("clock_polarity", 0)
312 cpha = config.get("clock_phase", 0)
313 word_size = config.get("word_size", 8)
314 bit_order = config.get("bit_order", "msb").lower()
316 decoder = SPIDecoder(
317 cpol=cpol,
318 cpha=cpha,
319 word_size=word_size,
320 bit_order=bit_order,
321 )
323 # For single-channel decode, use trace data as both clock and data
324 clk = trace.data
325 mosi = trace.data
327 packets = list(
328 decoder.decode(
329 clk=clk,
330 mosi=mosi,
331 sample_rate=trace.metadata.sample_rate,
332 )
333 )
335 errors: list[dict[str, Any]] = []
336 sample_rate = trace.metadata.sample_rate
338 for i, pkt in enumerate(packets):
339 if pkt.errors:
340 relevant_errors = pkt.errors
341 if error_types: 341 ↛ 342line 341 didn't jump to line 342 because the condition on line 341 was never true
342 relevant_errors = [
343 e for e in pkt.errors if any(t.lower() in e.lower() for t in error_types)
344 ]
346 for err in relevant_errors:
347 sample_idx = int(pkt.timestamp * sample_rate)
348 context_trace = _extract_context(original_trace, sample_idx, context_samples)
350 error = {
351 "type": err,
352 "timestamp": pkt.timestamp,
353 "packet_index": i,
354 "mosi_data": pkt.data,
355 "context": f"Samples {sample_idx - context_samples} to {sample_idx + context_samples}",
356 "context_trace": context_trace,
357 }
358 errors.append(error)
360 return packets, errors
363def _decode_i2c(
364 trace: DigitalTrace,
365 config: dict[str, Any],
366 context_samples: int,
367 error_types: list[str] | None,
368 original_trace: WaveformTrace | DigitalTrace,
369) -> tuple[list[ProtocolPacket], list[dict[str, Any]]]:
370 """Decode I2C protocol with error context.
372 Args:
373 trace: Digital trace to decode.
374 config: I2C configuration.
375 context_samples: Context window size.
376 error_types: Error types to detect.
377 original_trace: Original trace for context extraction.
379 Returns:
380 Tuple of (packets, errors).
381 """
382 from tracekit.analyzers.protocols.i2c import I2CDecoder
384 address_format = config.get("address_format", "auto")
386 # For single-channel, assume SDA and create synthetic SCL
387 sda = trace.data
388 sample_rate = trace.metadata.sample_rate
390 edges = np.where(np.diff(sda.astype(int)) != 0)[0]
392 if len(edges) < 20:
393 return [], []
395 # Create synthetic SCL
396 scl = np.ones_like(sda, dtype=bool)
397 for i, edge in enumerate(edges):
398 if i % 2 == 0 and edge + 10 < len(scl):
399 scl[edge : edge + 10] = False
401 decoder = I2CDecoder(address_format=address_format)
402 packets = list(decoder.decode(scl=scl, sda=sda, sample_rate=sample_rate))
404 errors: list[dict[str, Any]] = []
406 for i, pkt in enumerate(packets):
407 if pkt.errors:
408 relevant_errors = pkt.errors
409 if error_types: 409 ↛ 410line 409 didn't jump to line 410 because the condition on line 409 was never true
410 relevant_errors = [
411 e for e in pkt.errors if any(t.lower() in e.lower() for t in error_types)
412 ]
414 for err in relevant_errors:
415 sample_idx = int(pkt.timestamp * sample_rate)
416 context_trace = _extract_context(original_trace, sample_idx, context_samples)
418 addr = pkt.annotations.get("address", 0) if pkt.annotations else 0
420 error = {
421 "type": err,
422 "timestamp": pkt.timestamp,
423 "packet_index": i,
424 "address": addr,
425 "data": pkt.data,
426 "context": f"Samples {sample_idx - context_samples} to {sample_idx + context_samples}",
427 "context_trace": context_trace,
428 }
429 errors.append(error)
431 return packets, errors
434def _decode_can(
435 trace: DigitalTrace,
436 config: dict[str, Any],
437 context_samples: int,
438 error_types: list[str] | None,
439 original_trace: WaveformTrace | DigitalTrace,
440) -> tuple[list[ProtocolPacket], list[dict[str, Any]]]:
441 """Decode CAN protocol with error context.
443 Args:
444 trace: Digital trace to decode.
445 config: CAN configuration.
446 context_samples: Context window size.
447 error_types: Error types to detect.
448 original_trace: Original trace for context extraction.
450 Returns:
451 Tuple of (packets, errors).
452 """
453 from tracekit.analyzers.protocols.can import CANDecoder
455 bitrate = config.get("baud_rate", 500000)
456 sample_point = config.get("sample_point", 0.75)
458 decoder = CANDecoder(bitrate=bitrate, sample_point=sample_point)
459 packets = list(decoder.decode(trace))
461 errors: list[dict[str, Any]] = []
462 sample_rate = trace.metadata.sample_rate
464 for i, pkt in enumerate(packets):
465 if pkt.errors:
466 relevant_errors = pkt.errors
467 if error_types: 467 ↛ 468line 467 didn't jump to line 468 because the condition on line 467 was never true
468 relevant_errors = [
469 e for e in pkt.errors if any(t.lower() in e.lower() for t in error_types)
470 ]
472 for err in relevant_errors:
473 sample_idx = int(pkt.timestamp * sample_rate)
474 context_trace = _extract_context(original_trace, sample_idx, context_samples)
476 arb_id = pkt.annotations.get("arbitration_id", 0) if pkt.annotations else 0
478 error = {
479 "type": err,
480 "timestamp": pkt.timestamp,
481 "packet_index": i,
482 "arbitration_id": arb_id,
483 "data": pkt.data,
484 "context": f"Samples {sample_idx - context_samples} to {sample_idx + context_samples}",
485 "context_trace": context_trace,
486 }
487 errors.append(error)
489 return packets, errors
492__all__ = ["debug_protocol"]