Coverage for src / tracekit / workflows / protocol.py: 97%

146 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Protocol debug workflow. 

2 

3This module implements auto-detect protocol decoding with error context. 

4 

5 

6Example: 

7 >>> import tracekit as tk 

8 >>> trace = tk.load('serial_capture.wfm') 

9 >>> result = tk.debug_protocol(trace) 

10 >>> print(f"Protocol: {result['protocol']}") 

11 >>> print(f"Errors: {len(result['errors'])}") 

12 

13References: 

14 UART: TIA-232-F 

15 I2C: NXP UM10204 

16 SPI: Motorola SPI Block Guide 

17 CAN: ISO 11898 

18""" 

19 

20from __future__ import annotations 

21 

22from typing import Any 

23 

24import numpy as np 

25 

26from tracekit.core.exceptions import AnalysisError 

27from tracekit.core.types import DigitalTrace, ProtocolPacket, WaveformTrace 

28 

29 

30def debug_protocol( 

31 trace: WaveformTrace | DigitalTrace, 

32 *, 

33 protocol: str | None = None, 

34 context_samples: int = 100, 

35 error_types: list[str] | None = None, 

36 decode_all: bool = False, 

37) -> dict[str, Any]: 

38 """Auto-detect and decode protocol with error context. 

39 

40 Automatically detects protocol type (UART, SPI, I2C, CAN) if not specified, 

41 decodes packets, and highlights errors with surrounding context samples. 

42 

43 Args: 

44 trace: Signal to decode. 

45 protocol: Protocol type override ('UART', 'SPI', 'I2C', 'CAN', 'auto'). 

46 If None or 'auto', auto-detected. 

47 context_samples: Number of samples to include before/after errors. 

48 error_types: List of error types to detect. If None, detects all. 

49 decode_all: If True, decode all packets. If False, focus on errors. 

50 

51 Returns: 

52 Dictionary containing: 

53 - protocol: Detected or specified protocol type 

54 - baud_rate: Detected baud/clock rate (if applicable) 

55 - packets: List of decoded ProtocolPacket objects 

56 - errors: List of error dictionaries with context 

57 - config: Protocol configuration used 

58 - statistics: Decoding statistics (total packets, error count, etc.) 

59 

60 Raises: 

61 AnalysisError: If protocol cannot be detected or decoded. 

62 

63 Example: 

64 >>> trace = tk.load('uart_data.wfm') 

65 >>> result = tk.debug_protocol(trace) 

66 >>> print(f"Protocol: {result['protocol']}") 

67 >>> print(f"Baud Rate: {result['baud_rate']}") 

68 >>> for error in result['errors']: 

69 ... print(f"Error at {error['timestamp']}: {error['type']}") 

70 

71 References: 

72 sigrok Protocol Decoder API 

73 UART: TIA-232-F (Serial communication) 

74 I2C: NXP UM10204 (I2C-bus specification) 

75 """ 

76 from tracekit.inference.protocol import detect_protocol 

77 

78 # Convert to digital if needed 

79 digital_trace = _to_digital(trace) 

80 

81 # Auto-detect protocol if not specified 

82 if protocol is None or protocol.lower() == "auto": 

83 detection = detect_protocol(trace, min_confidence=0.5, return_candidates=True) # type: ignore[arg-type] 

84 protocol = detection["protocol"] 

85 config = detection["config"] 

86 confidence = detection["confidence"] 

87 else: 

88 confidence = 1.0 

89 config = _get_default_protocol_config(protocol.upper()) 

90 

91 protocol = protocol.upper() 

92 

93 # Decode based on protocol type 

94 packets: list[ProtocolPacket] = [] 

95 errors: list[dict[str, Any]] = [] 

96 

97 if protocol == "UART": 

98 packets, errors = _decode_uart(digital_trace, config, context_samples, error_types, trace) 

99 elif protocol == "SPI": 

100 packets, errors = _decode_spi(digital_trace, config, context_samples, error_types, trace) 

101 elif protocol == "I2C": 

102 packets, errors = _decode_i2c(digital_trace, config, context_samples, error_types, trace) 

103 elif protocol == "CAN": 

104 packets, errors = _decode_can(digital_trace, config, context_samples, error_types, trace) 

105 else: 

106 raise AnalysisError(f"Unsupported protocol: {protocol}") 

107 

108 # Filter to error packets if not decoding all 

109 if not decode_all: 

110 packets = [p for p in packets if p.errors] 

111 

112 # Calculate statistics 

113 total_packets = len(packets) 

114 error_count = len(errors) 

115 statistics = { 

116 "total_packets": total_packets, 

117 "error_count": error_count, 

118 "error_rate": error_count / total_packets if total_packets > 0 else 0, 

119 "confidence": confidence, 

120 } 

121 

122 result = { 

123 "protocol": protocol, 

124 "baud_rate": config.get("baud_rate") or config.get("clock_rate"), 

125 "packets": packets, 

126 "errors": errors, 

127 "config": config, 

128 "statistics": statistics, 

129 } 

130 

131 return result 

132 

133 

134def _to_digital(trace: WaveformTrace | DigitalTrace) -> DigitalTrace: 

135 """Convert waveform trace to digital trace. 

