Coverage for src / tracekit / cli / decode.py: 95%

179 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""TraceKit Decode Command implementing CLI-003. 

2 

3Provides CLI for protocol decoding with automatic protocol detection and 

4error highlighting. 

5 

6 

7Example: 

8 $ tracekit decode serial_capture.wfm 

9 $ tracekit decode i2c_bus.wfm --protocol I2C 

10 $ tracekit decode uart.wfm --protocol UART --baud-rate 115200 

11""" 

12 

13from __future__ import annotations 

14 

15import logging 

16from pathlib import Path 

17from typing import Any 

18 

19import click 

20import numpy as np 

21 

22from tracekit.cli.main import format_output 

23from tracekit.core.types import DigitalTrace, ProtocolPacket, WaveformTrace 

24 

25logger = logging.getLogger("tracekit.cli.decode") 

26 

27 

28@click.command() # type: ignore[misc] 

29@click.argument("file", type=click.Path(exists=True)) # type: ignore[misc] 

30@click.option( # type: ignore[misc] 

31 "--protocol", 

32 type=click.Choice(["uart", "spi", "i2c", "can", "auto"], case_sensitive=False), 

33 default="auto", 

34 help="Protocol type (default: auto-detect).", 

35) 

36@click.option( # type: ignore[misc] 

37 "--baud-rate", 

38 type=int, 

39 default=None, 

40 help="Baud rate for UART (auto-detect if not specified).", 

41) 

42@click.option( # type: ignore[misc] 

43 "--parity", 

44 type=click.Choice(["none", "even", "odd"], case_sensitive=False), 

45 default="none", 

46 help="Parity for UART (default: none).", 

47) 

48@click.option( # type: ignore[misc] 

49 "--stop-bits", 

50 type=click.Choice(["1", "2"]), 

51 default="1", 

52 help="Stop bits for UART (default: 1).", 

53) 

54@click.option( # type: ignore[misc] 

55 "--show-errors", 

56 is_flag=True, 

57 help="Show only errors with context.", 

58) 

59@click.option( # type: ignore[misc] 

60 "--output", 

61 type=click.Choice(["json", "csv", "html", "table"], case_sensitive=False), 

62 default="table", 

63 help="Output format (default: table).", 

64) 

65@click.pass_context # type: ignore[misc] 

66def decode( 

67 ctx: click.Context, 

68 file: str, 

69 protocol: str, 

70 baud_rate: int | None, 

71 parity: str, 

72 stop_bits: str, 

73 show_errors: bool, 

74 output: str, 

75) -> None: 

76 """Decode serial protocol data. 

77 

78 Automatically detects and decodes common serial protocols (UART, SPI, I2C, CAN). 

79 Can highlight errors with surrounding context for debugging. 

80 

81 Args: 

82 ctx: Click context object. 

83 file: Path to waveform file to decode. 

84 protocol: Protocol type (uart, spi, i2c, can, auto). 

85 baud_rate: Baud rate for UART (None for auto-detect). 

86 parity: Parity setting for UART (none, even, odd). 

87 stop_bits: Number of stop bits for UART (1 or 2). 

88 show_errors: Show only packets with errors. 

89 output: Output format (json, csv, html, table). 

90 

91 Raises: 

92 Exception: If decoding fails or file cannot be loaded. 

93 

94 Examples: 

95 

96 \b 

97 # Auto-detect and decode protocol 

98 $ tracekit decode serial_capture.wfm 

99 

100 \b 

101 # Decode specific protocol with parameters 

102 $ tracekit decode uart.wfm \\ 

103 --protocol UART \\ 

104 --baud-rate 9600 \\ 

105 --parity even \\ 

106 --stop-bits 2 

107 

108 \b 

109 # Show only errors for debugging 

110 $ tracekit decode problematic.wfm --show-errors 

111 

112 \b 

113 # Generate JSON output 

114 $ tracekit decode i2c.wfm --protocol I2C --output json 

