Coverage for src / tracekit / triggering / pattern.py: 99%

106 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Pattern triggering for TraceKit. 

2 

3Provides digital pattern matching for multi-channel logic signals. 

4Supports exact matches, wildcards, and edge conditions. 

5 

6Example: 

7 >>> from tracekit.triggering.pattern import PatternTrigger, find_pattern 

8 >>> # Find pattern 1010 on 4 channels 

9 >>> trigger = PatternTrigger(pattern=[1, 0, 1, 0]) 

10 >>> events = trigger.find_events(trace) 

11""" 

12 

13from __future__ import annotations 

14 

15from typing import TYPE_CHECKING, Literal 

16 

17import numpy as np 

18 

19from tracekit.core.exceptions import AnalysisError 

20from tracekit.core.types import DigitalTrace, WaveformTrace 

21from tracekit.triggering.base import ( 

22 Trigger, 

23 TriggerEvent, 

24 TriggerType, 

25) 

26 

27if TYPE_CHECKING: 

28 from numpy.typing import NDArray 

29 

30 

31class PatternTrigger(Trigger): 

32 """Pattern trigger for multi-bit digital pattern matching. 

33 

34 Detects when a digital signal or set of signals matches a 

35 specified pattern. 

36 

37 For single-channel waveforms, the pattern specifies a sequence 

38 of high/low states that must occur consecutively. 

39 

40 Attributes: 

