Coverage for src / tracekit / utils / autodetect.py: 94%
114 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Auto-detection utilities for signal analysis.
3This module provides utilities for automatic detection of signal
4parameters such as baud rate, logic levels, and protocol types.
7Example:
8 >>> from tracekit.utils.autodetect import detect_baud_rate
9 >>> baudrate = detect_baud_rate(trace)
10 >>> print(f"Detected baud rate: {baudrate}")
12References:
13 Standard baud rates. and UART specifications.
14"""
16from __future__ import annotations
18from typing import TYPE_CHECKING, Literal
20import numpy as np
22from tracekit.core.types import DigitalTrace, WaveformTrace
24if TYPE_CHECKING:
25 from numpy.typing import NDArray
27# Standard baud rates (RS-232, UART, CAN, etc.)
28STANDARD_BAUD_RATES: tuple[int, ...] = (
29 300,
30 600,
31 1200,
32 2400,
33 4800,
34 9600,
35 14400,
36 19200,
37 28800,
38 38400,
39 57600,
40 76800,
41 115200,
42 230400,
43 250000, # CAN common
44 460800,
45 500000, # CAN common
46 576000,
47 921600,
48 1000000, # 1 Mbps
49 1500000,
50 2000000,
51 3000000,
52 4000000,
53)
56def detect_baud_rate(
57 trace: WaveformTrace | DigitalTrace,
58 *,
59 threshold: float | Literal["auto"] = "auto",
60 method: Literal["pulse_width", "edge_timing", "autocorr"] = "pulse_width",
61 tolerance: float = 0.05,
62 return_confidence: bool = False,
63) -> int | tuple[int, float]:
64 """Detect baud rate from signal timing.
66 Analyzes pulse widths or edge timing to determine the symbol rate,
67 then maps to the nearest standard baud rate.
69 Args:
70 trace: Input trace (analog or digital).
71 threshold: Threshold for analog to digital conversion.
72 method: Detection method:
73 - "pulse_width": Minimum pulse width (default)
74 - "edge_timing": Edge-to-edge timing analysis
75 - "autocorr": Autocorrelation peak detection
76 tolerance: Tolerance for matching to standard rate (default 5%).
77 return_confidence: If True, also return confidence score.
79 Returns:
80 Detected baud rate (nearest standard), or tuple of (rate, confidence)
81 if return_confidence=True.
83 Raises:
84 ValueError: If unknown detection method specified.
86 Example:
87 >>> baudrate = detect_baud_rate(trace)
88 >>> print(f"Detected: {baudrate} bps")
90 >>> baudrate, confidence = detect_baud_rate(trace, return_confidence=True)
91 >>> print(f"Detected: {baudrate} bps ({confidence:.0%} confidence)")
93 References:
94 RS-232 Standard Baud Rates
95 """
96 # Get digital representation
97 if isinstance(trace, WaveformTrace):
98 from tracekit.analyzers.digital.extraction import to_digital
100 digital_trace = to_digital(trace, threshold=threshold)
101 data = digital_trace.data
102 else:
103 data = trace.data
105 sample_rate = trace.metadata.sample_rate
107 if method == "pulse_width":
108 bit_period = _detect_via_pulse_width(data, sample_rate)
109 elif method == "edge_timing":
110 bit_period = _detect_via_edge_timing(data, sample_rate)
111 elif method == "autocorr":
112 bit_period = _detect_via_autocorrelation(data, sample_rate)
113 else:
114 raise ValueError(f"Unknown method: {method}")
116 if bit_period <= 0 or np.isnan(bit_period):
117 if return_confidence: 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true
118 return 0, 0.0
119 return 0
121 # Convert to baud rate
122 measured_rate = 1.0 / bit_period
124 # Find nearest standard rate
125 best_rate = 0
126 best_error = float("inf")
128 for std_rate in STANDARD_BAUD_RATES:
129 error = abs(measured_rate - std_rate) / std_rate
130 if error < best_error:
131 best_error = error
132 best_rate = std_rate
134 # Compute confidence
135 confidence = max(0.0, 1.0 - best_error / tolerance) if best_error <= tolerance else 0.0
137 if return_confidence:
138 return best_rate, confidence
140 return best_rate
143def _detect_via_pulse_width(data: NDArray[np.bool_], sample_rate: float) -> float:
144 """Detect bit period from minimum pulse width.
146 Args:
147 data: Digital signal data.
148 sample_rate: Sample rate in Hz.
150 Returns:
151 Estimated bit period in seconds.
152 """
153 # Find pulse widths (runs of consecutive values)
154 pulse_widths = []
156 current_value = data[0]
157 run_length = 1
159 for i in range(1, len(data)):
160 if data[i] == current_value:
161 run_length += 1
162 else:
163 pulse_widths.append(run_length)
164 current_value = data[i]
165 run_length = 1
167 # Add final run
168 pulse_widths.append(run_length)
170 if len(pulse_widths) == 0: 170 ↛ 171line 170 didn't jump to line 171 because the condition on line 170 was never true
171 return 0.0
173 pulse_widths_arr = np.array(pulse_widths, dtype=np.float64)
175 # Filter out very short pulses (noise)
176 min_pulse = max(2, np.min(pulse_widths_arr[pulse_widths_arr > 1]))
178 # The minimum pulse width corresponds to a single bit
179 # Use the mode of small pulses for robustness
180 small_pulses = pulse_widths_arr[pulse_widths_arr <= min_pulse * 1.5]
182 bit_samples = min_pulse if len(small_pulses) == 0 else np.median(small_pulses)
184 return float(bit_samples / sample_rate)
187def _detect_via_edge_timing(data: NDArray[np.bool_], sample_rate: float) -> float:
188 """Detect bit period from edge-to-edge timing.
190 Args:
191 data: Digital signal data.
192 sample_rate: Sample rate in Hz.
194 Returns:
195 Estimated bit period in seconds.
196 """
197 # Find all edges
198 transitions = np.diff(data.astype(np.int8))
199 edge_indices = np.where(transitions != 0)[0]
201 if len(edge_indices) < 2:
202 return 0.0
204 # Compute edge intervals
205 intervals = np.diff(edge_indices).astype(np.float64)
207 if len(intervals) == 0: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true
208 return 0.0
210 # Intervals should be multiples of bit period
211 # Find GCD-like value using histogram
212 min_interval = np.min(intervals)
213 max_check = min(min_interval * 2, np.median(intervals))
215 # The bit period is the smallest common interval
216 # Use histogram to find the cluster
217 bins = np.arange(1, max_check + 1)
218 hist, _ = np.histogram(intervals, bins=bins)
220 if len(hist) == 0 or np.max(hist) == 0: 220 ↛ 221line 220 didn't jump to line 221 because the condition on line 220 was never true
221 bit_samples = min_interval
222 else:
223 # Find first significant peak
224 threshold = np.max(hist) * 0.3
225 peaks = np.where(hist >= threshold)[0]
227 if len(peaks) > 0: 227 ↛ 230line 227 didn't jump to line 230 because the condition on line 227 was always true
228 bit_samples = peaks[0] + 1 # +1 for bin offset
229 else:
230 bit_samples = min_interval
232 return float(bit_samples / sample_rate)
235def _detect_via_autocorrelation(data: NDArray[np.bool_], sample_rate: float) -> float:
236 """Detect bit period via autocorrelation.
238 Args:
239 data: Digital signal data.
240 sample_rate: Sample rate in Hz.
242 Returns:
243 Estimated bit period in seconds.
244 """
245 # Convert to float for correlation
246 signal = data.astype(np.float64) * 2 - 1 # Map to [-1, 1]
248 # Remove DC
249 signal = signal - np.mean(signal)
251 # Compute autocorrelation
252 n = len(signal)
253 max_lag = min(n // 2, int(sample_rate / 300)) # Limit to reasonable range
255 autocorr = np.correlate(signal[: max_lag * 2], signal[: max_lag * 2], mode="full")
256 autocorr = autocorr[len(autocorr) // 2 :] # Keep positive lags
258 # Normalize
259 autocorr = autocorr / autocorr[0]
261 # Find first significant peak after lag 0
262 # Skip initial samples to avoid lag-0 region
263 min_lag = max(2, max_lag // 100)
265 # Find local maxima
266 peaks = []
267 for i in range(min_lag, len(autocorr) - 1):
268 if autocorr[i] > autocorr[i - 1] and autocorr[i] > autocorr[i + 1]:
269 if autocorr[i] > 0.3: # Significance threshold
270 peaks.append((i, autocorr[i]))
272 if len(peaks) == 0:
273 return 0.0
275 # First significant peak is likely the bit period
276 bit_samples = peaks[0][0]
278 return float(bit_samples / sample_rate)
281def detect_logic_family(
282 trace: WaveformTrace,
283 *,
284 return_confidence: bool = False,
285) -> str | tuple[str, float]:
286 """Detect logic family from signal levels.
288 Analyzes voltage levels to identify TTL, CMOS, LVTTL, LVCMOS variants.
290 Args:
291 trace: Input analog trace.
292 return_confidence: If True, also return confidence score.
294 Returns:
295 Logic family name (e.g., "TTL", "LVCMOS_3V3"), or tuple of
296 (family, confidence) if return_confidence=True.
297 """
298 from tracekit.analyzers.digital.extraction import LOGIC_FAMILIES
300 data = trace.data
302 # Get voltage levels
303 v_low = float(np.percentile(data, 10))
304 v_high = float(np.percentile(data, 90))
306 # Estimate VCC from high level
307 v_cc_est = v_high * 1.1 # Add margin
309 best_family = "TTL"
310 best_score = 0.0
312 for family, levels in LOGIC_FAMILIES.items():
313 vcc = levels["VCC"]
314 vol = levels["VOL_max"]
315 voh = levels["VOH_min"]
317 # Score based on how well levels match
318 low_match = 1.0 - min(1.0, abs(v_low - vol) / 0.5)
319 high_match = 1.0 - min(1.0, abs(v_high - voh) / 0.5)
320 vcc_match = 1.0 - min(1.0, abs(v_cc_est - vcc) / vcc)
322 score = (low_match + high_match + vcc_match) / 3
324 if score > best_score:
325 best_score = score
326 best_family = family
328 if return_confidence:
329 return best_family, best_score
331 return best_family
334__all__ = [
335 "STANDARD_BAUD_RATES",
336 "detect_baud_rate",
337 "detect_logic_family",
338]