Coverage for src / tracekit / filtering / base.py: 99%
207 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Base filter classes for TraceKit filtering module.
3Provides abstract base classes for IIR and FIR filter implementations
4with common interface for filter application and introspection.
5"""
7from __future__ import annotations
9from abc import ABC, abstractmethod
10from dataclasses import dataclass
11from typing import TYPE_CHECKING, Literal
13import numpy as np
14from scipy import signal
16from tracekit.core.exceptions import AnalysisError
17from tracekit.core.types import WaveformTrace
19if TYPE_CHECKING:
20 from numpy.typing import NDArray
23@dataclass
24class FilterResult:
25 """Result of filter application with optional introspection data.
27 Attributes:
28 trace: Filtered waveform trace.
29 transfer_function: Optional frequency response H(f).
30 impulse_response: Optional impulse response h[n].
31 group_delay: Optional group delay in samples.
32 """
34 trace: WaveformTrace
35 transfer_function: NDArray[np.complex128] | None = None
36 impulse_response: NDArray[np.float64] | None = None
37 group_delay: NDArray[np.float64] | None = None
40class Filter(ABC):
41 """Abstract base class for all filters.
43 Defines the common interface for filter application and introspection.
44 All filter implementations must inherit from this class.
46 Attributes:
47 sample_rate: Sample rate in Hz for digital filter design.
48 is_stable: Whether the filter is stable (for IIR filters).
49 """
51 def __init__(self, sample_rate: float | None = None) -> None:
52 """Initialize filter.
54 Args:
55 sample_rate: Sample rate in Hz. If None, must be provided at apply time.
56 """
57 self._sample_rate = sample_rate
58 self._is_designed = False
60 @property
61 def sample_rate(self) -> float | None:
62 """Sample rate in Hz."""
63 return self._sample_rate
65 @sample_rate.setter
66 def sample_rate(self, value: float) -> None:
67 """Set sample rate and mark filter for redesign."""
68 if value != self._sample_rate:
69 self._sample_rate = value
70 self._is_designed = False
72 @property
73 @abstractmethod
74 def is_stable(self) -> bool:
75 """Check if filter is stable."""
76 ...
78 @property
79 @abstractmethod
80 def order(self) -> int:
81 """Filter order."""
82 ...
84 @abstractmethod
85 def apply(
86 self,
87 trace: WaveformTrace,
88 *,
89 return_details: bool = False,
90 ) -> WaveformTrace | FilterResult:
91 """Apply filter to a waveform trace.
93 Args:
94 trace: Input waveform trace.
95 return_details: If True, return FilterResult with introspection data.
97 Returns:
98 Filtered trace, or FilterResult if return_details=True.
99 """
100 ...
102 @abstractmethod
103 def get_frequency_response(
104 self,
105 worN: int | NDArray[np.float64] | None = None,
106 ) -> tuple[NDArray[np.float64], NDArray[np.complex128]]:
107 """Get frequency response of the filter.
109 Args:
110 worN: Frequencies at which to evaluate. If int, that many frequencies
111 from 0 to pi (Nyquist). If array, specific frequencies in rad/s.
112 If None, uses 512 points.
114 Returns:
115 Tuple of (frequencies, complex response H(f)).
116 """
117 ...
119 @abstractmethod
120 def get_impulse_response(
121 self,
122 n_samples: int = 256,
123 ) -> NDArray[np.float64]:
124 """Get impulse response of the filter.
126 Args:
127 n_samples: Number of samples in impulse response.
129 Returns:
130 Impulse response h[n].
131 """
132 ...
134 @abstractmethod
135 def get_step_response(
136 self,
137 n_samples: int = 256,
138 ) -> NDArray[np.float64]:
139 """Get step response of the filter.
141 Args:
142 n_samples: Number of samples in step response.
144 Returns:
145 Step response s[n].
146 """
147 ...
149 def get_transfer_function(
150 self,
151 freqs: NDArray[np.float64] | None = None,
152 ) -> NDArray[np.complex128]:
153 """Get transfer function H(f) at specified frequencies.
155 Args:
156 freqs: Frequencies in Hz. If None, uses 512 points from 0 to Nyquist.
158 Returns:
159 Complex transfer function values.
161 Raises:
162 AnalysisError: If sample rate is not set.
163 """
164 if self._sample_rate is None:
165 raise AnalysisError("Sample rate must be set to compute transfer function")
167 if freqs is None:
168 freqs = np.linspace(0, self._sample_rate / 2, 512)
170 # Convert Hz to normalized frequency
171 w = 2 * np.pi * freqs / self._sample_rate
172 _, h = self.get_frequency_response(w)
173 return h
175 def get_group_delay(
176 self,
177 worN: int | NDArray[np.float64] | None = None,
178 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
179 """Get group delay of the filter.
181 Args:
182 worN: Frequencies at which to evaluate. If int, that many frequencies.
183 If None, uses 512 points.
185 Returns:
186 Tuple of (frequencies, group delay in samples).
187 """
188 if worN is None:
189 worN = 512
190 # Default implementation using phase derivative
191 w, h = self.get_frequency_response(worN)
192 phase = np.unwrap(np.angle(h))
193 dw = np.diff(w)
194 dphi = np.diff(phase)
195 # Avoid division by zero
196 gd = np.zeros_like(w)
197 gd[:-1] = -dphi / dw
198 gd[-1] = gd[-2] if len(gd) > 1 else 0
199 return w, gd
202class IIRFilter(Filter):
203 """Infinite Impulse Response filter base class.
205 Stores filter coefficients in Second-Order Sections (SOS) format
206 for numerical stability, with optional B/A polynomial format.
208 Attributes:
209 sos: Second-order sections coefficients (preferred format).
210 ba: Numerator/denominator polynomial coefficients (optional).
211 """
213 def __init__(
214 self,
215 sample_rate: float | None = None,
216 sos: NDArray[np.float64] | None = None,
217 ba: tuple[NDArray[np.float64], NDArray[np.float64]] | None = None,
218 ) -> None:
219 """Initialize IIR filter.
221 Args:
222 sample_rate: Sample rate in Hz.
223 sos: Second-order sections array (n_sections, 6).
224 ba: Tuple of (b, a) polynomial coefficients.
225 """
226 super().__init__(sample_rate)
227 self._sos = sos
228 self._ba = ba
229 self._is_designed = sos is not None or ba is not None
231 @property
232 def sos(self) -> NDArray[np.float64] | None:
233 """Second-order sections coefficients."""
234 return self._sos
236 @property
237 def ba(self) -> tuple[NDArray[np.float64], NDArray[np.float64]] | None:
238 """B/A polynomial coefficients."""
239 if self._ba is not None:
240 return self._ba
241 if self._sos is not None:
242 # Convert SOS to BA
243 b, a = signal.sos2tf(self._sos)
244 return (b, a)
245 return None
247 @property
248 def is_stable(self) -> bool:
249 """Check if filter is stable (all poles inside unit circle)."""
250 if self._sos is None and self._ba is None:
251 return True # Not designed yet
253 ba = self.ba
254 if ba is None: 254 ↛ 255line 254 didn't jump to line 255 because the condition on line 254 was never true
255 return True
257 _, a = ba
258 poles = np.roots(a)
259 return bool(np.all(np.abs(poles) < 1.0))
261 @property
262 def order(self) -> int:
263 """Filter order."""
264 if self._sos is not None:
265 return 2 * len(self._sos)
266 if self._ba is not None:
267 return len(self._ba[1]) - 1
268 return 0
270 @property
271 def poles(self) -> NDArray[np.complex128]:
272 """Filter poles in z-domain."""
273 ba = self.ba
274 if ba is None:
275 return np.array([], dtype=np.complex128)
276 _, a = ba
277 return np.roots(a).astype(np.complex128)
279 @property
280 def zeros(self) -> NDArray[np.complex128]:
281 """Filter zeros in z-domain."""
282 ba = self.ba
283 if ba is None:
284 return np.array([], dtype=np.complex128)
285 b, _ = ba
286 return np.roots(b).astype(np.complex128)
288 def apply(
289 self,
290 trace: WaveformTrace,
291 *,
292 return_details: bool = False,
293 filtfilt: bool = True,
294 ) -> WaveformTrace | FilterResult:
295 """Apply IIR filter to waveform.
297 Args:
298 trace: Input waveform trace.
299 return_details: If True, return FilterResult with introspection data.
300 filtfilt: If True, use zero-phase filtering (forward-backward).
301 If False, use causal filtering.
303 Returns:
304 Filtered trace, or FilterResult if return_details=True.
306 Raises:
307 AnalysisError: If filter not designed or is unstable.
308 """
309 if self._sos is None and self._ba is None:
310 raise AnalysisError("Filter not designed - no coefficients available")
312 if not self.is_stable:
313 raise AnalysisError("Cannot apply unstable filter")
315 # Apply filter
316 if self._sos is not None:
317 if filtfilt:
318 filtered_data = signal.sosfiltfilt(self._sos, trace.data)
319 else:
320 filtered_data = signal.sosfilt(self._sos, trace.data)
321 else:
322 b, a = self._ba # type: ignore[misc]
323 if filtfilt:
324 filtered_data = signal.filtfilt(b, a, trace.data)
325 else:
326 filtered_data = signal.lfilter(b, a, trace.data)
328 filtered_trace = WaveformTrace(
329 data=filtered_data.astype(np.float64),
330 metadata=trace.metadata,
331 )
333 if return_details:
334 _w, h = self.get_frequency_response()
335 impulse = self.get_impulse_response()
336 _, gd = self.get_group_delay()
337 return FilterResult(
338 trace=filtered_trace,
339 transfer_function=h,
340 impulse_response=impulse,
341 group_delay=gd,
342 )
344 return filtered_trace
346 def get_frequency_response(
347 self,
348 worN: int | NDArray[np.float64] | None = None,
349 ) -> tuple[NDArray[np.float64], NDArray[np.complex128]]:
350 """Get frequency response."""
351 if worN is None:
352 worN = 512
354 if self._sos is not None:
355 w, h = signal.sosfreqz(self._sos, worN=worN)
356 elif self._ba is not None:
357 w, h = signal.freqz(self._ba[0], self._ba[1], worN=worN)
358 else:
359 raise AnalysisError("Filter not designed")
361 return w.astype(np.float64), h.astype(np.complex128)
363 def get_impulse_response(
364 self,
365 n_samples: int = 256,
366 ) -> NDArray[np.float64]:
367 """Get impulse response."""
368 impulse = np.zeros(n_samples)
369 impulse[0] = 1.0
371 response: NDArray[np.float64]
372 if self._sos is not None:
373 response = signal.sosfilt(self._sos, impulse).astype(np.float64)
374 elif self._ba is not None:
375 response = signal.lfilter(self._ba[0], self._ba[1], impulse).astype(np.float64)
376 else:
377 raise AnalysisError("Filter not designed")
379 return response
381 def get_step_response(
382 self,
383 n_samples: int = 256,
384 ) -> NDArray[np.float64]:
385 """Get step response."""
386 step = np.ones(n_samples)
388 response: NDArray[np.float64]
389 if self._sos is not None:
390 response = signal.sosfilt(self._sos, step).astype(np.float64)
391 elif self._ba is not None:
392 response = signal.lfilter(self._ba[0], self._ba[1], step).astype(np.float64)
393 else:
394 raise AnalysisError("Filter not designed")
396 return response
399class FIRFilter(Filter):
400 """Finite Impulse Response filter base class.
402 Stores filter coefficients as a single array of tap weights.
403 FIR filters are always stable and can achieve linear phase.
405 Attributes:
406 coeffs: Filter tap coefficients.
407 """
409 def __init__(
410 self,
411 sample_rate: float | None = None,
412 coeffs: NDArray[np.float64] | None = None,
413 ) -> None:
414 """Initialize FIR filter.
416 Args:
417 sample_rate: Sample rate in Hz.
418 coeffs: Filter coefficients (tap weights).
419 """
420 super().__init__(sample_rate)
421 self._coeffs = coeffs
422 self._is_designed = coeffs is not None
424 @property
425 def coeffs(self) -> NDArray[np.float64] | None:
426 """Filter coefficients."""
427 return self._coeffs
429 @coeffs.setter
430 def coeffs(self, value: NDArray[np.float64]) -> None:
431 """Set filter coefficients."""
432 self._coeffs = value
433 self._is_designed = True
435 @property
436 def is_stable(self) -> bool:
437 """FIR filters are always stable."""
438 return True
440 @property
441 def order(self) -> int:
442 """Filter order (number of taps - 1)."""
443 if self._coeffs is not None:
444 return len(self._coeffs) - 1
445 return 0
447 @property
448 def is_linear_phase(self) -> bool:
449 """Check if filter has linear phase (symmetric or antisymmetric coefficients)."""
450 if self._coeffs is None:
451 return False
452 len(self._coeffs)
453 # Check symmetry
454 symmetric = np.allclose(self._coeffs, self._coeffs[::-1])
455 antisymmetric = np.allclose(self._coeffs, -self._coeffs[::-1])
456 return symmetric or antisymmetric # type: ignore[no-any-return]
458 def apply(
459 self,
460 trace: WaveformTrace,
461 *,
462 return_details: bool = False,
463 mode: Literal["full", "same", "valid"] = "same",
464 ) -> WaveformTrace | FilterResult:
465 """Apply FIR filter to waveform.
467 Args:
468 trace: Input waveform trace.
469 return_details: If True, return FilterResult with introspection data.
470 mode: Convolution mode - "same" preserves length.
472 Returns:
473 Filtered trace, or FilterResult if return_details=True.
475 Raises:
476 AnalysisError: If filter not designed.
477 """
478 if self._coeffs is None:
479 raise AnalysisError("Filter not designed - no coefficients available")
481 # Apply filter using convolution
482 filtered_data = np.convolve(trace.data, self._coeffs, mode=mode)
484 filtered_trace = WaveformTrace(
485 data=filtered_data.astype(np.float64),
486 metadata=trace.metadata,
487 )
489 if return_details:
490 _w, h = self.get_frequency_response()
491 impulse = self.get_impulse_response()
492 _, gd = self.get_group_delay()
493 return FilterResult(
494 trace=filtered_trace,
495 transfer_function=h,
496 impulse_response=impulse,
497 group_delay=gd,
498 )
500 return filtered_trace
502 def get_frequency_response(
503 self,
504 worN: int | NDArray[np.float64] | None = None,
505 ) -> tuple[NDArray[np.float64], NDArray[np.complex128]]:
506 """Get frequency response."""
507 if self._coeffs is None:
508 raise AnalysisError("Filter not designed")
510 if worN is None:
511 worN = 512
513 w, h = signal.freqz(self._coeffs, 1, worN=worN)
514 return w.astype(np.float64), h.astype(np.complex128)
516 def get_impulse_response(
517 self,
518 n_samples: int = 256,
519 ) -> NDArray[np.float64]:
520 """Get impulse response (just the coefficients, zero-padded)."""
521 if self._coeffs is None:
522 raise AnalysisError("Filter not designed")
524 if len(self._coeffs) >= n_samples:
525 return self._coeffs[:n_samples].astype(np.float64)
527 response = np.zeros(n_samples)
528 response[: len(self._coeffs)] = self._coeffs
529 return response.astype(np.float64)
531 def get_step_response(
532 self,
533 n_samples: int = 256,
534 ) -> NDArray[np.float64]:
535 """Get step response."""
536 if self._coeffs is None:
537 raise AnalysisError("Filter not designed")
539 step = np.ones(n_samples)
540 response = np.convolve(step, self._coeffs, mode="full")[:n_samples]
541 return response.astype(np.float64)
543 def get_group_delay(
544 self,
545 worN: int | NDArray[np.float64] | None = None,
546 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
547 """Get group delay."""
548 if self._coeffs is None:
549 raise AnalysisError("Filter not designed")
551 if worN is None:
552 worN = 512
554 w, gd = signal.group_delay((self._coeffs, 1), w=worN)
555 return w.astype(np.float64), gd.astype(np.float64)
558__all__ = [
559 "FIRFilter",
560 "Filter",
561 "FilterResult",
562 "IIRFilter",
563]