Coverage for src / tracekit / filtering / convenience.py: 93%
106 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Convenience filtering functions for TraceKit.
3Provides simple one-call filter functions for common operations like
4moving average, median filter, Savitzky-Golay smoothing, and matched
5filtering.
7Example:
8 >>> from tracekit.filtering.convenience import low_pass, moving_average
9 >>> filtered = low_pass(trace, cutoff=1e6)
10 >>> smoothed = moving_average(trace, window_size=11)
11"""
13from __future__ import annotations
15from typing import TYPE_CHECKING, Any, Literal
17import numpy as np
18from scipy import ndimage, signal
20from tracekit.core.exceptions import AnalysisError
21from tracekit.core.types import WaveformTrace
22from tracekit.filtering.design import (
23 BandPassFilter,
24 BandStopFilter,
25 HighPassFilter,
26 LowPassFilter,
27)
29if TYPE_CHECKING:
30 from numpy.typing import NDArray
33def low_pass(
34 trace: WaveformTrace,
35 cutoff: float,
36 *,
37 order: int = 4,
38 filter_type: Literal[
39 "butterworth", "chebyshev1", "chebyshev2", "bessel", "elliptic"
40 ] = "butterworth",
41) -> WaveformTrace:
42 """Apply low-pass filter to trace.
44 Args:
45 trace: Input waveform trace.
46 cutoff: Cutoff frequency in Hz.
47 order: Filter order (default 4).
48 filter_type: Type of filter (default Butterworth).
50 Returns:
51 Filtered waveform trace.
53 Example:
54 >>> filtered = low_pass(trace, cutoff=1e6)
55 """
56 filt = LowPassFilter(
57 cutoff=cutoff,
58 sample_rate=trace.metadata.sample_rate,
59 order=order,
60 filter_type=filter_type,
61 )
62 result = filt.apply(trace)
63 if isinstance(result, WaveformTrace): 63 ↛ 65line 63 didn't jump to line 65 because the condition on line 63 was always true
64 return result
65 return result.trace
68def high_pass(
69 trace: WaveformTrace,
70 cutoff: float,
71 *,
72 order: int = 4,
73 filter_type: Literal[
74 "butterworth", "chebyshev1", "chebyshev2", "bessel", "elliptic"
75 ] = "butterworth",
76) -> WaveformTrace:
77 """Apply high-pass filter to trace.
79 Args:
80 trace: Input waveform trace.
81 cutoff: Cutoff frequency in Hz.
82 order: Filter order (default 4).
83 filter_type: Type of filter (default Butterworth).
85 Returns:
86 Filtered waveform trace.
88 Example:
89 >>> filtered = high_pass(trace, cutoff=100) # Remove DC and low frequencies
90 """
91 filt = HighPassFilter(
92 cutoff=cutoff,
93 sample_rate=trace.metadata.sample_rate,
94 order=order,
95 filter_type=filter_type,
96 )
97 result = filt.apply(trace)
98 if isinstance(result, WaveformTrace): 98 ↛ 100line 98 didn't jump to line 100 because the condition on line 98 was always true
99 return result
100 return result.trace
103def band_pass(
104 trace: WaveformTrace,
105 low: float,
106 high: float,
107 *,
108 order: int = 4,
109 filter_type: Literal[
110 "butterworth", "chebyshev1", "chebyshev2", "bessel", "elliptic"
111 ] = "butterworth",
112) -> WaveformTrace:
113 """Apply band-pass filter to trace.
115 Args:
116 trace: Input waveform trace.
117 low: Lower cutoff frequency in Hz.
118 high: Upper cutoff frequency in Hz.
119 order: Filter order (default 4).
120 filter_type: Type of filter (default Butterworth).
122 Returns:
123 Filtered waveform trace.
125 Example:
126 >>> filtered = band_pass(trace, low=1e3, high=10e3)
127 """
128 filt = BandPassFilter(
129 low=low,
130 high=high,
131 sample_rate=trace.metadata.sample_rate,
132 order=order,
133 filter_type=filter_type,
134 )
135 result = filt.apply(trace)
136 if isinstance(result, WaveformTrace): 136 ↛ 138line 136 didn't jump to line 138 because the condition on line 136 was always true
137 return result
138 return result.trace
141def band_stop(
142 trace: WaveformTrace,
143 low: float,
144 high: float,
145 *,
146 order: int = 4,
147 filter_type: Literal[
148 "butterworth", "chebyshev1", "chebyshev2", "bessel", "elliptic"
149 ] = "butterworth",
150) -> WaveformTrace:
151 """Apply band-stop (notch) filter to trace.
153 Args:
154 trace: Input waveform trace.
155 low: Lower cutoff frequency in Hz.
156 high: Upper cutoff frequency in Hz.
157 order: Filter order (default 4).
158 filter_type: Type of filter (default Butterworth).
160 Returns:
161 Filtered waveform trace.
163 Example:
164 >>> filtered = band_stop(trace, low=59, high=61) # Remove 60 Hz
165 """
166 filt = BandStopFilter(
167 low=low,
168 high=high,
169 sample_rate=trace.metadata.sample_rate,
170 order=order,
171 filter_type=filter_type,
172 )
173 result = filt.apply(trace)
174 if isinstance(result, WaveformTrace): 174 ↛ 176line 174 didn't jump to line 176 because the condition on line 174 was always true
175 return result
176 return result.trace
179def notch_filter(
180 trace: WaveformTrace,
181 freq: float,
182 *,
183 q_factor: float = 30.0,
184) -> WaveformTrace:
185 """Apply narrow notch filter at specified frequency.
187 Uses a band-stop Butterworth filter with bandwidth determined by Q factor.
188 Bandwidth (Hz) = freq / Q
190 Args:
191 trace: Input waveform trace.
192 freq: Center frequency to notch out in Hz.
193 q_factor: Quality factor (higher = narrower notch). Default 30.
195 Returns:
196 Filtered waveform trace.
198 Raises:
199 AnalysisError: If notch frequency exceeds Nyquist frequency.
201 Example:
202 >>> filtered = notch_filter(trace, freq=60, q_factor=30) # Remove 60 Hz line noise
203 """
204 sample_rate = trace.metadata.sample_rate
206 if freq >= sample_rate / 2:
207 raise AnalysisError(
208 f"Notch frequency {freq} Hz must be less than Nyquist {sample_rate / 2} Hz"
209 )
211 # Calculate bandwidth from Q factor: BW = f0 / Q
212 bandwidth = freq / q_factor
214 # Design band-stop filter centered at freq with calculated bandwidth
215 # Use 4th order Butterworth for good attenuation
216 low = max(freq - bandwidth / 2, 0.1) # Avoid zero frequency
217 high = min(freq + bandwidth / 2, sample_rate / 2 - 1) # Stay below Nyquist
219 # Normalize frequencies
220 wn = [low / (sample_rate / 2), high / (sample_rate / 2)]
222 # Design bandstop filter
223 sos = signal.butter(4, wn, btype="bandstop", output="sos")
225 # Apply zero-phase filter
226 filtered_data = signal.sosfiltfilt(sos, trace.data)
228 return WaveformTrace(
229 data=filtered_data.astype(np.float64),
230 metadata=trace.metadata,
231 )
234def moving_average(
235 trace: WaveformTrace,
236 window_size: int,
237 *,
238 mode: Literal["same", "valid", "full"] = "same",
239) -> WaveformTrace:
240 """Apply moving average filter.
242 Simple FIR filter with uniform weights.
244 Args:
245 trace: Input waveform trace.
246 window_size: Number of samples in averaging window (must be odd for 'same' mode).
247 mode: Convolution mode - "same" preserves length.
249 Returns:
250 Filtered waveform trace.
252 Raises:
253 AnalysisError: If window_size is not positive or exceeds data length.
255 Example:
256 >>> smoothed = moving_average(trace, window_size=11)
257 """
258 if window_size < 1:
259 raise AnalysisError(f"Window size must be positive, got {window_size}")
261 if window_size > len(trace.data):
262 raise AnalysisError(f"Window size {window_size} exceeds data length {len(trace.data)}")
264 kernel = np.ones(window_size) / window_size
265 filtered_data = np.convolve(trace.data, kernel, mode=mode)
267 return WaveformTrace(
268 data=filtered_data.astype(np.float64),
269 metadata=trace.metadata,
270 )
273def median_filter(
274 trace: WaveformTrace,
275 kernel_size: int,
276) -> WaveformTrace:
277 """Apply median filter for spike/impulse noise removal.
279 Non-linear filter that preserves edges while removing outliers.
281 Args:
282 trace: Input waveform trace.
283 kernel_size: Size of the median filter kernel (must be odd).
285 Returns:
286 Filtered waveform trace.
288 Raises:
289 AnalysisError: If kernel_size is not positive or not odd.
291 Example:
292 >>> cleaned = median_filter(trace, kernel_size=5) # Remove impulse noise
293 """
294 if kernel_size < 1:
295 raise AnalysisError(f"Kernel size must be positive, got {kernel_size}")
297 if kernel_size % 2 == 0:
298 raise AnalysisError(f"Kernel size must be odd, got {kernel_size}")
300 filtered_data = ndimage.median_filter(trace.data, size=kernel_size)
302 return WaveformTrace(
303 data=filtered_data.astype(np.float64),
304 metadata=trace.metadata,
305 )
308def savgol_filter(
309 trace: WaveformTrace,
310 window_length: int,
311 polyorder: int,
312 *,
313 deriv: int = 0,
314) -> WaveformTrace:
315 """Apply Savitzky-Golay smoothing filter.
317 Smooths data while preserving higher moments (peaks, etc.) better
318 than simple moving average.
320 Args:
321 trace: Input waveform trace.
322 window_length: Length of filter window (must be odd and > polyorder).
323 polyorder: Order of polynomial used in fitting.
324 deriv: Derivative order (0 = smoothing, 1 = first derivative, etc.).
326 Returns:
327 Filtered waveform trace.
329 Raises:
330 AnalysisError: If window_length is not odd or polyorder is invalid.
332 Example:
333 >>> smoothed = savgol_filter(trace, window_length=11, polyorder=3)
334 """
335 if window_length % 2 == 0:
336 raise AnalysisError(f"Window length must be odd, got {window_length}")
338 if polyorder >= window_length:
339 raise AnalysisError(
340 f"Polynomial order {polyorder} must be less than window length {window_length}"
341 )
343 filtered_data = signal.savgol_filter(trace.data, window_length, polyorder, deriv=deriv)
345 return WaveformTrace(
346 data=filtered_data.astype(np.float64),
347 metadata=trace.metadata,
348 )
351def matched_filter(
352 trace: WaveformTrace,
353 template: NDArray[np.floating[Any]],
354 *,
355 normalize: bool = True,
356) -> WaveformTrace:
357 """Apply matched filter for pulse detection.
359 Correlates the input with a known pulse template to detect
360 occurrences of that pulse shape.
362 Args:
363 trace: Input waveform trace.
364 template: Template pulse to match.
365 normalize: If True, normalize template for unit energy.
367 Returns:
368 Matched filter output trace. Peaks indicate template matches.
370 Raises:
371 AnalysisError: If template is empty or exceeds data length.
373 Example:
374 >>> # Detect a specific pulse shape
375 >>> pulse_template = np.array([0, 0.5, 1.0, 0.5, 0])
376 >>> match_output = matched_filter(trace, pulse_template)
377 >>> # Find peaks in match_output for detection
378 """
379 if len(template) == 0:
380 raise AnalysisError("Template cannot be empty")
382 if len(template) > len(trace.data):
383 raise AnalysisError(
384 f"Template length {len(template)} exceeds data length {len(trace.data)}"
385 )
387 # Matched filter is correlation with time-reversed template
388 h = template[::-1].copy()
390 if normalize:
391 energy = np.sum(h**2)
392 if energy > 0: 392 ↛ 396line 392 didn't jump to line 396 because the condition on line 392 was always true
393 h = h / np.sqrt(energy)
395 # Correlate (convolve with time-reversed template)
396 output = np.convolve(trace.data, h, mode="same")
398 return WaveformTrace(
399 data=output.astype(np.float64),
400 metadata=trace.metadata,
401 )
404def exponential_moving_average(
405 trace: WaveformTrace,
406 alpha: float,
407) -> WaveformTrace:
408 """Apply exponential moving average (EMA) filter.
410 IIR filter with exponential decay weighting.
412 Args:
413 trace: Input waveform trace.
414 alpha: Smoothing factor (0 < alpha <= 1). Higher = less smoothing.
416 Returns:
417 Filtered waveform trace.
419 Raises:
420 AnalysisError: If alpha is not in range (0, 1].
422 Example:
423 >>> smoothed = exponential_moving_average(trace, alpha=0.1)
424 """
425 if not 0 < alpha <= 1:
426 raise AnalysisError(f"Alpha must be in (0, 1], got {alpha}")
428 # EMA as IIR filter: y[n] = alpha * x[n] + (1 - alpha) * y[n-1]
429 # Transfer function: H(z) = alpha / (1 - (1-alpha) * z^-1)
430 b = np.array([alpha])
431 a = np.array([1.0, -(1 - alpha)])
433 filtered_data = signal.lfilter(b, a, trace.data)
435 return WaveformTrace(
436 data=filtered_data.astype(np.float64),
437 metadata=trace.metadata,
438 )
441def gaussian_filter(
442 trace: WaveformTrace,
443 sigma: float,
444) -> WaveformTrace:
445 """Apply Gaussian smoothing filter.
447 Smooth with Gaussian kernel of specified standard deviation.
449 Args:
450 trace: Input waveform trace.
451 sigma: Standard deviation of Gaussian kernel in samples.
453 Returns:
454 Filtered waveform trace.
456 Raises:
457 AnalysisError: If sigma is not positive.
459 Example:
460 >>> smoothed = gaussian_filter(trace, sigma=3.0)
461 """
462 if sigma <= 0:
463 raise AnalysisError(f"Sigma must be positive, got {sigma}")
465 filtered_data = ndimage.gaussian_filter1d(trace.data, sigma)
467 return WaveformTrace(
468 data=filtered_data.astype(np.float64),
469 metadata=trace.metadata,
470 )
473def differentiate(
474 trace: WaveformTrace,
475 *,
476 order: int = 1,
477) -> WaveformTrace:
478 """Compute numerical derivative of trace.
480 Uses numpy gradient for smooth differentiation.
482 Args:
483 trace: Input waveform trace.
484 order: Derivative order (1 = first derivative, 2 = second, etc.).
486 Returns:
487 Differentiated waveform trace. Units change (V -> V/s, etc.).
489 Raises:
490 AnalysisError: If order is not positive.
492 Example:
493 >>> velocity = differentiate(position_trace)
494 >>> acceleration = differentiate(position_trace, order=2)
495 """
496 if order < 1:
497 raise AnalysisError(f"Derivative order must be positive, got {order}")
499 sample_period = trace.metadata.time_base
500 result = trace.data.copy()
502 for _ in range(order):
503 result = np.gradient(result, sample_period)
505 return WaveformTrace(
506 data=result.astype(np.float64),
507 metadata=trace.metadata,
508 )
511def integrate(
512 trace: WaveformTrace,
513 *,
514 method: Literal["cumtrapz", "cumsum"] = "cumtrapz",
515 initial: float = 0.0,
516) -> WaveformTrace:
517 """Compute numerical integral of trace.
519 Args:
520 trace: Input waveform trace.
521 method: Integration method - "cumtrapz" (trapezoidal) or "cumsum".
522 initial: Initial value at t=0.
524 Returns:
525 Integrated waveform trace. Units change (V -> V*s, etc.).
527 Raises:
528 AnalysisError: If method is not recognized.
530 Example:
531 >>> position = integrate(velocity_trace)
532 """
533 sample_period = trace.metadata.time_base
535 if method == "cumtrapz":
536 from scipy.integrate import cumulative_trapezoid
538 result = cumulative_trapezoid(trace.data, dx=sample_period, initial=initial)
539 elif method == "cumsum": 539 ↛ 540line 539 didn't jump to line 540 because the condition on line 539 was never true
540 result = np.cumsum(trace.data) * sample_period + initial
541 else:
542 raise AnalysisError(f"Unknown integration method: {method}")
544 return WaveformTrace(
545 data=result.astype(np.float64),
546 metadata=trace.metadata,
547 )
550__all__ = [
551 "band_pass",
552 "band_stop",
553 "differentiate",
554 "exponential_moving_average",
555 "gaussian_filter",
556 "high_pass",
557 "integrate",
558 "low_pass",
559 "matched_filter",
560 "median_filter",
561 "moving_average",
562 "notch_filter",
563 "savgol_filter",
564]