115 """ 

116 verbose = ctx.obj.get("verbose", 0) 

117 

118 if verbose: 

119 logger.info(f"Decoding: {file}") 

120 logger.info(f"Protocol: {protocol}") 

121 if protocol.lower() == "uart" and baud_rate: 

122 logger.info(f"Baud rate: {baud_rate}") 

123 

124 try: 

125 # Import here to avoid circular imports 

126 from tracekit.loaders import load 

127 

128 # Load the trace 

129 logger.debug(f"Loading trace from {file}") 

130 trace = load(file) 

131 

132 # Perform protocol decoding 

133 results = _perform_decoding( 

134 trace=trace, # type: ignore[arg-type] 

135 protocol=protocol, 

136 baud_rate=baud_rate, 

137 parity=parity, 

138 stop_bits=int(stop_bits), 

139 show_errors=show_errors, 

140 ) 

141 

142 # Add metadata 

143 results["file"] = str(Path(file).name) 

144 

145 # Output results 

146 formatted = format_output(results, output) 

147 click.echo(formatted) 

148 

149 except Exception as e: 

150 logger.error(f"Decoding failed: {e}") 

151 if verbose > 1: 

152 raise 

153 click.echo(f"Error: {e}", err=True) 

154 ctx.exit(1) 

155 

156 

157def _to_digital(trace: WaveformTrace | DigitalTrace) -> DigitalTrace: 

158 """Convert waveform trace to digital trace. 

159 

160 Args: 

161 trace: Input trace (waveform or digital). 

162 

163 Returns: 

164 Digital trace with boolean data. 

165 """ 

166 if isinstance(trace, DigitalTrace): 

167 return trace 

168 

169 # Use midpoint threshold for digitization 

170 data = trace.data 

171 threshold = (np.max(data) + np.min(data)) / 2 

172 digital_data = data > threshold 

173 

174 return DigitalTrace( 

175 data=digital_data, 

176 metadata=trace.metadata, 

177 ) 

178 

179 

180def _perform_decoding( 

181 trace: WaveformTrace | DigitalTrace, 

182 protocol: str, 

183 baud_rate: int | None, 

184 parity: str, 

185 stop_bits: int, 

186 show_errors: bool, 

187) -> dict[str, Any]: 

188 """Perform protocol decoding using actual decoders. 

189 

190 Args: 

191 trace: Trace to decode. 

192 protocol: Protocol type or 'auto'. 

193 baud_rate: Optional baud rate for UART. 

194 parity: Parity setting for UART. 

195 stop_bits: Stop bits for UART. 

196 show_errors: Whether to show only errors. 

197 

198 Returns: 

199 Dictionary of decoding results. 