136 

137 Args: 

138 trace: Input trace. 

139 

140 Returns: 

141 Digital trace. 

142 """ 

143 if isinstance(trace, DigitalTrace): 

144 return trace 

145 

146 data = trace.data 

147 threshold = (np.max(data) + np.min(data)) / 2 

148 digital_data = data > threshold 

149 

150 return DigitalTrace( 

151 data=digital_data, 

152 metadata=trace.metadata, 

153 ) 

154 

155 

156def _get_default_protocol_config(protocol: str) -> dict[str, Any]: 

157 """Get default configuration for a protocol. 

158 

159 Args: 

160 protocol: Protocol name. 

161 

162 Returns: 

163 Default configuration dictionary. 

164 """ 

165 configs = { 

166 "UART": { 

167 "baud_rate": 115200, 

168 "data_bits": 8, 

169 "parity": "none", 

170 "stop_bits": 1, 

171 }, 

172 "SPI": { 

173 "clock_polarity": 0, 

174 "clock_phase": 0, 

175 "bit_order": "MSB", 

176 }, 

177 "I2C": { 

178 "clock_rate": 100000, # Standard mode 

179 "address_bits": 7, 

180 }, 

181 "CAN": { 

182 "baud_rate": 500000, 

183 "sample_point": 0.75, 

184 }, 

185 } 

186 return configs.get(protocol, {}) # type: ignore[return-value] 

187 

188 

189def _extract_context( 

190 trace: WaveformTrace | DigitalTrace, 

191 sample_idx: int, 

192 context_samples: int, 

193) -> WaveformTrace | DigitalTrace | None: 

194 """Extract context samples around a point. 

195 

196 Args: 

197 trace: Original trace. 

198 sample_idx: Center sample index. 

199 context_samples: Number of samples before and after. 

200 

201 Returns: 

202 Sub-trace with context, or None if invalid. 

203 """ 

204 data = trace.data 

205 start = max(0, sample_idx - context_samples) 

206 end = min(len(data), sample_idx + context_samples) 

207 

208 if end <= start: 

209 return None 

210 

211 context_data = data[start:end] 

212 

213 if isinstance(trace, WaveformTrace): 

214 return WaveformTrace( 

215 data=context_data, # type: ignore[arg-type] 

216 metadata=trace.metadata, 

217 ) 

218 else: 

219 return DigitalTrace( 

220 data=context_data, # type: ignore[arg-type] 

221 metadata=trace.metadata, 

222 ) 

223 

224 

225def _decode_uart( 

226 trace: DigitalTrace, 

227 config: dict[str, Any], 

228 context_samples: int, 

229 error_types: list[str] | None, 

230 original_trace: WaveformTrace | DigitalTrace, 

231) -> tuple[list[ProtocolPacket], list[dict[str, Any]]]: 

232 """Decode UART protocol with error context. 

233 

234 Args: 

235 trace: Digital trace to decode. 

236 config: UART configuration. 

237 context_samples: Context window size. 

238 error_types: Error types to detect. 

239 original_trace: Original trace for context extraction. 

240 

241 Returns: 

242 Tuple of (packets, errors). 

