Coverage for src / tracekit / discovery / auto_decoder.py: 21%

103 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Automatic protocol decoding without user configuration. 

2 

3This module provides one-shot protocol decode that auto-detects parameters 

4(baud rate, polarity, bit order) and decodes common protocols. 

5 

6 

7Example: 

8 >>> from tracekit.discovery import decode_protocol 

9 >>> result = decode_protocol(trace) 

10 >>> print(f"Protocol: {result.protocol}") 

11 >>> print(f"Decoded {len(result.data)} bytes") 

12 >>> for byte_data in result.data[:10]: 

13 ... print(f"0x{byte_data.value:02X} (confidence: {byte_data.confidence:.2f})") 

14 

15References: 

16 Protocol specifications: UART (EIA/TIA-232), SPI (Motorola), I2C (NXP) 

17""" 

18 

19from __future__ import annotations 

20 

21from dataclasses import dataclass, field 

22from typing import TYPE_CHECKING, Any, Literal 

23 

24import numpy as np 

25 

26from tracekit.core.types import DigitalTrace, WaveformTrace 

27from tracekit.discovery.signal_detector import characterize_signal 

28 

29if TYPE_CHECKING: 

30 from numpy.typing import NDArray 

31 

32ProtocolType = Literal["UART", "SPI", "I2C", "unknown"] 

33 

34 

35@dataclass 

36class DecodedByte: 

37 """Single decoded byte with confidence. 

38 

39 Attributes: 

40 value: Byte value (0-255). 

41 offset: Byte offset in decoded stream. 

42 confidence: Decode confidence (0.0-1.0). 

43 has_error: Whether byte has detected errors. 

44 error_type: Type of error if present. 

45 error_description: Plain-language error description. 

46 

47 Example: 

48 >>> byte_data = DecodedByte(value=0x48, offset=0, confidence=0.95) 

49 >>> print(f"Byte: 0x{byte_data.value:02X}, char: {chr(byte_data.value)}") 

50 """ 

51 

52 value: int 

53 offset: int 

54 confidence: float 

55 has_error: bool = False 

56 error_type: str | None = None 

57 error_description: str | None = None 

58 

59 def __post_init__(self) -> None: 

60 """Validate byte data.""" 

61 if not 0 <= self.value <= 255: 

62 raise ValueError(f"Byte value must be 0-255, got {self.value}") 

63 if not 0.0 <= self.confidence <= 1.0: 

64 raise ValueError(f"Confidence must be 0.0-1.0, got {self.confidence}") 

65 

66 

67@dataclass 

68class DecodeResult: 

69 """Protocol decode result with auto-detected parameters. 

70 

71 Attributes: 

72 protocol: Detected protocol name. 

73 overall_confidence: Overall decode confidence (0.0-1.0). 

74 detected_params: Auto-detected protocol parameters. 

75 data: List of decoded bytes with per-byte confidence. 

76 frame_count: Number of frames/packets decoded. 

77 error_count: Number of errors detected. 

78 

79 Example: 

80 >>> result = decode_protocol(trace) 

81 >>> print(f"Protocol: {result.protocol}") 

82 >>> print(f"Confidence: {result.overall_confidence:.2f}") 

83 >>> print(f"Parameters: {result.detected_params}") 

84 >>> print(f"Decoded {len(result.data)} bytes with {result.error_count} errors") 

85 """ 

86 

87 protocol: ProtocolType 

88 overall_confidence: float 

89 detected_params: dict[str, Any] = field(default_factory=dict) 

90 data: list[DecodedByte] = field(default_factory=list) 

91 frame_count: int = 0 

92 error_count: int = 0 

93 

94 

95def decode_protocol( 

96 trace: WaveformTrace | DigitalTrace, 

97 *, 

98 protocol_hint: ProtocolType | None = None, 

99 params_hint: dict[str, Any] | None = None, 

100 confidence_threshold: float = 0.7, 

101 return_errors: bool = True, 

102) -> DecodeResult: 

103 """Automatically decode protocol without user configuration. 

104 

105 Auto-detects protocol type, parameters (baud rate, polarity, bit order), 

106 and decodes data with per-byte confidence scores. 

107 

108 Supported protocols: 

109 - UART: Auto-detects baud rate, parity, stop bits 

110 - SPI: Auto-detects mode (CPOL/CPHA), bit order 

111 - I2C: Auto-detects clock rate, addressing mode 

112 

113 Args: 

114 trace: Input waveform or digital trace. 

115 protocol_hint: Optional protocol hint to narrow detection. 

116 params_hint: Optional parameter hints (e.g., approximate clock freq). 

117 confidence_threshold: Minimum confidence for valid bytes. 

118 return_errors: Whether to include error information. 

119 

120 Returns: 

121 DecodeResult with protocol, parameters, and decoded data. 

122 

123 Raises: 

124 ValueError: If trace is empty or invalid. 

125 