200 """ 

201 # Import protocol decoders 

202 from tracekit.inference.protocol import detect_protocol 

203 

204 sample_rate = trace.metadata.sample_rate 

205 duration_ms = len(trace.data) / sample_rate * 1e3 

206 

207 results: dict[str, Any] = { 

208 "sample_rate": f"{sample_rate / 1e6:.1f} MHz", 

209 "samples": len(trace.data), 

210 "duration": f"{duration_ms:.3f} ms", 

211 } 

212 

213 # Auto-detect protocol if requested 

214 detected_protocol = protocol 

215 detection_confidence = 1.0 

216 

217 if protocol.lower() == "auto": 

218 try: 

219 detection = detect_protocol(trace, min_confidence=0.5, return_candidates=True) # type: ignore[arg-type] 

220 detected_protocol = detection["protocol"].lower() 

221 detection_confidence = detection["confidence"] 

222 results["auto_detection"] = { 

223 "protocol": detection["protocol"], 

224 "confidence": f"{detection_confidence:.1%}", 

225 "candidates": [ 

226 {"protocol": c["protocol"], "confidence": f"{c['confidence']:.1%}"} 

227 for c in detection.get("candidates", [])[:3] 

228 ], 

229 } 

230 # Extract config suggestions 

231 if "config" in detection: 231 ↛ 239line 231 didn't jump to line 239 because the condition on line 231 was always true

232 if detected_protocol == "uart" and baud_rate is None: 232 ↛ 239line 232 didn't jump to line 239 because the condition on line 232 was always true

233 baud_rate = detection["config"].get("baud_rate") 

234 except Exception as e: 

235 logger.warning(f"Auto-detection failed: {e}, defaulting to UART") 

236 detected_protocol = "uart" 

237 detection_confidence = 0.0 

238 

239 results["protocol"] = detected_protocol.upper() 

240 

241 # Convert to digital trace for decoding 

242 digital_trace = _to_digital(trace) 

243 

244 # Decode based on protocol 

245 packets: list[ProtocolPacket] = [] 

246 errors: list[dict[str, Any]] = [] 

247 

248 if detected_protocol == "uart": 

249 packets, errors, protocol_info = _decode_uart( 

250 digital_trace, baud_rate, parity, stop_bits, show_errors 

251 ) 

252 results.update(protocol_info) 

253 

254 elif detected_protocol == "spi": 

255 packets, errors, protocol_info = _decode_spi(digital_trace, show_errors) 

256 results.update(protocol_info) 

257 

258 elif detected_protocol == "i2c": 

259 packets, errors, protocol_info = _decode_i2c(digital_trace, show_errors) 

260 results.update(protocol_info) 

261 

262 elif detected_protocol == "can": 262 ↛ 267line 262 didn't jump to line 267 because the condition on line 262 was always true

263 packets, errors, protocol_info = _decode_can(digital_trace, baud_rate, show_errors) 

264 results.update(protocol_info) 

265 

266 # Filter to errors only if requested 

267 if show_errors: 

268 packets = [p for p in packets if p.errors] 

269 

270 # Summarize results 

271 results["packets_decoded"] = len(packets) 

272 results["errors_found"] = len(errors) 

273 

274 # Add packet details 

275 results["packets"] = [ 

276 { 

277 "index": i, 

278 "timestamp": f"{p.timestamp * 1e3:.3f} ms", 

279 "data": p.data.hex() if p.data else "", 

280 "errors": p.errors, 

281 **{k: v for k, v in (p.annotations or {}).items() if k != "data_bits"}, 

282 } 

283 for i, p in enumerate(packets[:100]) # Limit to first 100 packets 

284 ] 

285 

286 if len(packets) > 100: 

287 results["note"] = f"Showing first 100 of {len(packets)} packets" 

288 

289 # Add error details if any 

290 if errors: 

291 results["error_details"] = errors[:20] # Limit to first 20 errors 

292 

293 return results 

294 

295 

296def _decode_uart( 

297 trace: DigitalTrace, 

298 baud_rate: int | None, 

299 parity: str, 

300 stop_bits: int, 

301 show_errors: bool, 

302) -> tuple[list[ProtocolPacket], list[dict[str, Any]], dict[str, Any]]: 

303 """Decode UART protocol. 

304 

305 Args: 

306 trace: Digital trace to decode. 

307 baud_rate: Baud rate (None for auto-detect). 

308 parity: Parity mode. 

309 stop_bits: Number of stop bits. 

310 show_errors: Whether to filter to errors only. 

311 

312 Returns: 

313 Tuple of (packets, errors, protocol_info). 

314 """ 

315 from tracekit.analyzers.protocols.uart import UARTDecoder 

316 

317 # Create decoder with parameters 

318 decoder = UARTDecoder( 

319 baudrate=baud_rate or 0, # 0 triggers auto-detection 

320 data_bits=8, 

321 parity=parity, # type: ignore[arg-type] 

322 stop_bits=stop_bits, 

323 ) 

324 

325 packets = list(decoder.decode(trace)) 

326 errors = [] 

327 

328 # Extract errors from packets 

329 for i, pkt in enumerate(packets): 

330 if pkt.errors: 

331 for err in pkt.errors: 

332 errors.append( 

333 { 

334 "packet_index": i, 

335 "timestamp": f"{pkt.timestamp * 1e3:.3f} ms", 

336 "type": err, 

337 "data": pkt.data.hex() if pkt.data else "", 

338 } 

339 ) 

340 

341 # Determine actual baud rate used 

342 actual_baud = decoder._baudrate if hasattr(decoder, "_baudrate") else baud_rate 

343 

344 protocol_info = { 

345 "baud_rate": actual_baud, 

346 "parity": parity, 

347 "stop_bits": stop_bits, 

348 "data_bits": 8, 

349 } 

350 

351 return packets, errors, protocol_info 

352 

353 

354def _decode_spi( 

355 trace: DigitalTrace, 

356 show_errors: bool, 

357) -> tuple[list[ProtocolPacket], list[dict[str, Any]], dict[str, Any]]: 

358 """Decode SPI protocol. 