243 """ 

244 from tracekit.analyzers.protocols.uart import UARTDecoder 

245 

246 baud_rate = config.get("baud_rate", 0) 

247 data_bits = config.get("data_bits", 8) 

248 parity = config.get("parity", "none") 

249 stop_bits = config.get("stop_bits", 1) 

250 

251 decoder = UARTDecoder( 

252 baudrate=baud_rate, 

253 data_bits=data_bits, 

254 parity=parity, 

255 stop_bits=stop_bits, 

256 ) 

257 

258 packets = list(decoder.decode(trace)) 

259 errors: list[dict[str, Any]] = [] 

260 

261 sample_rate = trace.metadata.sample_rate 

262 

263 for i, pkt in enumerate(packets): 

264 if pkt.errors: 

265 # Filter by error types if specified 

266 relevant_errors = pkt.errors 

267 if error_types: 

268 relevant_errors = [ 

269 e for e in pkt.errors if any(t.lower() in e.lower() for t in error_types) 

270 ] 

271 

272 for err in relevant_errors: 

273 sample_idx = int(pkt.timestamp * sample_rate) 

274 context_trace = _extract_context(original_trace, sample_idx, context_samples) 

275 

276 error = { 

277 "type": err, 

278 "timestamp": pkt.timestamp, 

279 "packet_index": i, 

280 "address": None, 

281 "data": pkt.data, 

282 "context": f"Samples {sample_idx - context_samples} to {sample_idx + context_samples}", 

283 "context_trace": context_trace, 

284 } 

285 errors.append(error) 

286 

287 return packets, errors 

288 

289 

290def _decode_spi( 

291 trace: DigitalTrace, 

292 config: dict[str, Any], 

293 context_samples: int, 

294 error_types: list[str] | None, 

295 original_trace: WaveformTrace | DigitalTrace, 

296) -> tuple[list[ProtocolPacket], list[dict[str, Any]]]: 

297 """Decode SPI protocol with error context. 

298 

299 Args: 

300 trace: Digital trace to decode. 

301 config: SPI configuration. 

302 context_samples: Context window size. 

303 error_types: Error types to detect. 

304 original_trace: Original trace for context extraction. 

305 

306 Returns: 

307 Tuple of (packets, errors). 

308 """ 

309 from tracekit.analyzers.protocols.spi import SPIDecoder 

310 

311 cpol = config.get("clock_polarity", 0) 

312 cpha = config.get("clock_phase", 0) 

313 word_size = config.get("word_size", 8) 

314 bit_order = config.get("bit_order", "msb").lower() 

315 

316 decoder = SPIDecoder( 

317 cpol=cpol, 

318 cpha=cpha, 

319 word_size=word_size, 

320 bit_order=bit_order, 

321 ) 

322 

323 # For single-channel decode, use trace data as both clock and data 

324 clk = trace.data 

325 mosi = trace.data 

326 

327 packets = list( 

328 decoder.decode( 

329 clk=clk, 

330 mosi=mosi, 

331 sample_rate=trace.metadata.sample_rate, 

332 ) 

333 ) 

334 

335 errors: list[dict[str, Any]] = [] 

336 sample_rate = trace.metadata.sample_rate 

337 

338 for i, pkt in enumerate(packets): 

339 if pkt.errors: 

340 relevant_errors = pkt.errors 

341 if error_types: 341 ↛ 342line 341 didn't jump to line 342 because the condition on line 341 was never true

342 relevant_errors = [ 

343 e for e in pkt.errors if any(t.lower() in e.lower() for t in error_types) 

344 ] 

345 

346 for err in relevant_errors: 

347 sample_idx = int(pkt.timestamp * sample_rate) 

348 context_trace = _extract_context(original_trace, sample_idx, context_samples) 

349 

350 error = { 

351 "type": err, 

352 "timestamp": pkt.timestamp, 

353 "packet_index": i, 

354 "mosi_data": pkt.data, 

355 "context": f"Samples {sample_idx - context_samples} to {sample_idx + context_samples}", 

356 "context_trace": context_trace, 

357 } 

358 errors.append(error) 

359 

360 return packets, errors 

361 

362 

363def _decode_i2c( 

364 trace: DigitalTrace, 

365 config: dict[str, Any], 

366 context_samples: int, 

367 error_types: list[str] | None, 

368 original_trace: WaveformTrace | DigitalTrace, 

369) -> tuple[list[ProtocolPacket], list[dict[str, Any]]]: 

370 """Decode I2C protocol with error context. 

371 

372 Args: 

373 trace: Digital trace to decode. 

374 config: I2C configuration. 

375 context_samples: Context window size. 

376 error_types: Error types to detect. 

