Coverage for src / tracekit / analyzers / protocols / usb.py: 77%

171 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""USB protocol decoder. 

2 

3This module provides USB Low Speed (1.5 Mbps) and Full Speed (12 Mbps) 

4protocol decoding with NRZI encoding, bit stuffing, and CRC validation. 

5 

6 

7Example: 

8 >>> from tracekit.analyzers.protocols.usb import USBDecoder 

9 >>> decoder = USBDecoder(speed="full") 

10 >>> for packet in decoder.decode(dp=dp, dm=dm): 

11 ... print(f"PID: {packet.annotations['pid_name']}") 

12 

13References: 

14 USB 2.0 Specification (usb.org) 

15""" 

16 

17from __future__ import annotations 

18 

19from enum import Enum 

20from typing import TYPE_CHECKING, Literal 

21 

22from tracekit.analyzers.protocols.base import ( 

23 AnnotationLevel, 

24 ChannelDef, 

25 OptionDef, 

26 SyncDecoder, 

27) 

28from tracekit.core.types import DigitalTrace, ProtocolPacket 

29 

30if TYPE_CHECKING: 

31 from collections.abc import Iterator 

32 

33 import numpy as np 

34 from numpy.typing import NDArray 

35 

36 

37class USBSpeed(Enum): 

38 """USB speed modes.""" 

39 

40 LOW_SPEED = 1_500_000 # 1.5 Mbps 

41 FULL_SPEED = 12_000_000 # 12 Mbps 

42 

43 

44class USBPID(Enum): 

45 """USB Packet Identifiers.""" 

46 

47 # Token PIDs 

48 OUT = 0b0001 

49 IN = 0b1001 

50 SOF = 0b0101 

51 SETUP = 0b1101 

52 # Data PIDs 

53 DATA0 = 0b0011 

54 DATA1 = 0b1011 

55 DATA2 = 0b0111 

56 MDATA = 0b1111 

57 # Handshake PIDs 

58 ACK = 0b0010 

59 NAK = 0b1010 

60 STALL = 0b1110 

61 NYET = 0b0110 

62 # Special PIDs 

63 PRE = 0b1100 

64 ERR = 0b1100 

65 SPLIT = 0b1000 

66 PING = 0b0100 

67 

68 

69# PID names for display 

70PID_NAMES = { 

71 0b0001: "OUT", 

72 0b1001: "IN", 

73 0b0101: "SOF", 

74 0b1101: "SETUP", 

75 0b0011: "DATA0", 

76 0b1011: "DATA1", 

77 0b0111: "DATA2", 

78 0b1111: "MDATA", 

79 0b0010: "ACK", 

80 0b1010: "NAK", 

81 0b1110: "STALL", 

82 0b0110: "NYET", 

83 0b1100: "PRE/ERR", 

84 0b1000: "SPLIT", 

85 0b0100: "PING", 

86} 

87 

88 

89class USBDecoder(SyncDecoder): 

90 """USB protocol decoder. 

91 

92 Decodes USB Low Speed and Full Speed transactions including 

93 NRZI decoding, bit unstuffing, and CRC validation. 

94 

95 Attributes: 

96 id: "usb" 

97 name: "USB" 

98 channels: [dp, dm] (required) 

99 

100 Example: 

101 >>> decoder = USBDecoder(speed="full") 

102 >>> for packet in decoder.decode(dp=dp, dm=dm, sample_rate=100e6): 

103 ... print(f"PID: {packet.annotations['pid_name']}") 

104 """ 

105 

106 id = "usb" 

107 name = "USB" 

108 longname = "Universal Serial Bus" 

109 desc = "USB Low/Full Speed protocol decoder" 

110 

111 channels = [ # noqa: RUF012 

112 ChannelDef("dp", "D+", "USB D+ signal", required=True), 

113 ChannelDef("dm", "D-", "USB D- signal", required=True), 

114 ] 

115 

116 optional_channels = [] # noqa: RUF012 

117 

118 options = [ # noqa: RUF012 

119 OptionDef("speed", "Speed", "USB speed", default="full", values=["low", "full"]), 

120 ] 

121 

122 annotations = [ # noqa: RUF012 

123 ("sync", "SYNC field"), 

124 ("pid", "Packet ID"), 

125 ("data", "Data payload"), 

126 ("crc", "CRC field"), 

127 ("eop", "End of Packet"), 

128 ("error", "Error"), 

129 ] 

130 

131 def __init__( 

132 self, 

133 speed: Literal["low", "full"] = "full", 

134 ) -> None: 

135 """Initialize USB decoder. 

