Coverage for src / tracekit / utils / autodetect.py: 94%

114 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Auto-detection utilities for signal analysis. 

2 

3This module provides utilities for automatic detection of signal 

4parameters such as baud rate, logic levels, and protocol types. 

5 

6 

7Example: 

8 >>> from tracekit.utils.autodetect import detect_baud_rate 

9 >>> baudrate = detect_baud_rate(trace) 

10 >>> print(f"Detected baud rate: {baudrate}") 

11 

12References: 

13 Standard baud rates. and UART specifications. 

14""" 

15 

16from __future__ import annotations 

17 

18from typing import TYPE_CHECKING, Literal 

19 

20import numpy as np 

21 

22from tracekit.core.types import DigitalTrace, WaveformTrace 

23 

24if TYPE_CHECKING: 

25 from numpy.typing import NDArray 

26 

27# Standard baud rates (RS-232, UART, CAN, etc.) 

28STANDARD_BAUD_RATES: tuple[int, ...] = ( 

29 300, 

30 600, 

31 1200, 

32 2400, 

33 4800, 

34 9600, 

35 14400, 

36 19200, 

37 28800, 

38 38400, 

39 57600, 

40 76800, 

41 115200, 

42 230400, 

43 250000, # CAN common 

44 460800, 

45 500000, # CAN common 

46 576000, 

47 921600, 

48 1000000, # 1 Mbps 

49 1500000, 

50 2000000, 

51 3000000, 

52 4000000, 

53) 

54 

55 

56def detect_baud_rate( 

57 trace: WaveformTrace | DigitalTrace, 

58 *, 

59 threshold: float | Literal["auto"] = "auto", 

60 method: Literal["pulse_width", "edge_timing", "autocorr"] = "pulse_width", 

61 tolerance: float = 0.05, 

62 return_confidence: bool = False, 

63) -> int | tuple[int, float]: 

64 """Detect baud rate from signal timing. 

65 

66 Analyzes pulse widths or edge timing to determine the symbol rate, 

67 then maps to the nearest standard baud rate. 

68 

69 Args: 

70 trace: Input trace (analog or digital). 

71 threshold: Threshold for analog to digital conversion. 

72 method: Detection method: 

73 - "pulse_width": Minimum pulse width (default) 

74 - "edge_timing": Edge-to-edge timing analysis 

75 - "autocorr": Autocorrelation peak detection 

76 tolerance: Tolerance for matching to standard rate (default 5%). 

77 return_confidence: If True, also return confidence score. 

78 

79 Returns: 

80 Detected baud rate (nearest standard), or tuple of (rate, confidence) 

81 if return_confidence=True. 

82 

83 Raises: 

84 ValueError: If unknown detection method specified. 

85 

86 Example: 

87 >>> baudrate = detect_baud_rate(trace) 

88 >>> print(f"Detected: {baudrate} bps") 

89 

90 >>> baudrate, confidence = detect_baud_rate(trace, return_confidence=True) 

91 >>> print(f"Detected: {baudrate} bps ({confidence:.0%} confidence)") 

92 

93 References: 

94 RS-232 Standard Baud Rates 