126 Example: 

127 >>> # Automatic decode with no configuration 

128 >>> result = decode_protocol(trace) 

129 >>> if result.overall_confidence >= 0.8: 

130 ... print(f"High confidence {result.protocol} decode") 

131 ... for byte_data in result.data: 

132 ... print(f"0x{byte_data.value:02X}") 

133 

134 >>> # Decode with hint to speed up detection 

135 >>> result = decode_protocol(trace, protocol_hint='UART') 

136 >>> print(f"Baud rate: {result.detected_params['baud_rate']}") 

137 

138 References: 

139 DISC-010: One-Shot Protocol Decode 

140 """ 

141 # Validate input 

142 if len(trace) == 0: 

143 raise ValueError("Cannot decode empty trace") 

144 

145 # Auto-detect protocol if no hint provided 

146 if protocol_hint is None: 

147 char_result = characterize_signal(trace) 

148 

149 # Map signal types to protocols 

150 if char_result.signal_type == "uart": 

151 protocol_hint = "UART" 

152 elif char_result.signal_type == "spi": 

153 protocol_hint = "SPI" 

154 elif char_result.signal_type == "i2c": 

155 protocol_hint = "I2C" 

156 # Try UART as default for digital signals 

157 elif char_result.signal_type == "digital": 

158 protocol_hint = "UART" 

159 else: 

160 return DecodeResult( 

161 protocol="unknown", 

162 overall_confidence=0.0, 

163 detected_params={}, 

164 data=[], 

165 ) 

166 

167 # Decode based on detected/hinted protocol 

168 if protocol_hint == "UART": 

169 return _decode_uart_auto(trace, params_hint, confidence_threshold, return_errors) 

170 elif protocol_hint == "SPI": 

171 return _decode_spi_auto(trace, params_hint, confidence_threshold, return_errors) 

172 elif protocol_hint == "I2C": 

173 return _decode_i2c_auto(trace, params_hint, confidence_threshold, return_errors) 

174 else: 

175 return DecodeResult( 

176 protocol="unknown", 

177 overall_confidence=0.0, 

178 detected_params={}, 

179 data=[], 

180 ) 

181 

182 

183def _decode_uart_auto( 

184 trace: WaveformTrace | DigitalTrace, 

185 params_hint: dict[str, Any] | None, 

186 confidence_threshold: float, 

187 return_errors: bool, 

188) -> DecodeResult: 

189 """Auto-decode UART with parameter detection. 

190 

191 Args: 

192 trace: Input trace. 

193 params_hint: Optional parameter hints. 

194 confidence_threshold: Minimum confidence threshold. 

195 return_errors: Whether to include errors. 

196 

197 Returns: 

198 DecodeResult for UART. 

199 """ 

200 from tracekit.analyzers.protocols.uart import UARTDecoder 

201 

202 # Get data array 

203 if isinstance(trace, WaveformTrace): 

204 data = trace.data 

205 sample_rate = trace.metadata.sample_rate 

206 else: 

207 data = trace.data.astype(np.float64) 

208 sample_rate = trace.metadata.sample_rate 

209 

210 # Auto-detect baud rate if not provided 

211 if params_hint and "baud_rate" in params_hint: 

212 baud_rate = params_hint["baud_rate"] 

213 else: 

214 baud_rate = _detect_baud_rate(data, sample_rate) 

215 

216 # Default UART parameters (8N1) 

217 data_bits = params_hint.get("data_bits", 8) if params_hint else 8 

218 parity = params_hint.get("parity", "none") if params_hint else "none" 

219 stop_bits = params_hint.get("stop_bits", 1) if params_hint else 1 

220 

221 # Create decoder with detected parameters 

222 decoder = UARTDecoder( 

223 baudrate=baud_rate, 

224 data_bits=data_bits, 

225 parity=parity, # type: ignore[arg-type] 

226 stop_bits=stop_bits, 

227 ) 

228 

229 # Decode 

230 try: 

231 packets = list(decoder.decode(trace)) 

232 except Exception: 

233 # Decode failed 

234 return DecodeResult( 

235 protocol="UART", 

236 overall_confidence=0.3, 

237 detected_params={"baud_rate": baud_rate}, 

238 data=[], 

239 ) 

240 

241 # Convert to DecodedByte format 

242 decoded_bytes: list[DecodedByte] = [] 

243 error_count = 0 

244 

245 for i, packet in enumerate(packets): 

246 for byte_val in packet.data: 

247 # Calculate confidence based on errors 

248 has_error = len(packet.errors) > 0 

249 if has_error: 

250 confidence = 0.65 

251 error_count += 1 

252 else: 

253 confidence = 0.95 

254 

255 decoded_bytes.append( 

256 DecodedByte( 

257 value=byte_val, 

258 offset=i, 

259 confidence=confidence, 

260 has_error=has_error, 

261 error_type=packet.errors[0] if packet.errors else None, 

262 error_description=packet.errors[0] if packet.errors else None, 

263 ) 

264 ) 

265 

266 # Calculate overall confidence 

267 if decoded_bytes: 

268 avg_confidence = np.mean([b.confidence for b in decoded_bytes]) 

269 overall_confidence = float(round(float(avg_confidence), 2)) 

270 else: 

271 overall_confidence = 0.0 

272 

273 return DecodeResult( 

274 protocol="UART", 

275 overall_confidence=overall_confidence, 

276 detected_params={ 

277 "baud_rate": baud_rate, 

278 "data_bits": data_bits, 

279 "parity": parity, 

280 "stop_bits": stop_bits, 

281 }, 

282 data=decoded_bytes, 

283 frame_count=len(packets), 

284 error_count=error_count, 

285 ) 

286 

287 

288def _decode_spi_auto( 

289 trace: WaveformTrace | DigitalTrace, 

290 params_hint: dict[str, Any] | None, 

291 confidence_threshold: float, 

292 return_errors: bool, 

293) -> DecodeResult: 

294 """Auto-decode SPI with parameter detection. 