359 

360 Note: SPI requires multiple signals (CLK, MOSI, optionally MISO, CS). 

361 This function assumes the trace contains clock data and will attempt 

362 to decode what's available. 

363 

364 Args: 

365 trace: Digital trace to decode (assumed to be CLK or combined). 

366 show_errors: Whether to filter to errors only. 

367 

368 Returns: 

369 Tuple of (packets, errors, protocol_info). 

370 """ 

371 from tracekit.analyzers.protocols.spi import SPIDecoder 

372 

373 # For single-channel decode, we can only analyze timing 

374 # Full SPI decode requires separate CLK, MOSI, MISO channels 

375 decoder = SPIDecoder(cpol=0, cpha=0, word_size=8) 

376 

377 # Create MOSI from the trace data (treating as data line) 

378 clk = trace.data 

379 mosi = trace.data # Same data for single-channel analysis 

380 

381 packets = list(decoder.decode(clk=clk, mosi=mosi, sample_rate=trace.metadata.sample_rate)) 

382 errors = [] 

383 

384 for i, pkt in enumerate(packets): 

385 if pkt.errors: 

386 for err in pkt.errors: 

387 errors.append( 

388 { 

389 "packet_index": i, 

390 "timestamp": f"{pkt.timestamp * 1e3:.3f} ms", 

391 "type": err, 

392 } 

393 ) 

394 

395 # Estimate clock frequency from edge timing 

396 edges = np.where(np.diff(clk.astype(int)) != 0)[0] 

397 if len(edges) > 1: 

398 avg_period = np.mean(np.diff(edges)) / trace.metadata.sample_rate 

399 clock_freq = 1 / (2 * avg_period) if avg_period > 0 else 0 

400 else: 

401 clock_freq = 0 

402 

403 protocol_info = { 

404 "clock_frequency": f"{clock_freq / 1e6:.2f} MHz" if clock_freq > 0 else "Unknown", 

405 "mode": "0 (CPOL=0, CPHA=0)", 

406 "word_size": 8, 

407 "note": "Single-channel decode. For full SPI decode, provide separate CLK/MOSI/MISO signals.", 

408 } 

409 

410 return packets, errors, protocol_info 

411 

412 

413def _decode_i2c( 

414 trace: DigitalTrace, 

415 show_errors: bool, 

416) -> tuple[list[ProtocolPacket], list[dict[str, Any]], dict[str, Any]]: 

417 """Decode I2C protocol. 

418 

419 Note: I2C requires two signals (SCL, SDA). This function assumes 

420 the trace is SDA and attempts to detect SCL from timing patterns. 

421 

422 Args: 

423 trace: Digital trace to decode. 

424 show_errors: Whether to filter to errors only. 

425 

426 Returns: 

427 Tuple of (packets, errors, protocol_info). 