41 pattern: Pattern to match (list of 0, 1, or None for don't care). 

42 levels: Threshold levels for converting analog to digital. 

43 match_type: Type of match - "exact", "any", or "sequence". 

44 """ 

45 

46 def __init__( 

47 self, 

48 pattern: list[int | None], 

49 levels: float | list[float] | None = None, 

50 match_type: Literal["exact", "sequence"] = "sequence", 

51 ) -> None: 

52 """Initialize pattern trigger. 

53 

54 Args: 

55 pattern: Pattern to match. Values are 0, 1, or None (don't care). 

56 For multi-channel, this is the pattern across channels. 

57 For single-channel sequence, this is the bit sequence. 

58 levels: Threshold level(s) for analog-to-digital conversion. 

59 If None, uses 50% of signal amplitude. 

60 match_type: "exact" matches pattern at each sample, 

61 "sequence" finds the pattern as a sequence in time. 

62 

63 Raises: 

64 AnalysisError: If pattern contains invalid values. 

65 """ 

66 self.pattern = pattern 

67 self.levels = levels 

68 self.match_type = match_type 

69 

70 # Validate pattern 

71 for val in pattern: 

72 if val is not None and val not in (0, 1): 

73 raise AnalysisError(f"Pattern values must be 0, 1, or None, got {val}") 

74 

75 def find_events( 

76 self, 

77 trace: WaveformTrace | DigitalTrace, 

78 ) -> list[TriggerEvent]: 

79 """Find pattern matches in the trace. 

80 

81 Args: 

82 trace: Input trace (single channel for sequence matching). 

83 

84 Returns: 

85 List of trigger events for each pattern match. 

86 """ 

87 # Convert to digital if needed 

88 if isinstance(trace, DigitalTrace): 

89 digital = trace.data 

90 else: 

91 level = self._get_level(trace) 

92 digital = trace.data >= level 

93 

94 sample_period = trace.metadata.time_base 

95 events: list[TriggerEvent] = [] 

96 

97 if self.match_type == "sequence": 

98 events = self._find_sequence_matches(digital, sample_period) 

99 else: 

100 events = self._find_exact_matches(digital, sample_period) 

101 

102 return events 

103 

104 def _get_level(self, trace: WaveformTrace) -> float: 

105 """Get threshold level for analog-to-digital conversion.""" 

106 if isinstance(self.levels, int | float): 

107 return float(self.levels) 

108 elif self.levels is None: 

109 return (np.min(trace.data) + np.max(trace.data)) / 2 # type: ignore[no-any-return] 

110 else: 

111 # Multi-channel case - use first level for single trace 

112 return float(self.levels[0]) 

113 

114 def _find_sequence_matches( 

115 self, 

116 digital: NDArray[np.bool_], 

117 sample_period: float, 

118 ) -> list[TriggerEvent]: 

119 """Find pattern as a sequence in the data.""" 

120 events: list[TriggerEvent] = [] 

121 

122 # Convert pattern to expected transitions 

123 pattern_len = len(self.pattern) 

124 pattern_arr = np.array([p if p is not None else -1 for p in self.pattern]) 

125 

126 # Slide pattern across data 

127 for i in range(len(digital) - pattern_len + 1): 

128 segment = digital[i : i + pattern_len].astype(np.int8) 

129 

130 # Check match (ignoring don't care values) 

131 match = True 

132 for j, p in enumerate(pattern_arr): 

133 if p >= 0 and segment[j] != p: 

134 match = False 

135 break 

136 

137 if match: 

138 events.append( 

139 TriggerEvent( 

140 timestamp=i * sample_period, 

141 sample_index=i, 

142 event_type=TriggerType.PATTERN_MATCH, 

143 duration=pattern_len * sample_period, 

144 data={"pattern": self.pattern}, 

145 ) 

146 ) 

147 

148 return events 

149 

150 def _find_exact_matches( 

151 self, 

152 digital: NDArray[np.bool_], 

153 sample_period: float, 

154 ) -> list[TriggerEvent]: 

155 """Find exact pattern matches at each sample.""" 

156 events: list[TriggerEvent] = [] 

157 

158 # For single channel, check if current value matches first pattern bit 

159 pattern_val = self.pattern[0] 

160 if pattern_val is None: 

161 # Don't care - matches everything 

162 return events 

163 

164 expected = bool(pattern_val) 

165 prev_match = digital[0] == expected 

166 

167 for i in range(1, len(digital)): 

168 curr_match = digital[i] == expected 

169 if curr_match and not prev_match: 

170 # Transition to matching state 

171 events.append( 

172 TriggerEvent( 

173 timestamp=i * sample_period, 

174 sample_index=i, 

175 event_type=TriggerType.PATTERN_MATCH, 

176 data={"pattern": self.pattern}, 

177 ) 

178 ) 

179 prev_match = curr_match 

180 

181 return events 

182 

183 

184class MultiChannelPatternTrigger(Trigger): 

185 """Pattern trigger for multiple parallel channels. 

186 

187 Triggers when all channels simultaneously match the specified pattern. 

188 

189 Example: 

190 >>> trigger = MultiChannelPatternTrigger( 

191 ... pattern=[1, 0, 1, None], # Ch0=1, Ch1=0, Ch2=1, Ch3=don't care 

192 ... levels=[1.5, 1.5, 1.5, 1.5] 

193 ... ) 

194 """ 

195 

196 def __init__( 

197 self, 

198 pattern: list[int | None], 

199 levels: list[float] | None = None, 

200 ) -> None: 

201 """Initialize multi-channel pattern trigger. 

202 

203 Args: 

204 pattern: Pattern for each channel (0, 1, or None for don't care). 

205 levels: Threshold level for each channel. 

206 """ 

207 self.pattern = pattern 

208 self.levels = levels 

209 

210 def find_events( 

211 self, 

212 traces: list[WaveformTrace | DigitalTrace], # type: ignore[override] 

213 ) -> list[TriggerEvent]: 

214 """Find pattern matches across multiple channels. 

215 

216 Args: 

217 traces: List of traces (one per channel). 

218 

219 Returns: 

220 List of trigger events where pattern matches. 

221 

222 Raises: 

223 AnalysisError: If number of traces doesn't match pattern length. 

224 """ 

225 if len(traces) != len(self.pattern): 

226 raise AnalysisError( 

227 f"Number of traces ({len(traces)}) must match pattern length ({len(self.pattern)})" 

228 ) 

229 

230 # Convert all traces to digital 

231 digitals: list[NDArray[np.bool_]] = [] 

232 for i, trace in enumerate(traces): 

233 if isinstance(trace, DigitalTrace): 

234 digitals.append(trace.data) 

235 else: 

236 if self.levels is not None: 236 ↛ 239line 236 didn't jump to line 239 because the condition on line 236 was always true

237 level = self.levels[i] 

238 else: 

239 level = (np.min(trace.data) + np.max(trace.data)) / 2 

240 digitals.append(trace.data >= level) 

241 

242 # Find samples where all channels match pattern 

243 sample_period = traces[0].metadata.time_base 

244 n_samples = min(len(d) for d in digitals) 

245 events: list[TriggerEvent] = [] 

246 

247 prev_match = False 

248 for i in range(n_samples): 

249 curr_match = True 

250 for _j, (digital, pattern_val) in enumerate(zip(digitals, self.pattern, strict=False)): 

251 if pattern_val is not None and digital[i] != bool(pattern_val): 

252 curr_match = False 

253 break 

254 

255 if curr_match and not prev_match: 

256 # Transition to matching state 

257 events.append( 

258 TriggerEvent( 

259 timestamp=i * sample_period, 

260 sample_index=i, 

261 event_type=TriggerType.PATTERN_MATCH, 

262 data={"pattern": self.pattern}, 

263 ) 

264 ) 

265 prev_match = curr_match 

266 

267 return events 

268 

269 

270def find_pattern( 

271 trace: WaveformTrace | DigitalTrace, 

272 pattern: list[int | None], 

273 *, 

274 level: float | None = None, 

275 return_indices: bool = False, 

276) -> NDArray[np.float64] | NDArray[np.int64]: 

277 """Find all occurrences of a bit pattern in a trace. 

278 

279 Args: 

280 trace: Input trace. 

281 pattern: Bit pattern to find (0, 1, None for don't care). 

282 level: Threshold level. If None, uses 50% of amplitude. 

283 return_indices: If True, return sample indices instead of timestamps. 

284 

285 Returns: 

286 Array of timestamps or indices where pattern was found. 

287 

288 Example: 

289 >>> # Find start bits (0 followed by data) 

290 >>> starts = find_pattern(trace, [0, 1, 1]) 

291 """ 

292 trigger = PatternTrigger(pattern=pattern, levels=level) 

293 events = trigger.find_events(trace) 

294 

295 if return_indices: 

296 return np.array([e.sample_index for e in events], dtype=np.int64) 

297 return np.array([e.timestamp for e in events], dtype=np.float64) 

298 

299 

300def find_bit_sequence( 

301 trace: WaveformTrace, 

302 bits: str, 

303 *, 

304 level: float | None = None, 

305) -> list[TriggerEvent]: 

306 """Find a specific bit sequence in a trace. 

307 

308 Args: 

309 trace: Input waveform trace. 

310 bits: Bit string (e.g., "10101010", "1X0X" where X is don't care). 

311 level: Threshold level for digitization. 

312 

313 Returns: 

314 List of trigger events for each match. 

315 

316 Raises: 

317 AnalysisError: If invalid bit character in bits string. 

318 

319 Example: 

320 >>> events = find_bit_sequence(trace, "10110") 

321 >>> events = find_bit_sequence(trace, "1XX0") # X = don't care 

322 """ 

323 # Convert string to pattern list 

324 pattern: list[int | None] = [] 

325 for char in bits: 

326 if char == "0": 

327 pattern.append(0) 

328 elif char == "1": 

329 pattern.append(1) 

330 elif char.upper() == "X": 

331 pattern.append(None) 

332 else: 

333 raise AnalysisError(f"Invalid bit character: {char}") 

334 

335 trigger = PatternTrigger(pattern=pattern, levels=level) 

336 return trigger.find_events(trace) 

337 

338 

339__all__ = [ 

340 "MultiChannelPatternTrigger", 

341 "PatternTrigger", 

342 "find_bit_sequence", 

343 "find_pattern", 

344]