Coverage for src / tracekit / analyzers / protocols / usb.py: 77%
171 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""USB protocol decoder.
3This module provides USB Low Speed (1.5 Mbps) and Full Speed (12 Mbps)
4protocol decoding with NRZI encoding, bit stuffing, and CRC validation.
7Example:
8 >>> from tracekit.analyzers.protocols.usb import USBDecoder
9 >>> decoder = USBDecoder(speed="full")
10 >>> for packet in decoder.decode(dp=dp, dm=dm):
11 ... print(f"PID: {packet.annotations['pid_name']}")
13References:
14 USB 2.0 Specification (usb.org)
15"""
17from __future__ import annotations
19from enum import Enum
20from typing import TYPE_CHECKING, Literal
22from tracekit.analyzers.protocols.base import (
23 AnnotationLevel,
24 ChannelDef,
25 OptionDef,
26 SyncDecoder,
27)
28from tracekit.core.types import DigitalTrace, ProtocolPacket
30if TYPE_CHECKING:
31 from collections.abc import Iterator
33 import numpy as np
34 from numpy.typing import NDArray
37class USBSpeed(Enum):
38 """USB speed modes."""
40 LOW_SPEED = 1_500_000 # 1.5 Mbps
41 FULL_SPEED = 12_000_000 # 12 Mbps
44class USBPID(Enum):
45 """USB Packet Identifiers."""
47 # Token PIDs
48 OUT = 0b0001
49 IN = 0b1001
50 SOF = 0b0101
51 SETUP = 0b1101
52 # Data PIDs
53 DATA0 = 0b0011
54 DATA1 = 0b1011
55 DATA2 = 0b0111
56 MDATA = 0b1111
57 # Handshake PIDs
58 ACK = 0b0010
59 NAK = 0b1010
60 STALL = 0b1110
61 NYET = 0b0110
62 # Special PIDs
63 PRE = 0b1100
64 ERR = 0b1100
65 SPLIT = 0b1000
66 PING = 0b0100
69# PID names for display
70PID_NAMES = {
71 0b0001: "OUT",
72 0b1001: "IN",
73 0b0101: "SOF",
74 0b1101: "SETUP",
75 0b0011: "DATA0",
76 0b1011: "DATA1",
77 0b0111: "DATA2",
78 0b1111: "MDATA",
79 0b0010: "ACK",
80 0b1010: "NAK",
81 0b1110: "STALL",
82 0b0110: "NYET",
83 0b1100: "PRE/ERR",
84 0b1000: "SPLIT",
85 0b0100: "PING",
86}
89class USBDecoder(SyncDecoder):
90 """USB protocol decoder.
92 Decodes USB Low Speed and Full Speed transactions including
93 NRZI decoding, bit unstuffing, and CRC validation.
95 Attributes:
96 id: "usb"
97 name: "USB"
98 channels: [dp, dm] (required)
100 Example:
101 >>> decoder = USBDecoder(speed="full")
102 >>> for packet in decoder.decode(dp=dp, dm=dm, sample_rate=100e6):
103 ... print(f"PID: {packet.annotations['pid_name']}")
104 """
106 id = "usb"
107 name = "USB"
108 longname = "Universal Serial Bus"
109 desc = "USB Low/Full Speed protocol decoder"
111 channels = [ # noqa: RUF012
112 ChannelDef("dp", "D+", "USB D+ signal", required=True),
113 ChannelDef("dm", "D-", "USB D- signal", required=True),
114 ]
116 optional_channels = [] # noqa: RUF012
118 options = [ # noqa: RUF012
119 OptionDef("speed", "Speed", "USB speed", default="full", values=["low", "full"]),
120 ]
122 annotations = [ # noqa: RUF012
123 ("sync", "SYNC field"),
124 ("pid", "Packet ID"),
125 ("data", "Data payload"),
126 ("crc", "CRC field"),
127 ("eop", "End of Packet"),
128 ("error", "Error"),
129 ]
131 def __init__(
132 self,
133 speed: Literal["low", "full"] = "full",
134 ) -> None:
135 """Initialize USB decoder.
137 Args:
138 speed: USB speed ("low" or "full").
139 """
140 super().__init__(speed=speed)
141 self._speed = USBSpeed.LOW_SPEED if speed == "low" else USBSpeed.FULL_SPEED
143 def decode( # type: ignore[override]
144 self,
145 trace: DigitalTrace | None = None,
146 *,
147 dp: NDArray[np.bool_] | None = None,
148 dm: NDArray[np.bool_] | None = None,
149 sample_rate: float = 1.0,
150 ) -> Iterator[ProtocolPacket]:
151 """Decode USB packets.
153 Args:
154 trace: Optional primary trace.
155 dp: D+ signal.
156 dm: D- signal.
157 sample_rate: Sample rate in Hz.
159 Yields:
160 Decoded USB packets as ProtocolPacket objects.
162 Example:
163 >>> decoder = USBDecoder(speed="full")
164 >>> for pkt in decoder.decode(dp=dp, dm=dm, sample_rate=100e6):
165 ... print(f"Address: {pkt.annotations.get('address', 'N/A')}")
166 """
167 if dp is None or dm is None:
168 return
170 n_samples = min(len(dp), len(dm))
171 dp = dp[:n_samples]
172 dm = dm[:n_samples]
174 # Decode differential signal to single-ended
175 # J state: LS: D-=1, D+=0; FS: D+=1, D-=0
176 # K state: LS: D+=1, D-=0; FS: D-=1, D+=0
177 # SE0: D+=0, D-=0 (idle/EOP)
178 # SE1: D+=1, D-=1 (illegal)
180 if self._speed == USBSpeed.LOW_SPEED:
181 # Low speed: J = D-, K = D+
182 diff_signal = ~dp & dm # J=1, K=0
183 else:
184 # Full speed: J = D+, K = D-
185 diff_signal = dp & ~dm # J=1, K=0
187 # Find packet boundaries (SYNC followed by data, ending with EOP)
188 # EOP is SE0 for at least 2 bit times
190 bit_period = sample_rate / self._speed.value
191 int(2 * bit_period)
193 # Find SE0 regions (both D+ and D- are 0)
194 se0 = ~dp & ~dm
196 trans_num = 0
197 idx = 0
199 while idx < len(diff_signal):
200 # Look for SYNC pattern in NRZI-decoded signal
201 # SYNC is 0x80 (10000000) after NRZI and bit unstuffing
202 # In NRZI: no transition = 1, transition = 0
204 # Decode NRZI starting from current position
205 nrzi_start = self._find_sync_pattern(diff_signal, idx, bit_period)
206 if nrzi_start is None:
207 break
209 # Extract packet bits
210 packet_bits, bit_errors = self._extract_packet_bits(
211 diff_signal, nrzi_start, bit_period, se0
212 )
214 if len(packet_bits) < 16: # Minimum: SYNC(8) + PID(8) 214 ↛ 215line 214 didn't jump to line 215 because the condition on line 214 was never true
215 idx = nrzi_start + int(bit_period)
216 continue
218 # Skip SYNC field (first 8 bits)
219 data_bits = packet_bits[8:]
221 # Extract PID (8 bits)
222 if len(data_bits) < 8: 222 ↛ 223line 222 didn't jump to line 223 because the condition on line 222 was never true
223 idx = nrzi_start + int(bit_period)
224 continue
226 pid_byte = self._bits_to_byte(data_bits[:8])
227 pid_value = pid_byte & 0x0F
228 pid_check = (pid_byte >> 4) & 0x0F
230 errors = list(bit_errors)
232 # Validate PID (upper 4 bits should be complement of lower 4 bits)
233 if pid_value ^ pid_check != 0x0F: 233 ↛ 236line 233 didn't jump to line 236 because the condition on line 233 was always true
234 errors.append("PID check failed")
236 pid_name = PID_NAMES.get(pid_value, f"UNKNOWN(0x{pid_value:X})")
238 # Extract payload based on PID type
239 payload_bits = data_bits[8:]
240 payload_bytes = []
241 annotations = {
242 "transaction_num": trans_num,
243 "pid_value": pid_value,
244 "pid_name": pid_name,
245 }
247 # Token packets: OUT, IN, SOF, SETUP
248 if pid_value in [0b0001, 0b1001, 0b1101]: # OUT, IN, SETUP 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true
249 if len(payload_bits) >= 16: # 11-bit (addr+endp) + 5-bit CRC
250 addr_endp = self._bits_to_value(payload_bits[:11])
251 address = addr_endp & 0x7F
252 endpoint = (addr_endp >> 7) & 0x0F
253 crc5 = self._bits_to_value(payload_bits[11:16])
255 # Validate CRC5
256 expected_crc5 = self._crc5(addr_endp)
257 if crc5 != expected_crc5:
258 errors.append("CRC5 error")
260 annotations["address"] = address
261 annotations["endpoint"] = endpoint
263 elif pid_value == 0b0101: # SOF 263 ↛ 264line 263 didn't jump to line 264 because the condition on line 263 was never true
264 if len(payload_bits) >= 16: # 11-bit frame number + 5-bit CRC
265 frame_num = self._bits_to_value(payload_bits[:11])
266 crc5 = self._bits_to_value(payload_bits[11:16])
267 annotations["frame_number"] = frame_num
269 # Data packets: DATA0, DATA1, DATA2, MDATA
270 elif pid_value in [0b0011, 0b1011, 0b0111, 0b1111]: 270 ↛ 271line 270 didn't jump to line 271 because the condition on line 270 was never true
271 if len(payload_bits) >= 16: # At least CRC16
272 # Data payload + CRC16
273 data_bit_count = len(payload_bits) - 16
274 if data_bit_count >= 0:
275 for i in range(0, data_bit_count, 8):
276 if i + 8 <= data_bit_count:
277 byte_val = self._bits_to_byte(payload_bits[i : i + 8])
278 payload_bytes.append(byte_val)
280 self._bits_to_value(payload_bits[-16:])
281 annotations["data_length"] = len(payload_bytes)
283 # Calculate timing
284 start_time = nrzi_start / sample_rate
285 end_time = (nrzi_start + len(packet_bits) * bit_period) / sample_rate
287 # Add annotation
288 self.put_annotation(
289 start_time,
290 end_time,
291 AnnotationLevel.PACKETS,
292 f"{pid_name}",
293 )
295 # Create packet
296 packet = ProtocolPacket(
297 timestamp=start_time,
298 protocol="usb",
299 data=bytes(payload_bytes),
300 annotations=annotations,
301 errors=errors,
302 )
304 yield packet
306 trans_num += 1
307 idx = int(nrzi_start + len(packet_bits) * bit_period)
309 def _find_sync_pattern(
310 self,
311 signal: NDArray[np.bool_],
312 start_idx: int,
313 bit_period: float,
314 ) -> int | None:
315 """Find USB SYNC pattern (KJKJKJKK in differential).
317 Args:
318 signal: Differential signal.
319 start_idx: Start search index.
320 bit_period: Bit period in samples.
322 Returns:
323 Index of SYNC start, or None if not found.
324 """
325 # SYNC pattern in NRZI is: KJKJKJKK
326 # After NRZI decoding: alternating pattern ending in two same bits
327 # Simplified: look for transition patterns
329 min_transitions = 6
330 idx = start_idx
332 while idx < len(signal) - int(8 * bit_period):
333 # Sample at bit centers
334 trans_count = 0
335 prev_val = signal[idx]
337 for i in range(8):
338 sample_idx = int(idx + i * bit_period)
339 if sample_idx < len(signal): 339 ↛ 337line 339 didn't jump to line 337 because the condition on line 339 was always true
340 curr_val = signal[sample_idx]
341 if curr_val != prev_val:
342 trans_count += 1
343 prev_val = curr_val
345 if trans_count >= min_transitions:
346 return idx
348 idx += int(bit_period / 4) # Scan at quarter-bit resolution
350 return None
352 def _extract_packet_bits(
353 self,
354 signal: NDArray[np.bool_],
355 start_idx: int,
356 bit_period: float,
357 se0: NDArray[np.bool_],
358 ) -> tuple[list[int], list[str]]:
359 """Extract and decode packet bits with NRZI and unstuffing.
361 Args:
362 signal: NRZI-encoded differential signal.
363 start_idx: Packet start index.
364 bit_period: Bit period in samples.
365 se0: SE0 detection array.
367 Returns:
368 (bits, errors) tuple.
369 """
370 bits = []
371 errors = [] # type: ignore[var-annotated]
372 idx = start_idx
373 prev_val = signal[idx] if idx < len(signal) else False
374 stuff_count = 0
376 max_bits = 1024 # Prevent infinite loops
378 for _ in range(max_bits): 378 ↛ 408line 378 didn't jump to line 408 because the loop on line 378 didn't complete
379 sample_idx = int(idx)
380 if sample_idx >= len(signal): 380 ↛ 381line 380 didn't jump to line 381 because the condition on line 380 was never true
381 break
383 # Check for EOP (SE0)
384 if se0[sample_idx]:
385 break
387 curr_val = signal[sample_idx]
389 # NRZI decode: no transition = 1, transition = 0
390 if curr_val == prev_val:
391 bit = 1
392 stuff_count += 1
393 else:
394 bit = 0
395 stuff_count = 0
397 # Bit unstuffing: remove stuff bit after six consecutive 1s
398 if stuff_count == 6: 398 ↛ 400line 398 didn't jump to line 400 because the condition on line 398 was never true
399 # Next bit should be a stuff bit (0)
400 stuff_count = 0
401 # Skip this bit
402 else:
403 bits.append(bit)
405 prev_val = curr_val
406 idx += bit_period # type: ignore[assignment]
408 return bits, errors
410 def _bits_to_byte(self, bits: list[int]) -> int:
411 """Convert 8 bits to byte (LSB first).
413 Args:
414 bits: List of 8 bits.
416 Returns:
417 Byte value.
418 """
419 value = 0
420 for i in range(min(8, len(bits))):
421 value |= bits[i] << i
422 return value
424 def _bits_to_value(self, bits: list[int]) -> int:
425 """Convert bits to integer (LSB first).
427 Args:
428 bits: List of bits.
430 Returns:
431 Integer value.
432 """
433 value = 0
434 for i, bit in enumerate(bits):
435 value |= bit << i
436 return value
438 def _crc5(self, data: int) -> int:
439 """Compute USB CRC5.
441 Args:
442 data: 11-bit data value.
444 Returns:
445 5-bit CRC.
446 """
447 # CRC-5-USB polynomial: x^5 + x^2 + 1 (0x05)
448 crc = 0x1F
449 for i in range(11):
450 bit = (data >> i) & 1
451 if (crc & 1) ^ bit:
452 crc = ((crc >> 1) ^ 0x14) & 0x1F
453 else:
454 crc >>= 1
455 return crc ^ 0x1F
458def decode_usb(
459 dp: NDArray[np.bool_],
460 dm: NDArray[np.bool_],
461 sample_rate: float = 1.0,
462 speed: Literal["low", "full"] = "full",
463) -> list[ProtocolPacket]:
464 """Convenience function to decode USB packets.
466 Args:
467 dp: D+ signal.
468 dm: D- signal.
469 sample_rate: Sample rate in Hz.
470 speed: USB speed ("low" or "full").
472 Returns:
473 List of decoded USB packets.
475 Example:
476 >>> packets = decode_usb(dp, dm, sample_rate=100e6, speed="full")
477 >>> for pkt in packets:
478 ... print(f"PID: {pkt.annotations['pid_name']}")
479 """
480 decoder = USBDecoder(speed=speed)
481 return list(decoder.decode(dp=dp, dm=dm, sample_rate=sample_rate))
484__all__ = ["PID_NAMES", "USBPID", "USBDecoder", "USBSpeed", "decode_usb"]