428 """ 

429 from tracekit.analyzers.protocols.i2c import I2CDecoder 

430 

431 # For single-channel, assume it's SDA and create synthetic SCL from edges 

432 sda = trace.data 

433 sample_rate = trace.metadata.sample_rate 

434 

435 # Try to find clock pattern from edge timing 

436 edges = np.where(np.diff(sda.astype(int)) != 0)[0] 

437 if len(edges) < 20: 

438 # Not enough edges for I2C 

439 return [], [], {"error": "Insufficient edges for I2C decode"} 

440 

441 # Create synthetic SCL (toggle at each edge for analysis) 

442 scl = np.ones_like(sda, dtype=bool) 

443 for i, edge in enumerate(edges): 

444 if i % 2 == 0 and edge + 1 < len(scl): 

445 scl[edge : edge + 10] = False # Create clock pulses 

446 

447 decoder = I2CDecoder() 

448 packets = list(decoder.decode(scl=scl, sda=sda, sample_rate=sample_rate)) 

449 errors = [] 

450 

451 addresses_seen: set[int] = set() 

452 for i, pkt in enumerate(packets): 

453 addr = pkt.annotations.get("address", 0) if pkt.annotations else 0 

454 addresses_seen.add(addr) 

455 

456 if pkt.errors: 

457 for err in pkt.errors: 

458 errors.append( 

459 { 

460 "packet_index": i, 

461 "timestamp": f"{pkt.timestamp * 1e3:.3f} ms", 

462 "type": err, 

463 "address": f"0x{addr:02X}", 

464 } 

465 ) 

466 

467 # Estimate clock rate from edge intervals 

468 if len(edges) > 1: 468 ↛ 472line 468 didn't jump to line 472 because the condition on line 468 was always true

469 avg_interval = np.mean(np.diff(edges)) / sample_rate 

470 clock_rate = 1 / (2 * avg_interval) if avg_interval > 0 else 0 

471 else: 

472 clock_rate = 0 

473 

474 protocol_info = { 

475 "clock_frequency": f"{clock_rate / 1e3:.1f} kHz" if clock_rate > 0 else "Unknown", 

476 "addresses_seen": [f"0x{a:02X}" for a in sorted(addresses_seen)], 

477 "transactions": len(packets), 

478 "note": "Single-channel decode. For accurate I2C decode, provide separate SCL/SDA signals.", 

479 } 

480 

481 return packets, errors, protocol_info 

482 

483 

484def _decode_can( 

485 trace: DigitalTrace, 

486 baud_rate: int | None, 

487 show_errors: bool, 

488) -> tuple[list[ProtocolPacket], list[dict[str, Any]], dict[str, Any]]: 

489 """Decode CAN protocol. 

490 

491 Args: 

492 trace: Digital trace to decode. 

493 baud_rate: CAN bit rate (None for common rate detection). 

494 show_errors: Whether to filter to errors only. 

495 

496 Returns: 

497 Tuple of (packets, errors, protocol_info). 

498 """ 

499 from tracekit.analyzers.protocols.can import CANDecoder 

500 

501 # Try common CAN baud rates if not specified 

502 if baud_rate is None: 

503 common_rates = [500000, 250000, 125000, 1000000] 

504 best_rate = 500000 

505 max_packets = 0 

506 

507 for rate in common_rates: 

508 try: 

509 decoder = CANDecoder(bitrate=rate) 

510 test_packets = list(decoder.decode(trace)) 

511 if len(test_packets) > max_packets: 

512 max_packets = len(test_packets) 

513 best_rate = rate 

514 except Exception: 

515 continue 

516 

517 baud_rate = best_rate 

518 

519 decoder = CANDecoder(bitrate=baud_rate) 

520 packets = list(decoder.decode(trace)) 

521 errors = [] 

522 

523 arbitration_ids: set[int] = set() 

524 for i, pkt in enumerate(packets): 

525 arb_id = pkt.annotations.get("arbitration_id", 0) if pkt.annotations else 0 

526 arbitration_ids.add(arb_id) 

527 

528 if pkt.errors: 528 ↛ 529line 528 didn't jump to line 529 because the condition on line 528 was never true

529 for err in pkt.errors: 

530 errors.append( 

531 { 

532 "packet_index": i, 

533 "timestamp": f"{pkt.timestamp * 1e3:.3f} ms", 

534 "type": err, 

535 "arbitration_id": f"0x{arb_id:03X}", 

536 } 

537 ) 

538 

539 protocol_info = { 

540 "bit_rate": f"{baud_rate / 1000:.0f} kbps", 

541 "messages": len(packets), 

542 "arbitration_ids": [f"0x{a:03X}" for a in sorted(arbitration_ids)[:10]], 

543 "extended_frames": sum( 

544 1 for p in packets if p.annotations and p.annotations.get("is_extended") 

545 ), 

546 } 

547 

548 if len(arbitration_ids) > 10: 

549 protocol_info["note"] = f"Showing first 10 of {len(arbitration_ids)} arbitration IDs" 

550 

551 return packets, errors, protocol_info