Coverage for src / tracekit / cli / decode.py: 95%
179 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""TraceKit Decode Command implementing CLI-003.
3Provides CLI for protocol decoding with automatic protocol detection and
4error highlighting.
7Example:
8 $ tracekit decode serial_capture.wfm
9 $ tracekit decode i2c_bus.wfm --protocol I2C
10 $ tracekit decode uart.wfm --protocol UART --baud-rate 115200
11"""
13from __future__ import annotations
15import logging
16from pathlib import Path
17from typing import Any
19import click
20import numpy as np
22from tracekit.cli.main import format_output
23from tracekit.core.types import DigitalTrace, ProtocolPacket, WaveformTrace
25logger = logging.getLogger("tracekit.cli.decode")
28@click.command() # type: ignore[misc]
29@click.argument("file", type=click.Path(exists=True)) # type: ignore[misc]
30@click.option( # type: ignore[misc]
31 "--protocol",
32 type=click.Choice(["uart", "spi", "i2c", "can", "auto"], case_sensitive=False),
33 default="auto",
34 help="Protocol type (default: auto-detect).",
35)
36@click.option( # type: ignore[misc]
37 "--baud-rate",
38 type=int,
39 default=None,
40 help="Baud rate for UART (auto-detect if not specified).",
41)
42@click.option( # type: ignore[misc]
43 "--parity",
44 type=click.Choice(["none", "even", "odd"], case_sensitive=False),
45 default="none",
46 help="Parity for UART (default: none).",
47)
48@click.option( # type: ignore[misc]
49 "--stop-bits",
50 type=click.Choice(["1", "2"]),
51 default="1",
52 help="Stop bits for UART (default: 1).",
53)
54@click.option( # type: ignore[misc]
55 "--show-errors",
56 is_flag=True,
57 help="Show only errors with context.",
58)
59@click.option( # type: ignore[misc]
60 "--output",
61 type=click.Choice(["json", "csv", "html", "table"], case_sensitive=False),
62 default="table",
63 help="Output format (default: table).",
64)
65@click.pass_context # type: ignore[misc]
66def decode(
67 ctx: click.Context,
68 file: str,
69 protocol: str,
70 baud_rate: int | None,
71 parity: str,
72 stop_bits: str,
73 show_errors: bool,
74 output: str,
75) -> None:
76 """Decode serial protocol data.
78 Automatically detects and decodes common serial protocols (UART, SPI, I2C, CAN).
79 Can highlight errors with surrounding context for debugging.
81 Args:
82 ctx: Click context object.
83 file: Path to waveform file to decode.
84 protocol: Protocol type (uart, spi, i2c, can, auto).
85 baud_rate: Baud rate for UART (None for auto-detect).
86 parity: Parity setting for UART (none, even, odd).
87 stop_bits: Number of stop bits for UART (1 or 2).
88 show_errors: Show only packets with errors.
89 output: Output format (json, csv, html, table).
91 Raises:
92 Exception: If decoding fails or file cannot be loaded.
94 Examples:
96 \b
97 # Auto-detect and decode protocol
98 $ tracekit decode serial_capture.wfm
100 \b
101 # Decode specific protocol with parameters
102 $ tracekit decode uart.wfm \\
103 --protocol UART \\
104 --baud-rate 9600 \\
105 --parity even \\
106 --stop-bits 2
108 \b
109 # Show only errors for debugging
110 $ tracekit decode problematic.wfm --show-errors
112 \b
113 # Generate JSON output
114 $ tracekit decode i2c.wfm --protocol I2C --output json
115 """
116 verbose = ctx.obj.get("verbose", 0)
118 if verbose:
119 logger.info(f"Decoding: {file}")
120 logger.info(f"Protocol: {protocol}")
121 if protocol.lower() == "uart" and baud_rate:
122 logger.info(f"Baud rate: {baud_rate}")
124 try:
125 # Import here to avoid circular imports
126 from tracekit.loaders import load
128 # Load the trace
129 logger.debug(f"Loading trace from {file}")
130 trace = load(file)
132 # Perform protocol decoding
133 results = _perform_decoding(
134 trace=trace, # type: ignore[arg-type]
135 protocol=protocol,
136 baud_rate=baud_rate,
137 parity=parity,
138 stop_bits=int(stop_bits),
139 show_errors=show_errors,
140 )
142 # Add metadata
143 results["file"] = str(Path(file).name)
145 # Output results
146 formatted = format_output(results, output)
147 click.echo(formatted)
149 except Exception as e:
150 logger.error(f"Decoding failed: {e}")
151 if verbose > 1:
152 raise
153 click.echo(f"Error: {e}", err=True)
154 ctx.exit(1)
157def _to_digital(trace: WaveformTrace | DigitalTrace) -> DigitalTrace:
158 """Convert waveform trace to digital trace.
160 Args:
161 trace: Input trace (waveform or digital).
163 Returns:
164 Digital trace with boolean data.
165 """
166 if isinstance(trace, DigitalTrace):
167 return trace
169 # Use midpoint threshold for digitization
170 data = trace.data
171 threshold = (np.max(data) + np.min(data)) / 2
172 digital_data = data > threshold
174 return DigitalTrace(
175 data=digital_data,
176 metadata=trace.metadata,
177 )
180def _perform_decoding(
181 trace: WaveformTrace | DigitalTrace,
182 protocol: str,
183 baud_rate: int | None,
184 parity: str,
185 stop_bits: int,
186 show_errors: bool,
187) -> dict[str, Any]:
188 """Perform protocol decoding using actual decoders.
190 Args:
191 trace: Trace to decode.
192 protocol: Protocol type or 'auto'.
193 baud_rate: Optional baud rate for UART.
194 parity: Parity setting for UART.
195 stop_bits: Stop bits for UART.
196 show_errors: Whether to show only errors.
198 Returns:
199 Dictionary of decoding results.
200 """
201 # Import protocol decoders
202 from tracekit.inference.protocol import detect_protocol
204 sample_rate = trace.metadata.sample_rate
205 duration_ms = len(trace.data) / sample_rate * 1e3
207 results: dict[str, Any] = {
208 "sample_rate": f"{sample_rate / 1e6:.1f} MHz",
209 "samples": len(trace.data),
210 "duration": f"{duration_ms:.3f} ms",
211 }
213 # Auto-detect protocol if requested
214 detected_protocol = protocol
215 detection_confidence = 1.0
217 if protocol.lower() == "auto":
218 try:
219 detection = detect_protocol(trace, min_confidence=0.5, return_candidates=True) # type: ignore[arg-type]
220 detected_protocol = detection["protocol"].lower()
221 detection_confidence = detection["confidence"]
222 results["auto_detection"] = {
223 "protocol": detection["protocol"],
224 "confidence": f"{detection_confidence:.1%}",
225 "candidates": [
226 {"protocol": c["protocol"], "confidence": f"{c['confidence']:.1%}"}
227 for c in detection.get("candidates", [])[:3]
228 ],
229 }
230 # Extract config suggestions
231 if "config" in detection: 231 ↛ 239line 231 didn't jump to line 239 because the condition on line 231 was always true
232 if detected_protocol == "uart" and baud_rate is None: 232 ↛ 239line 232 didn't jump to line 239 because the condition on line 232 was always true
233 baud_rate = detection["config"].get("baud_rate")
234 except Exception as e:
235 logger.warning(f"Auto-detection failed: {e}, defaulting to UART")
236 detected_protocol = "uart"
237 detection_confidence = 0.0
239 results["protocol"] = detected_protocol.upper()
241 # Convert to digital trace for decoding
242 digital_trace = _to_digital(trace)
244 # Decode based on protocol
245 packets: list[ProtocolPacket] = []
246 errors: list[dict[str, Any]] = []
248 if detected_protocol == "uart":
249 packets, errors, protocol_info = _decode_uart(
250 digital_trace, baud_rate, parity, stop_bits, show_errors
251 )
252 results.update(protocol_info)
254 elif detected_protocol == "spi":
255 packets, errors, protocol_info = _decode_spi(digital_trace, show_errors)
256 results.update(protocol_info)
258 elif detected_protocol == "i2c":
259 packets, errors, protocol_info = _decode_i2c(digital_trace, show_errors)
260 results.update(protocol_info)
262 elif detected_protocol == "can": 262 ↛ 267line 262 didn't jump to line 267 because the condition on line 262 was always true
263 packets, errors, protocol_info = _decode_can(digital_trace, baud_rate, show_errors)
264 results.update(protocol_info)
266 # Filter to errors only if requested
267 if show_errors:
268 packets = [p for p in packets if p.errors]
270 # Summarize results
271 results["packets_decoded"] = len(packets)
272 results["errors_found"] = len(errors)
274 # Add packet details
275 results["packets"] = [
276 {
277 "index": i,
278 "timestamp": f"{p.timestamp * 1e3:.3f} ms",
279 "data": p.data.hex() if p.data else "",
280 "errors": p.errors,
281 **{k: v for k, v in (p.annotations or {}).items() if k != "data_bits"},
282 }
283 for i, p in enumerate(packets[:100]) # Limit to first 100 packets
284 ]
286 if len(packets) > 100:
287 results["note"] = f"Showing first 100 of {len(packets)} packets"
289 # Add error details if any
290 if errors:
291 results["error_details"] = errors[:20] # Limit to first 20 errors
293 return results
296def _decode_uart(
297 trace: DigitalTrace,
298 baud_rate: int | None,
299 parity: str,
300 stop_bits: int,
301 show_errors: bool,
302) -> tuple[list[ProtocolPacket], list[dict[str, Any]], dict[str, Any]]:
303 """Decode UART protocol.
305 Args:
306 trace: Digital trace to decode.
307 baud_rate: Baud rate (None for auto-detect).
308 parity: Parity mode.
309 stop_bits: Number of stop bits.
310 show_errors: Whether to filter to errors only.
312 Returns:
313 Tuple of (packets, errors, protocol_info).
314 """
315 from tracekit.analyzers.protocols.uart import UARTDecoder
317 # Create decoder with parameters
318 decoder = UARTDecoder(
319 baudrate=baud_rate or 0, # 0 triggers auto-detection
320 data_bits=8,
321 parity=parity, # type: ignore[arg-type]
322 stop_bits=stop_bits,
323 )
325 packets = list(decoder.decode(trace))
326 errors = []
328 # Extract errors from packets
329 for i, pkt in enumerate(packets):
330 if pkt.errors:
331 for err in pkt.errors:
332 errors.append(
333 {
334 "packet_index": i,
335 "timestamp": f"{pkt.timestamp * 1e3:.3f} ms",
336 "type": err,
337 "data": pkt.data.hex() if pkt.data else "",
338 }
339 )
341 # Determine actual baud rate used
342 actual_baud = decoder._baudrate if hasattr(decoder, "_baudrate") else baud_rate
344 protocol_info = {
345 "baud_rate": actual_baud,
346 "parity": parity,
347 "stop_bits": stop_bits,
348 "data_bits": 8,
349 }
351 return packets, errors, protocol_info
354def _decode_spi(
355 trace: DigitalTrace,
356 show_errors: bool,
357) -> tuple[list[ProtocolPacket], list[dict[str, Any]], dict[str, Any]]:
358 """Decode SPI protocol.
360 Note: SPI requires multiple signals (CLK, MOSI, optionally MISO, CS).
361 This function assumes the trace contains clock data and will attempt
362 to decode what's available.
364 Args:
365 trace: Digital trace to decode (assumed to be CLK or combined).
366 show_errors: Whether to filter to errors only.
368 Returns:
369 Tuple of (packets, errors, protocol_info).
370 """
371 from tracekit.analyzers.protocols.spi import SPIDecoder
373 # For single-channel decode, we can only analyze timing
374 # Full SPI decode requires separate CLK, MOSI, MISO channels
375 decoder = SPIDecoder(cpol=0, cpha=0, word_size=8)
377 # Create MOSI from the trace data (treating as data line)
378 clk = trace.data
379 mosi = trace.data # Same data for single-channel analysis
381 packets = list(decoder.decode(clk=clk, mosi=mosi, sample_rate=trace.metadata.sample_rate))
382 errors = []
384 for i, pkt in enumerate(packets):
385 if pkt.errors:
386 for err in pkt.errors:
387 errors.append(
388 {
389 "packet_index": i,
390 "timestamp": f"{pkt.timestamp * 1e3:.3f} ms",
391 "type": err,
392 }
393 )
395 # Estimate clock frequency from edge timing
396 edges = np.where(np.diff(clk.astype(int)) != 0)[0]
397 if len(edges) > 1:
398 avg_period = np.mean(np.diff(edges)) / trace.metadata.sample_rate
399 clock_freq = 1 / (2 * avg_period) if avg_period > 0 else 0
400 else:
401 clock_freq = 0
403 protocol_info = {
404 "clock_frequency": f"{clock_freq / 1e6:.2f} MHz" if clock_freq > 0 else "Unknown",
405 "mode": "0 (CPOL=0, CPHA=0)",
406 "word_size": 8,
407 "note": "Single-channel decode. For full SPI decode, provide separate CLK/MOSI/MISO signals.",
408 }
410 return packets, errors, protocol_info
413def _decode_i2c(
414 trace: DigitalTrace,
415 show_errors: bool,
416) -> tuple[list[ProtocolPacket], list[dict[str, Any]], dict[str, Any]]:
417 """Decode I2C protocol.
419 Note: I2C requires two signals (SCL, SDA). This function assumes
420 the trace is SDA and attempts to detect SCL from timing patterns.
422 Args:
423 trace: Digital trace to decode.
424 show_errors: Whether to filter to errors only.
426 Returns:
427 Tuple of (packets, errors, protocol_info).
428 """
429 from tracekit.analyzers.protocols.i2c import I2CDecoder
431 # For single-channel, assume it's SDA and create synthetic SCL from edges
432 sda = trace.data
433 sample_rate = trace.metadata.sample_rate
435 # Try to find clock pattern from edge timing
436 edges = np.where(np.diff(sda.astype(int)) != 0)[0]
437 if len(edges) < 20:
438 # Not enough edges for I2C
439 return [], [], {"error": "Insufficient edges for I2C decode"}
441 # Create synthetic SCL (toggle at each edge for analysis)
442 scl = np.ones_like(sda, dtype=bool)
443 for i, edge in enumerate(edges):
444 if i % 2 == 0 and edge + 1 < len(scl):
445 scl[edge : edge + 10] = False # Create clock pulses
447 decoder = I2CDecoder()
448 packets = list(decoder.decode(scl=scl, sda=sda, sample_rate=sample_rate))
449 errors = []
451 addresses_seen: set[int] = set()
452 for i, pkt in enumerate(packets):
453 addr = pkt.annotations.get("address", 0) if pkt.annotations else 0
454 addresses_seen.add(addr)
456 if pkt.errors:
457 for err in pkt.errors:
458 errors.append(
459 {
460 "packet_index": i,
461 "timestamp": f"{pkt.timestamp * 1e3:.3f} ms",
462 "type": err,
463 "address": f"0x{addr:02X}",
464 }
465 )
467 # Estimate clock rate from edge intervals
468 if len(edges) > 1: 468 ↛ 472line 468 didn't jump to line 472 because the condition on line 468 was always true
469 avg_interval = np.mean(np.diff(edges)) / sample_rate
470 clock_rate = 1 / (2 * avg_interval) if avg_interval > 0 else 0
471 else:
472 clock_rate = 0
474 protocol_info = {
475 "clock_frequency": f"{clock_rate / 1e3:.1f} kHz" if clock_rate > 0 else "Unknown",
476 "addresses_seen": [f"0x{a:02X}" for a in sorted(addresses_seen)],
477 "transactions": len(packets),
478 "note": "Single-channel decode. For accurate I2C decode, provide separate SCL/SDA signals.",
479 }
481 return packets, errors, protocol_info
484def _decode_can(
485 trace: DigitalTrace,
486 baud_rate: int | None,
487 show_errors: bool,
488) -> tuple[list[ProtocolPacket], list[dict[str, Any]], dict[str, Any]]:
489 """Decode CAN protocol.
491 Args:
492 trace: Digital trace to decode.
493 baud_rate: CAN bit rate (None for common rate detection).
494 show_errors: Whether to filter to errors only.
496 Returns:
497 Tuple of (packets, errors, protocol_info).
498 """
499 from tracekit.analyzers.protocols.can import CANDecoder
501 # Try common CAN baud rates if not specified
502 if baud_rate is None:
503 common_rates = [500000, 250000, 125000, 1000000]
504 best_rate = 500000
505 max_packets = 0
507 for rate in common_rates:
508 try:
509 decoder = CANDecoder(bitrate=rate)
510 test_packets = list(decoder.decode(trace))
511 if len(test_packets) > max_packets:
512 max_packets = len(test_packets)
513 best_rate = rate
514 except Exception:
515 continue
517 baud_rate = best_rate
519 decoder = CANDecoder(bitrate=baud_rate)
520 packets = list(decoder.decode(trace))
521 errors = []
523 arbitration_ids: set[int] = set()
524 for i, pkt in enumerate(packets):
525 arb_id = pkt.annotations.get("arbitration_id", 0) if pkt.annotations else 0
526 arbitration_ids.add(arb_id)
528 if pkt.errors: 528 ↛ 529line 528 didn't jump to line 529 because the condition on line 528 was never true
529 for err in pkt.errors:
530 errors.append(
531 {
532 "packet_index": i,
533 "timestamp": f"{pkt.timestamp * 1e3:.3f} ms",
534 "type": err,
535 "arbitration_id": f"0x{arb_id:03X}",
536 }
537 )
539 protocol_info = {
540 "bit_rate": f"{baud_rate / 1000:.0f} kbps",
541 "messages": len(packets),
542 "arbitration_ids": [f"0x{a:03X}" for a in sorted(arbitration_ids)[:10]],
543 "extended_frames": sum(
544 1 for p in packets if p.annotations and p.annotations.get("is_extended")
545 ),
546 }
548 if len(arbitration_ids) > 10:
549 protocol_info["note"] = f"Showing first 10 of {len(arbitration_ids)} arbitration IDs"
551 return packets, errors, protocol_info