95 """ 

96 # Get digital representation 

97 if isinstance(trace, WaveformTrace): 

98 from tracekit.analyzers.digital.extraction import to_digital 

99 

100 digital_trace = to_digital(trace, threshold=threshold) 

101 data = digital_trace.data 

102 else: 

103 data = trace.data 

104 

105 sample_rate = trace.metadata.sample_rate 

106 

107 if method == "pulse_width": 

108 bit_period = _detect_via_pulse_width(data, sample_rate) 

109 elif method == "edge_timing": 

110 bit_period = _detect_via_edge_timing(data, sample_rate) 

111 elif method == "autocorr": 

112 bit_period = _detect_via_autocorrelation(data, sample_rate) 

113 else: 

114 raise ValueError(f"Unknown method: {method}") 

115 

116 if bit_period <= 0 or np.isnan(bit_period): 

117 if return_confidence: 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true

118 return 0, 0.0 

119 return 0 

120 

121 # Convert to baud rate 

122 measured_rate = 1.0 / bit_period 

123 

124 # Find nearest standard rate 

125 best_rate = 0 

126 best_error = float("inf") 

127 

128 for std_rate in STANDARD_BAUD_RATES: 

129 error = abs(measured_rate - std_rate) / std_rate 

130 if error < best_error: 

131 best_error = error 

132 best_rate = std_rate 

133 

134 # Compute confidence 

135 confidence = max(0.0, 1.0 - best_error / tolerance) if best_error <= tolerance else 0.0 

136 

137 if return_confidence: 

138 return best_rate, confidence 

139 

140 return best_rate 

141 

142 

143def _detect_via_pulse_width(data: NDArray[np.bool_], sample_rate: float) -> float: 

144 """Detect bit period from minimum pulse width. 

145 

146 Args: 

147 data: Digital signal data. 

148 sample_rate: Sample rate in Hz. 

149 

150 Returns: 

151 Estimated bit period in seconds. 

152 """ 

153 # Find pulse widths (runs of consecutive values) 

154 pulse_widths = [] 

155 

156 current_value = data[0] 

157 run_length = 1 

158 

159 for i in range(1, len(data)): 

160 if data[i] == current_value: 

161 run_length += 1 

162 else: 

163 pulse_widths.append(run_length) 

164 current_value = data[i] 

165 run_length = 1 

166 

167 # Add final run 

168 pulse_widths.append(run_length) 

169 

170 if len(pulse_widths) == 0: 170 ↛ 171line 170 didn't jump to line 171 because the condition on line 170 was never true

171 return 0.0 

172 

173 pulse_widths_arr = np.array(pulse_widths, dtype=np.float64) 

174 

175 # Filter out very short pulses (noise) 

176 min_pulse = max(2, np.min(pulse_widths_arr[pulse_widths_arr > 1])) 

177 

178 # The minimum pulse width corresponds to a single bit 

179 # Use the mode of small pulses for robustness 

180 small_pulses = pulse_widths_arr[pulse_widths_arr <= min_pulse * 1.5] 

181 

182 bit_samples = min_pulse if len(small_pulses) == 0 else np.median(small_pulses) 

183 

184 return float(bit_samples / sample_rate) 

185 

186 

187def _detect_via_edge_timing(data: NDArray[np.bool_], sample_rate: float) -> float: 

188 """Detect bit period from edge-to-edge timing. 

189 

190 Args: 

191 data: Digital signal data. 

192 sample_rate: Sample rate in Hz. 

193 

194 Returns: 

195 Estimated bit period in seconds. 

196 """ 

197 # Find all edges 

198 transitions = np.diff(data.astype(np.int8)) 

199 edge_indices = np.where(transitions != 0)[0] 

200 

201 if len(edge_indices) < 2: 

202 return 0.0 

203 

204 # Compute edge intervals 

205 intervals = np.diff(edge_indices).astype(np.float64) 

206 

207 if len(intervals) == 0: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true

208 return 0.0 

209 

210 # Intervals should be multiples of bit period 

211 # Find GCD-like value using histogram 

212 min_interval = np.min(intervals) 

213 max_check = min(min_interval * 2, np.median(intervals)) 

214 

215 # The bit period is the smallest common interval 

216 # Use histogram to find the cluster 

217 bins = np.arange(1, max_check + 1) 

218 hist, _ = np.histogram(intervals, bins=bins) 

219 

220 if len(hist) == 0 or np.max(hist) == 0: 220 ↛ 221line 220 didn't jump to line 221 because the condition on line 220 was never true

221 bit_samples = min_interval 

222 else: 

223 # Find first significant peak 

224 threshold = np.max(hist) * 0.3 

225 peaks = np.where(hist >= threshold)[0] 

226 

227 if len(peaks) > 0: 227 ↛ 230line 227 didn't jump to line 230 because the condition on line 227 was always true

228 bit_samples = peaks[0] + 1 # +1 for bin offset 

229 else: 

230 bit_samples = min_interval 

231 

232 return float(bit_samples / sample_rate) 

233 

234 

235def _detect_via_autocorrelation(data: NDArray[np.bool_], sample_rate: float) -> float: 

236 """Detect bit period via autocorrelation. 