377 original_trace: Original trace for context extraction. 

378 

379 Returns: 

380 Tuple of (packets, errors). 

381 """ 

382 from tracekit.analyzers.protocols.i2c import I2CDecoder 

383 

384 address_format = config.get("address_format", "auto") 

385 

386 # For single-channel, assume SDA and create synthetic SCL 

387 sda = trace.data 

388 sample_rate = trace.metadata.sample_rate 

389 

390 edges = np.where(np.diff(sda.astype(int)) != 0)[0] 

391 

392 if len(edges) < 20: 

393 return [], [] 

394 

395 # Create synthetic SCL 

396 scl = np.ones_like(sda, dtype=bool) 

397 for i, edge in enumerate(edges): 

398 if i % 2 == 0 and edge + 10 < len(scl): 

399 scl[edge : edge + 10] = False 

400 

401 decoder = I2CDecoder(address_format=address_format) 

402 packets = list(decoder.decode(scl=scl, sda=sda, sample_rate=sample_rate)) 

403 

404 errors: list[dict[str, Any]] = [] 

405 

406 for i, pkt in enumerate(packets): 

407 if pkt.errors: 

408 relevant_errors = pkt.errors 

409 if error_types: 409 ↛ 410line 409 didn't jump to line 410 because the condition on line 409 was never true

410 relevant_errors = [ 

411 e for e in pkt.errors if any(t.lower() in e.lower() for t in error_types) 

412 ] 

413 

414 for err in relevant_errors: 

415 sample_idx = int(pkt.timestamp * sample_rate) 

416 context_trace = _extract_context(original_trace, sample_idx, context_samples) 

417 

418 addr = pkt.annotations.get("address", 0) if pkt.annotations else 0 

419 

420 error = { 

421 "type": err, 

422 "timestamp": pkt.timestamp, 

423 "packet_index": i, 

424 "address": addr, 

425 "data": pkt.data, 

426 "context": f"Samples {sample_idx - context_samples} to {sample_idx + context_samples}", 

427 "context_trace": context_trace, 

428 } 

429 errors.append(error) 

430 

431 return packets, errors 

432 

433 

434def _decode_can( 

435 trace: DigitalTrace, 

436 config: dict[str, Any], 

437 context_samples: int, 

438 error_types: list[str] | None, 

439 original_trace: WaveformTrace | DigitalTrace, 

440) -> tuple[list[ProtocolPacket], list[dict[str, Any]]]: 

441 """Decode CAN protocol with error context. 

442 

443 Args: 

444 trace: Digital trace to decode. 

445 config: CAN configuration. 

446 context_samples: Context window size. 

447 error_types: Error types to detect. 

448 original_trace: Original trace for context extraction. 

449 

450 Returns: 

451 Tuple of (packets, errors). 

452 """ 

453 from tracekit.analyzers.protocols.can import CANDecoder 

454 

455 bitrate = config.get("baud_rate", 500000) 

456 sample_point = config.get("sample_point", 0.75) 

457 

458 decoder = CANDecoder(bitrate=bitrate, sample_point=sample_point) 

459 packets = list(decoder.decode(trace)) 

460 

461 errors: list[dict[str, Any]] = [] 

462 sample_rate = trace.metadata.sample_rate 

463 

464 for i, pkt in enumerate(packets): 

465 if pkt.errors: 

466 relevant_errors = pkt.errors 

467 if error_types: 467 ↛ 468line 467 didn't jump to line 468 because the condition on line 467 was never true

468 relevant_errors = [ 

469 e for e in pkt.errors if any(t.lower() in e.lower() for t in error_types) 

470 ] 

471 

472 for err in relevant_errors: 

473 sample_idx = int(pkt.timestamp * sample_rate) 

474 context_trace = _extract_context(original_trace, sample_idx, context_samples) 

475 

476 arb_id = pkt.annotations.get("arbitration_id", 0) if pkt.annotations else 0 

477 

478 error = { 

479 "type": err, 

480 "timestamp": pkt.timestamp, 

481 "packet_index": i, 

482 "arbitration_id": arb_id, 

483 "data": pkt.data, 

484 "context": f"Samples {sample_idx - context_samples} to {sample_idx + context_samples}", 

485 "context_trace": context_trace, 

486 } 

487 errors.append(error) 

488 

489 return packets, errors 

490 

491 

492__all__ = ["debug_protocol"]