136 

137 Args: 

138 speed: USB speed ("low" or "full"). 

139 """ 

140 super().__init__(speed=speed) 

141 self._speed = USBSpeed.LOW_SPEED if speed == "low" else USBSpeed.FULL_SPEED 

142 

143 def decode( # type: ignore[override] 

144 self, 

145 trace: DigitalTrace | None = None, 

146 *, 

147 dp: NDArray[np.bool_] | None = None, 

148 dm: NDArray[np.bool_] | None = None, 

149 sample_rate: float = 1.0, 

150 ) -> Iterator[ProtocolPacket]: 

151 """Decode USB packets. 

152 

153 Args: 

154 trace: Optional primary trace. 

155 dp: D+ signal. 

156 dm: D- signal. 

157 sample_rate: Sample rate in Hz. 

158 

159 Yields: 

160 Decoded USB packets as ProtocolPacket objects. 

161 

162 Example: 

163 >>> decoder = USBDecoder(speed="full") 

164 >>> for pkt in decoder.decode(dp=dp, dm=dm, sample_rate=100e6): 

165 ... print(f"Address: {pkt.annotations.get('address', 'N/A')}") 

166 """ 

167 if dp is None or dm is None: 

168 return 

169 

170 n_samples = min(len(dp), len(dm)) 

171 dp = dp[:n_samples] 

172 dm = dm[:n_samples] 

173 

174 # Decode differential signal to single-ended 

175 # J state: LS: D-=1, D+=0; FS: D+=1, D-=0 

176 # K state: LS: D+=1, D-=0; FS: D-=1, D+=0 

177 # SE0: D+=0, D-=0 (idle/EOP) 

178 # SE1: D+=1, D-=1 (illegal) 

179 

180 if self._speed == USBSpeed.LOW_SPEED: 

181 # Low speed: J = D-, K = D+ 

182 diff_signal = ~dp & dm # J=1, K=0 

183 else: 

184 # Full speed: J = D+, K = D- 

185 diff_signal = dp & ~dm # J=1, K=0 

186 

187 # Find packet boundaries (SYNC followed by data, ending with EOP) 

188 # EOP is SE0 for at least 2 bit times 

189 

190 bit_period = sample_rate / self._speed.value 

191 int(2 * bit_period) 

192 

193 # Find SE0 regions (both D+ and D- are 0) 

194 se0 = ~dp & ~dm 

195 

196 trans_num = 0 

197 idx = 0 

198 

199 while idx < len(diff_signal): 

200 # Look for SYNC pattern in NRZI-decoded signal 

201 # SYNC is 0x80 (10000000) after NRZI and bit unstuffing 

202 # In NRZI: no transition = 1, transition = 0 

203 

204 # Decode NRZI starting from current position 

205 nrzi_start = self._find_sync_pattern(diff_signal, idx, bit_period) 

206 if nrzi_start is None: 

207 break 

208 

209 # Extract packet bits 

210 packet_bits, bit_errors = self._extract_packet_bits( 

211 diff_signal, nrzi_start, bit_period, se0 

212 ) 

213 

214 if len(packet_bits) < 16: # Minimum: SYNC(8) + PID(8) 214 ↛ 215line 214 didn't jump to line 215 because the condition on line 214 was never true

215 idx = nrzi_start + int(bit_period) 

216 continue 

217 

218 # Skip SYNC field (first 8 bits) 

219 data_bits = packet_bits[8:] 

220 

221 # Extract PID (8 bits) 

222 if len(data_bits) < 8: 222 ↛ 223line 222 didn't jump to line 223 because the condition on line 222 was never true

223 idx = nrzi_start + int(bit_period) 

224 continue 

225 

226 pid_byte = self._bits_to_byte(data_bits[:8]) 

227 pid_value = pid_byte & 0x0F 

228 pid_check = (pid_byte >> 4) & 0x0F 

229 

230 errors = list(bit_errors) 

231 

232 # Validate PID (upper 4 bits should be complement of lower 4 bits) 

233 if pid_value ^ pid_check != 0x0F: 233 ↛ 236line 233 didn't jump to line 236 because the condition on line 233 was always true

234 errors.append("PID check failed") 

235 

236 pid_name = PID_NAMES.get(pid_value, f"UNKNOWN(0x{pid_value:X})") 

237 

238 # Extract payload based on PID type 

239 payload_bits = data_bits[8:] 

240 payload_bytes = [] 

241 annotations = { 

242 "transaction_num": trans_num, 

243 "pid_value": pid_value, 

244 "pid_name": pid_name, 

245 } 

246 

247 # Token packets: OUT, IN, SOF, SETUP 

248 if pid_value in [0b0001, 0b1001, 0b1101]: # OUT, IN, SETUP 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true

249 if len(payload_bits) >= 16: # 11-bit (addr+endp) + 5-bit CRC 

250 addr_endp = self._bits_to_value(payload_bits[:11]) 

251 address = addr_endp & 0x7F 

252 endpoint = (addr_endp >> 7) & 0x0F 

253 crc5 = self._bits_to_value(payload_bits[11:16]) 

254 

255 # Validate CRC5 

256 expected_crc5 = self._crc5(addr_endp) 

257 if crc5 != expected_crc5: 

258 errors.append("CRC5 error") 

259 

260 annotations["address"] = address 

261 annotations["endpoint"] = endpoint 

262 

263 elif pid_value == 0b0101: # SOF 263 ↛ 264line 263 didn't jump to line 264 because the condition on line 263 was never true

264 if len(payload_bits) >= 16: # 11-bit frame number + 5-bit CRC 

265 frame_num = self._bits_to_value(payload_bits[:11]) 

266 crc5 = self._bits_to_value(payload_bits[11:16]) 

267 annotations["frame_number"] = frame_num 

268 

269 # Data packets: DATA0, DATA1, DATA2, MDATA 

270 elif pid_value in [0b0011, 0b1011, 0b0111, 0b1111]: 270 ↛ 271line 270 didn't jump to line 271 because the condition on line 270 was never true

271 if len(payload_bits) >= 16: # At least CRC16 

272 # Data payload + CRC16 

273 data_bit_count = len(payload_bits) - 16 

274 if data_bit_count >= 0: 

275 for i in range(0, data_bit_count, 8): 

276 if i + 8 <= data_bit_count: 

277 byte_val = self._bits_to_byte(payload_bits[i : i + 8]) 

278 payload_bytes.append(byte_val) 

279 

280 self._bits_to_value(payload_bits[-16:]) 

281 annotations["data_length"] = len(payload_bytes) 

282 

283 # Calculate timing 

284 start_time = nrzi_start / sample_rate 

285 end_time = (nrzi_start + len(packet_bits) * bit_period) / sample_rate 

286 

287 # Add annotation 

288 self.put_annotation( 

289 start_time, 

290 end_time, 

291 AnnotationLevel.PACKETS, 

292 f"{pid_name}", 

293 ) 

294 

295 # Create packet 

296 packet = ProtocolPacket( 

297 timestamp=start_time, 

298 protocol="usb", 

299 data=bytes(payload_bytes), 

300 annotations=annotations, 

301 errors=errors, 

302 ) 

303 

304 yield packet 

305 

306 trans_num += 1 

307 idx = int(nrzi_start + len(packet_bits) * bit_period) 

308 

309 def _find_sync_pattern( 

310 self, 

311 signal: NDArray[np.bool_], 

312 start_idx: int, 

313 bit_period: float, 

314 ) -> int | None: 

315 """Find USB SYNC pattern (KJKJKJKK in differential). 