295 

296 Args: 

297 trace: Input trace. 

298 params_hint: Optional parameter hints. 

299 confidence_threshold: Minimum confidence threshold. 

300 return_errors: Whether to include errors. 

301 

302 Returns: 

303 DecodeResult for SPI. 

304 """ 

305 # SPI requires multiple channels (clock, MOSI, optionally MISO) 

306 # Single-channel auto-decode is limited 

307 # Return low-confidence result indicating more channels needed 

308 

309 return DecodeResult( 

310 protocol="SPI", 

311 overall_confidence=0.4, 

312 detected_params={"note": "SPI requires clock and data channels"}, 

313 data=[], 

314 frame_count=0, 

315 error_count=0, 

316 ) 

317 

318 

319def _decode_i2c_auto( 

320 trace: WaveformTrace | DigitalTrace, 

321 params_hint: dict[str, Any] | None, 

322 confidence_threshold: float, 

323 return_errors: bool, 

324) -> DecodeResult: 

325 """Auto-decode I2C with parameter detection. 

326 

327 Args: 

328 trace: Input trace. 

329 params_hint: Optional parameter hints. 

330 confidence_threshold: Minimum confidence threshold. 

331 return_errors: Whether to include errors. 

332 

333 Returns: 

334 DecodeResult for I2C. 

335 """ 

336 # I2C requires both SDA and SCL channels 

337 # Single-channel auto-decode is limited 

338 # Return low-confidence result indicating more channels needed 

339 

340 return DecodeResult( 

341 protocol="I2C", 

342 overall_confidence=0.4, 

343 detected_params={"note": "I2C requires SDA and SCL channels"}, 

344 data=[], 

345 frame_count=0, 

346 error_count=0, 

347 ) 

348 

349 

350def _detect_baud_rate(data: NDArray[np.floating[Any]], sample_rate: float) -> int: 

351 """Auto-detect UART baud rate from signal. 

352 

353 Args: 

354 data: Signal data array. 

355 sample_rate: Sample rate in Hz. 

356 

357 Returns: 

358 Detected baud rate in bps. 

359 """ 

360 # Threshold signal to digital 

361 threshold = (np.max(data) + np.min(data)) / 2 

362 digital = data > threshold 

363 

364 # Find edges 

365 edges = np.where(np.diff(digital.astype(int)) != 0)[0] 

366 

367 if len(edges) < 10: 

368 return 115200 # Default fallback 

369 

370 # Analyze edge intervals to find bit period 

371 intervals = np.diff(edges) 

372 

373 # Use histogram to find most common interval (bit period) 

374 hist, bin_edges = np.histogram(intervals, bins=50) 

375 peak_bin = np.argmax(hist) 

376 bit_period_samples = (bin_edges[peak_bin] + bin_edges[peak_bin + 1]) / 2 

377 

378 # Calculate baud rate 

379 estimated_baud = int(sample_rate / bit_period_samples) 

380 

381 # Snap to common baud rates 

382 common_bauds = [ 

383 300, 

384 600, 

385 1200, 

386 2400, 

387 4800, 

388 9600, 

389 14400, 

390 19200, 

391 28800, 

392 38400, 

393 57600, 

394 115200, 

395 230400, 

396 460800, 

397 921600, 

398 ] 

399 

400 closest_baud = min(common_bauds, key=lambda x: abs(x - estimated_baud)) 

401 

402 # Validate: should be within 5% of detected rate 

403 if abs(closest_baud - estimated_baud) / estimated_baud < 0.05: 

404 return closest_baud 

405 else: 

406 # Use estimated rate if no close match 

407 return estimated_baud 

408 

409 

410__all__ = [ 

411 "DecodeResult", 

412 "DecodedByte", 

413 "ProtocolType", 

414 "decode_protocol", 

415]