Coverage for src / tracekit / triggering / pattern.py: 99%
106 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Pattern triggering for TraceKit.
3Provides digital pattern matching for multi-channel logic signals.
4Supports exact matches, wildcards, and edge conditions.
6Example:
7 >>> from tracekit.triggering.pattern import PatternTrigger, find_pattern
8 >>> # Find pattern 1010 on 4 channels
9 >>> trigger = PatternTrigger(pattern=[1, 0, 1, 0])
10 >>> events = trigger.find_events(trace)
11"""
13from __future__ import annotations
15from typing import TYPE_CHECKING, Literal
17import numpy as np
19from tracekit.core.exceptions import AnalysisError
20from tracekit.core.types import DigitalTrace, WaveformTrace
21from tracekit.triggering.base import (
22 Trigger,
23 TriggerEvent,
24 TriggerType,
25)
27if TYPE_CHECKING:
28 from numpy.typing import NDArray
31class PatternTrigger(Trigger):
32 """Pattern trigger for multi-bit digital pattern matching.
34 Detects when a digital signal or set of signals matches a
35 specified pattern.
37 For single-channel waveforms, the pattern specifies a sequence
38 of high/low states that must occur consecutively.
40 Attributes:
41 pattern: Pattern to match (list of 0, 1, or None for don't care).
42 levels: Threshold levels for converting analog to digital.
43 match_type: Type of match - "exact", "any", or "sequence".
44 """
46 def __init__(
47 self,
48 pattern: list[int | None],
49 levels: float | list[float] | None = None,
50 match_type: Literal["exact", "sequence"] = "sequence",
51 ) -> None:
52 """Initialize pattern trigger.
54 Args:
55 pattern: Pattern to match. Values are 0, 1, or None (don't care).
56 For multi-channel, this is the pattern across channels.
57 For single-channel sequence, this is the bit sequence.
58 levels: Threshold level(s) for analog-to-digital conversion.
59 If None, uses 50% of signal amplitude.
60 match_type: "exact" matches pattern at each sample,
61 "sequence" finds the pattern as a sequence in time.
63 Raises:
64 AnalysisError: If pattern contains invalid values.
65 """
66 self.pattern = pattern
67 self.levels = levels
68 self.match_type = match_type
70 # Validate pattern
71 for val in pattern:
72 if val is not None and val not in (0, 1):
73 raise AnalysisError(f"Pattern values must be 0, 1, or None, got {val}")
75 def find_events(
76 self,
77 trace: WaveformTrace | DigitalTrace,
78 ) -> list[TriggerEvent]:
79 """Find pattern matches in the trace.
81 Args:
82 trace: Input trace (single channel for sequence matching).
84 Returns:
85 List of trigger events for each pattern match.
86 """
87 # Convert to digital if needed
88 if isinstance(trace, DigitalTrace):
89 digital = trace.data
90 else:
91 level = self._get_level(trace)
92 digital = trace.data >= level
94 sample_period = trace.metadata.time_base
95 events: list[TriggerEvent] = []
97 if self.match_type == "sequence":
98 events = self._find_sequence_matches(digital, sample_period)
99 else:
100 events = self._find_exact_matches(digital, sample_period)
102 return events
104 def _get_level(self, trace: WaveformTrace) -> float:
105 """Get threshold level for analog-to-digital conversion."""
106 if isinstance(self.levels, int | float):
107 return float(self.levels)
108 elif self.levels is None:
109 return (np.min(trace.data) + np.max(trace.data)) / 2 # type: ignore[no-any-return]
110 else:
111 # Multi-channel case - use first level for single trace
112 return float(self.levels[0])
114 def _find_sequence_matches(
115 self,
116 digital: NDArray[np.bool_],
117 sample_period: float,
118 ) -> list[TriggerEvent]:
119 """Find pattern as a sequence in the data."""
120 events: list[TriggerEvent] = []
122 # Convert pattern to expected transitions
123 pattern_len = len(self.pattern)
124 pattern_arr = np.array([p if p is not None else -1 for p in self.pattern])
126 # Slide pattern across data
127 for i in range(len(digital) - pattern_len + 1):
128 segment = digital[i : i + pattern_len].astype(np.int8)
130 # Check match (ignoring don't care values)
131 match = True
132 for j, p in enumerate(pattern_arr):
133 if p >= 0 and segment[j] != p:
134 match = False
135 break
137 if match:
138 events.append(
139 TriggerEvent(
140 timestamp=i * sample_period,
141 sample_index=i,
142 event_type=TriggerType.PATTERN_MATCH,
143 duration=pattern_len * sample_period,
144 data={"pattern": self.pattern},
145 )
146 )
148 return events
150 def _find_exact_matches(
151 self,
152 digital: NDArray[np.bool_],
153 sample_period: float,
154 ) -> list[TriggerEvent]:
155 """Find exact pattern matches at each sample."""
156 events: list[TriggerEvent] = []
158 # For single channel, check if current value matches first pattern bit
159 pattern_val = self.pattern[0]
160 if pattern_val is None:
161 # Don't care - matches everything
162 return events
164 expected = bool(pattern_val)
165 prev_match = digital[0] == expected
167 for i in range(1, len(digital)):
168 curr_match = digital[i] == expected
169 if curr_match and not prev_match:
170 # Transition to matching state
171 events.append(
172 TriggerEvent(
173 timestamp=i * sample_period,
174 sample_index=i,
175 event_type=TriggerType.PATTERN_MATCH,
176 data={"pattern": self.pattern},
177 )
178 )
179 prev_match = curr_match
181 return events
184class MultiChannelPatternTrigger(Trigger):
185 """Pattern trigger for multiple parallel channels.
187 Triggers when all channels simultaneously match the specified pattern.
189 Example:
190 >>> trigger = MultiChannelPatternTrigger(
191 ... pattern=[1, 0, 1, None], # Ch0=1, Ch1=0, Ch2=1, Ch3=don't care
192 ... levels=[1.5, 1.5, 1.5, 1.5]
193 ... )
194 """
196 def __init__(
197 self,
198 pattern: list[int | None],
199 levels: list[float] | None = None,
200 ) -> None:
201 """Initialize multi-channel pattern trigger.
203 Args:
204 pattern: Pattern for each channel (0, 1, or None for don't care).
205 levels: Threshold level for each channel.
206 """
207 self.pattern = pattern
208 self.levels = levels
210 def find_events(
211 self,
212 traces: list[WaveformTrace | DigitalTrace], # type: ignore[override]
213 ) -> list[TriggerEvent]:
214 """Find pattern matches across multiple channels.
216 Args:
217 traces: List of traces (one per channel).
219 Returns:
220 List of trigger events where pattern matches.
222 Raises:
223 AnalysisError: If number of traces doesn't match pattern length.
224 """
225 if len(traces) != len(self.pattern):
226 raise AnalysisError(
227 f"Number of traces ({len(traces)}) must match pattern length ({len(self.pattern)})"
228 )
230 # Convert all traces to digital
231 digitals: list[NDArray[np.bool_]] = []
232 for i, trace in enumerate(traces):
233 if isinstance(trace, DigitalTrace):
234 digitals.append(trace.data)
235 else:
236 if self.levels is not None: 236 ↛ 239line 236 didn't jump to line 239 because the condition on line 236 was always true
237 level = self.levels[i]
238 else:
239 level = (np.min(trace.data) + np.max(trace.data)) / 2
240 digitals.append(trace.data >= level)
242 # Find samples where all channels match pattern
243 sample_period = traces[0].metadata.time_base
244 n_samples = min(len(d) for d in digitals)
245 events: list[TriggerEvent] = []
247 prev_match = False
248 for i in range(n_samples):
249 curr_match = True
250 for _j, (digital, pattern_val) in enumerate(zip(digitals, self.pattern, strict=False)):
251 if pattern_val is not None and digital[i] != bool(pattern_val):
252 curr_match = False
253 break
255 if curr_match and not prev_match:
256 # Transition to matching state
257 events.append(
258 TriggerEvent(
259 timestamp=i * sample_period,
260 sample_index=i,
261 event_type=TriggerType.PATTERN_MATCH,
262 data={"pattern": self.pattern},
263 )
264 )
265 prev_match = curr_match
267 return events
270def find_pattern(
271 trace: WaveformTrace | DigitalTrace,
272 pattern: list[int | None],
273 *,
274 level: float | None = None,
275 return_indices: bool = False,
276) -> NDArray[np.float64] | NDArray[np.int64]:
277 """Find all occurrences of a bit pattern in a trace.
279 Args:
280 trace: Input trace.
281 pattern: Bit pattern to find (0, 1, None for don't care).
282 level: Threshold level. If None, uses 50% of amplitude.
283 return_indices: If True, return sample indices instead of timestamps.
285 Returns:
286 Array of timestamps or indices where pattern was found.
288 Example:
289 >>> # Find start bits (0 followed by data)
290 >>> starts = find_pattern(trace, [0, 1, 1])
291 """
292 trigger = PatternTrigger(pattern=pattern, levels=level)
293 events = trigger.find_events(trace)
295 if return_indices:
296 return np.array([e.sample_index for e in events], dtype=np.int64)
297 return np.array([e.timestamp for e in events], dtype=np.float64)
300def find_bit_sequence(
301 trace: WaveformTrace,
302 bits: str,
303 *,
304 level: float | None = None,
305) -> list[TriggerEvent]:
306 """Find a specific bit sequence in a trace.
308 Args:
309 trace: Input waveform trace.
310 bits: Bit string (e.g., "10101010", "1X0X" where X is don't care).
311 level: Threshold level for digitization.
313 Returns:
314 List of trigger events for each match.
316 Raises:
317 AnalysisError: If invalid bit character in bits string.
319 Example:
320 >>> events = find_bit_sequence(trace, "10110")
321 >>> events = find_bit_sequence(trace, "1XX0") # X = don't care
322 """
323 # Convert string to pattern list
324 pattern: list[int | None] = []
325 for char in bits:
326 if char == "0":
327 pattern.append(0)
328 elif char == "1":
329 pattern.append(1)
330 elif char.upper() == "X":
331 pattern.append(None)
332 else:
333 raise AnalysisError(f"Invalid bit character: {char}")
335 trigger = PatternTrigger(pattern=pattern, levels=level)
336 return trigger.find_events(trace)
339__all__ = [
340 "MultiChannelPatternTrigger",
341 "PatternTrigger",
342 "find_bit_sequence",
343 "find_pattern",
344]