316 

317 Args: 

318 signal: Differential signal. 

319 start_idx: Start search index. 

320 bit_period: Bit period in samples. 

321 

322 Returns: 

323 Index of SYNC start, or None if not found. 

324 """ 

325 # SYNC pattern in NRZI is: KJKJKJKK 

326 # After NRZI decoding: alternating pattern ending in two same bits 

327 # Simplified: look for transition patterns 

328 

329 min_transitions = 6 

330 idx = start_idx 

331 

332 while idx < len(signal) - int(8 * bit_period): 

333 # Sample at bit centers 

334 trans_count = 0 

335 prev_val = signal[idx] 

336 

337 for i in range(8): 

338 sample_idx = int(idx + i * bit_period) 

339 if sample_idx < len(signal): 339 ↛ 337line 339 didn't jump to line 337 because the condition on line 339 was always true

340 curr_val = signal[sample_idx] 

341 if curr_val != prev_val: 

342 trans_count += 1 

343 prev_val = curr_val 

344 

345 if trans_count >= min_transitions: 

346 return idx 

347 

348 idx += int(bit_period / 4) # Scan at quarter-bit resolution 

349 

350 return None 

351 

352 def _extract_packet_bits( 

353 self, 

354 signal: NDArray[np.bool_], 

355 start_idx: int, 

356 bit_period: float, 

357 se0: NDArray[np.bool_], 

358 ) -> tuple[list[int], list[str]]: 

359 """Extract and decode packet bits with NRZI and unstuffing. 