237 

238 Args: 

239 data: Digital signal data. 

240 sample_rate: Sample rate in Hz. 

241 

242 Returns: 

243 Estimated bit period in seconds. 

244 """ 

245 # Convert to float for correlation 

246 signal = data.astype(np.float64) * 2 - 1 # Map to [-1, 1] 

247 

248 # Remove DC 

249 signal = signal - np.mean(signal) 

250 

251 # Compute autocorrelation 

252 n = len(signal) 

253 max_lag = min(n // 2, int(sample_rate / 300)) # Limit to reasonable range 

254 

255 autocorr = np.correlate(signal[: max_lag * 2], signal[: max_lag * 2], mode="full") 

256 autocorr = autocorr[len(autocorr) // 2 :] # Keep positive lags 

257 

258 # Normalize 

259 autocorr = autocorr / autocorr[0] 

260 

261 # Find first significant peak after lag 0 

262 # Skip initial samples to avoid lag-0 region 

263 min_lag = max(2, max_lag // 100) 

264 

265 # Find local maxima 

266 peaks = [] 

267 for i in range(min_lag, len(autocorr) - 1): 

268 if autocorr[i] > autocorr[i - 1] and autocorr[i] > autocorr[i + 1]: 

269 if autocorr[i] > 0.3: # Significance threshold 

270 peaks.append((i, autocorr[i])) 

271 

272 if len(peaks) == 0: 

273 return 0.0 

274 

275 # First significant peak is likely the bit period 

276 bit_samples = peaks[0][0] 

277 

278 return float(bit_samples / sample_rate) 

279 

280 

281def detect_logic_family( 

282 trace: WaveformTrace, 

283 *, 

284 return_confidence: bool = False, 

285) -> str | tuple[str, float]: 

286 """Detect logic family from signal levels. 

287 

288 Analyzes voltage levels to identify TTL, CMOS, LVTTL, LVCMOS variants. 

289 

290 Args: 

291 trace: Input analog trace. 

292 return_confidence: If True, also return confidence score. 

293 

294 Returns: 

295 Logic family name (e.g., "TTL", "LVCMOS_3V3"), or tuple of 

296 (family, confidence) if return_confidence=True. 

297 """ 

298 from tracekit.analyzers.digital.extraction import LOGIC_FAMILIES 

299 

300 data = trace.data 

301 

302 # Get voltage levels 

303 v_low = float(np.percentile(data, 10)) 

304 v_high = float(np.percentile(data, 90)) 

305 

306 # Estimate VCC from high level 

307 v_cc_est = v_high * 1.1 # Add margin 

308 

309 best_family = "TTL" 

310 best_score = 0.0 

311 

312 for family, levels in LOGIC_FAMILIES.items(): 

313 vcc = levels["VCC"] 

314 vol = levels["VOL_max"] 

315 voh = levels["VOH_min"] 

316 

317 # Score based on how well levels match 

318 low_match = 1.0 - min(1.0, abs(v_low - vol) / 0.5) 

319 high_match = 1.0 - min(1.0, abs(v_high - voh) / 0.5) 

320 vcc_match = 1.0 - min(1.0, abs(v_cc_est - vcc) / vcc) 

321 

322 score = (low_match + high_match + vcc_match) / 3 

323 

324 if score > best_score: 

325 best_score = score 

326 best_family = family 

327 

328 if return_confidence: 

329 return best_family, best_score 

330 

331 return best_family 

332 

333 

334__all__ = [ 

335 "STANDARD_BAUD_RATES", 

336 "detect_baud_rate", 

337 "detect_logic_family", 

338]