360 

361 Args: 

362 signal: NRZI-encoded differential signal. 

363 start_idx: Packet start index. 

364 bit_period: Bit period in samples. 

365 se0: SE0 detection array. 

366 

367 Returns: 

368 (bits, errors) tuple. 

369 """ 

370 bits = [] 

371 errors = [] # type: ignore[var-annotated] 

372 idx = start_idx 

373 prev_val = signal[idx] if idx < len(signal) else False 

374 stuff_count = 0 

375 

376 max_bits = 1024 # Prevent infinite loops 

377 

378 for _ in range(max_bits): 378 ↛ 408line 378 didn't jump to line 408 because the loop on line 378 didn't complete

379 sample_idx = int(idx) 

380 if sample_idx >= len(signal): 380 ↛ 381line 380 didn't jump to line 381 because the condition on line 380 was never true

381 break 

382 

383 # Check for EOP (SE0) 

384 if se0[sample_idx]: 

385 break 

386 

387 curr_val = signal[sample_idx] 

388 

389 # NRZI decode: no transition = 1, transition = 0 

390 if curr_val == prev_val: 

391 bit = 1 

392 stuff_count += 1 

393 else: 

394 bit = 0 

395 stuff_count = 0 

396 

397 # Bit unstuffing: remove stuff bit after six consecutive 1s 

398 if stuff_count == 6: 398 ↛ 400line 398 didn't jump to line 400 because the condition on line 398 was never true

399 # Next bit should be a stuff bit (0) 

400 stuff_count = 0 

401 # Skip this bit 

402 else: 

403 bits.append(bit) 

404 

405 prev_val = curr_val 

406 idx += bit_period # type: ignore[assignment] 

407 

408 return bits, errors 

409 

410 def _bits_to_byte(self, bits: list[int]) -> int: 

411 """Convert 8 bits to byte (LSB first). 

412 

413 Args: 

414 bits: List of 8 bits. 

415 

416 Returns: 

417 Byte value. 

418 """ 

419 value = 0 

420 for i in range(min(8, len(bits))): 

421 value |= bits[i] << i 

422 return value 

423 

424 def _bits_to_value(self, bits: list[int]) -> int: 

425 """Convert bits to integer (LSB first). 

426 

427 Args: 

428 bits: List of bits. 

429 

430 Returns: 

431 Integer value. 

432 """ 

433 value = 0 

434 for i, bit in enumerate(bits): 

435 value |= bit << i 

436 return value 

437 

438 def _crc5(self, data: int) -> int: 

439 """Compute USB CRC5. 

440 

441 Args: 

442 data: 11-bit data value. 

443 

444 Returns: 

445 5-bit CRC. 

446 """ 

447 # CRC-5-USB polynomial: x^5 + x^2 + 1 (0x05) 

448 crc = 0x1F 

449 for i in range(11): 

450 bit = (data >> i) & 1 

451 if (crc & 1) ^ bit: 

452 crc = ((crc >> 1) ^ 0x14) & 0x1F 

453 else: 

454 crc >>= 1 

455 return crc ^ 0x1F 

456 

457 

458def decode_usb( 

459 dp: NDArray[np.bool_], 

460 dm: NDArray[np.bool_], 

461 sample_rate: float = 1.0, 

462 speed: Literal["low", "full"] = "full", 

463) -> list[ProtocolPacket]: 

464 """Convenience function to decode USB packets. 

465 

466 Args: 

467 dp: D+ signal. 

468 dm: D- signal. 

469 sample_rate: Sample rate in Hz. 

470 speed: USB speed ("low" or "full"). 

471 

472 Returns: 

473 List of decoded USB packets. 

474 

475 Example: 

476 >>> packets = decode_usb(dp, dm, sample_rate=100e6, speed="full") 

477 >>> for pkt in packets: 

478 ... print(f"PID: {pkt.annotations['pid_name']}") 

479 """ 

480 decoder = USBDecoder(speed=speed) 

481 return list(decoder.decode(dp=dp, dm=dm, sample_rate=sample_rate)) 

482 

483 

484__all__ = ["PID_NAMES", "USBPID", "USBDecoder", "USBSpeed", "decode_usb"]