Coverage for src / tracekit / filtering / base.py: 99%

207 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Base filter classes for TraceKit filtering module. 

2 

3Provides abstract base classes for IIR and FIR filter implementations 

4with common interface for filter application and introspection. 

5""" 

6 

7from __future__ import annotations 

8 

9from abc import ABC, abstractmethod 

10from dataclasses import dataclass 

11from typing import TYPE_CHECKING, Literal 

12 

13import numpy as np 

14from scipy import signal 

15 

16from tracekit.core.exceptions import AnalysisError 

17from tracekit.core.types import WaveformTrace 

18 

19if TYPE_CHECKING: 

20 from numpy.typing import NDArray 

21 

22 

23@dataclass 

24class FilterResult: 

25 """Result of filter application with optional introspection data. 

26 

27 Attributes: 

28 trace: Filtered waveform trace. 

29 transfer_function: Optional frequency response H(f). 

30 impulse_response: Optional impulse response h[n]. 

31 group_delay: Optional group delay in samples. 

32 """ 

33 

34 trace: WaveformTrace 

35 transfer_function: NDArray[np.complex128] | None = None 

36 impulse_response: NDArray[np.float64] | None = None 

37 group_delay: NDArray[np.float64] | None = None 

38 

39 

40class Filter(ABC): 

41 """Abstract base class for all filters. 

42 

43 Defines the common interface for filter application and introspection. 

44 All filter implementations must inherit from this class. 

45 

46 Attributes: 

47 sample_rate: Sample rate in Hz for digital filter design. 

48 is_stable: Whether the filter is stable (for IIR filters). 

49 """ 

50 

51 def __init__(self, sample_rate: float | None = None) -> None: 

52 """Initialize filter. 

53 

54 Args: 

55 sample_rate: Sample rate in Hz. If None, must be provided at apply time. 

56 """ 

57 self._sample_rate = sample_rate 

58 self._is_designed = False 

59 

60 @property 

61 def sample_rate(self) -> float | None: 

62 """Sample rate in Hz.""" 

63 return self._sample_rate 

64 

65 @sample_rate.setter 

66 def sample_rate(self, value: float) -> None: 

67 """Set sample rate and mark filter for redesign.""" 

68 if value != self._sample_rate: 

69 self._sample_rate = value 

70 self._is_designed = False 

71 

72 @property 

73 @abstractmethod 

74 def is_stable(self) -> bool: 

75 """Check if filter is stable.""" 

76 ... 

77 

78 @property 

79 @abstractmethod 

80 def order(self) -> int: 

81 """Filter order.""" 

82 ... 

83 

84 @abstractmethod 

85 def apply( 

86 self, 

87 trace: WaveformTrace, 

88 *, 

89 return_details: bool = False, 

90 ) -> WaveformTrace | FilterResult: 

91 """Apply filter to a waveform trace. 

92 

93 Args: 

94 trace: Input waveform trace. 

95 return_details: If True, return FilterResult with introspection data. 

96 

97 Returns: 

98 Filtered trace, or FilterResult if return_details=True. 

99 """ 

100 ... 

101 

102 @abstractmethod 

103 def get_frequency_response( 

104 self, 

105 worN: int | NDArray[np.float64] | None = None, 

106 ) -> tuple[NDArray[np.float64], NDArray[np.complex128]]: 

107 """Get frequency response of the filter. 

108 

109 Args: 

110 worN: Frequencies at which to evaluate. If int, that many frequencies 

111 from 0 to pi (Nyquist). If array, specific frequencies in rad/s. 

112 If None, uses 512 points. 

113 

114 Returns: 

115 Tuple of (frequencies, complex response H(f)). 

116 """ 

117 ... 

118 

119 @abstractmethod 

120 def get_impulse_response( 

121 self, 

122 n_samples: int = 256, 

123 ) -> NDArray[np.float64]: 

124 """Get impulse response of the filter. 

125 

126 Args: 

127 n_samples: Number of samples in impulse response. 

128 

129 Returns: 

130 Impulse response h[n]. 

131 """ 

132 ... 

133 

134 @abstractmethod 

135 def get_step_response( 

136 self, 

137 n_samples: int = 256, 

138 ) -> NDArray[np.float64]: 

139 """Get step response of the filter. 

140 

141 Args: 

142 n_samples: Number of samples in step response. 

143 

144 Returns: 

145 Step response s[n]. 

146 """ 

147 ... 

148 

149 def get_transfer_function( 

150 self, 

151 freqs: NDArray[np.float64] | None = None, 

152 ) -> NDArray[np.complex128]: 

153 """Get transfer function H(f) at specified frequencies. 

154 

155 Args: 

156 freqs: Frequencies in Hz. If None, uses 512 points from 0 to Nyquist. 

157 

158 Returns: 

159 Complex transfer function values. 

160 

161 Raises: 

162 AnalysisError: If sample rate is not set. 

163 """ 

164 if self._sample_rate is None: 

165 raise AnalysisError("Sample rate must be set to compute transfer function") 

166 

167 if freqs is None: 

168 freqs = np.linspace(0, self._sample_rate / 2, 512) 

169 

170 # Convert Hz to normalized frequency 

171 w = 2 * np.pi * freqs / self._sample_rate 

172 _, h = self.get_frequency_response(w) 

173 return h 

174 

175 def get_group_delay( 

176 self, 

177 worN: int | NDArray[np.float64] | None = None, 

178 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

179 """Get group delay of the filter. 

180 

181 Args: 

182 worN: Frequencies at which to evaluate. If int, that many frequencies. 

183 If None, uses 512 points. 

184 

185 Returns: 

186 Tuple of (frequencies, group delay in samples). 

187 """ 

188 if worN is None: 

189 worN = 512 

190 # Default implementation using phase derivative 

191 w, h = self.get_frequency_response(worN) 

192 phase = np.unwrap(np.angle(h)) 

193 dw = np.diff(w) 

194 dphi = np.diff(phase) 

195 # Avoid division by zero 

196 gd = np.zeros_like(w) 

197 gd[:-1] = -dphi / dw 

198 gd[-1] = gd[-2] if len(gd) > 1 else 0 

199 return w, gd 

200 

201 

202class IIRFilter(Filter): 

203 """Infinite Impulse Response filter base class. 

204 

205 Stores filter coefficients in Second-Order Sections (SOS) format 

206 for numerical stability, with optional B/A polynomial format. 

207 

208 Attributes: 

209 sos: Second-order sections coefficients (preferred format). 

210 ba: Numerator/denominator polynomial coefficients (optional). 

211 """ 

212 

213 def __init__( 

214 self, 

215 sample_rate: float | None = None, 

216 sos: NDArray[np.float64] | None = None, 

217 ba: tuple[NDArray[np.float64], NDArray[np.float64]] | None = None, 

218 ) -> None: 

219 """Initialize IIR filter. 

220 

221 Args: 

222 sample_rate: Sample rate in Hz. 

223 sos: Second-order sections array (n_sections, 6). 

224 ba: Tuple of (b, a) polynomial coefficients. 

225 """ 

226 super().__init__(sample_rate) 

227 self._sos = sos 

228 self._ba = ba 

229 self._is_designed = sos is not None or ba is not None 

230 

231 @property 

232 def sos(self) -> NDArray[np.float64] | None: 

233 """Second-order sections coefficients.""" 

234 return self._sos 

235 

236 @property 

237 def ba(self) -> tuple[NDArray[np.float64], NDArray[np.float64]] | None: 

238 """B/A polynomial coefficients.""" 

239 if self._ba is not None: 

240 return self._ba 

241 if self._sos is not None: 

242 # Convert SOS to BA 

243 b, a = signal.sos2tf(self._sos) 

244 return (b, a) 

245 return None 

246 

247 @property 

248 def is_stable(self) -> bool: 

249 """Check if filter is stable (all poles inside unit circle).""" 

250 if self._sos is None and self._ba is None: 

251 return True # Not designed yet 

252 

253 ba = self.ba 

254 if ba is None: 254 ↛ 255line 254 didn't jump to line 255 because the condition on line 254 was never true

255 return True 

256 

257 _, a = ba 

258 poles = np.roots(a) 

259 return bool(np.all(np.abs(poles) < 1.0)) 

260 

261 @property 

262 def order(self) -> int: 

263 """Filter order.""" 

264 if self._sos is not None: 

265 return 2 * len(self._sos) 

266 if self._ba is not None: 

267 return len(self._ba[1]) - 1 

268 return 0 

269 

270 @property 

271 def poles(self) -> NDArray[np.complex128]: 

272 """Filter poles in z-domain.""" 

273 ba = self.ba 

274 if ba is None: 

275 return np.array([], dtype=np.complex128) 

276 _, a = ba 

277 return np.roots(a).astype(np.complex128) 

278 

279 @property 

280 def zeros(self) -> NDArray[np.complex128]: 

281 """Filter zeros in z-domain.""" 

282 ba = self.ba 

283 if ba is None: 

284 return np.array([], dtype=np.complex128) 

285 b, _ = ba 

286 return np.roots(b).astype(np.complex128) 

287 

288 def apply( 

289 self, 

290 trace: WaveformTrace, 

291 *, 

292 return_details: bool = False, 

293 filtfilt: bool = True, 

294 ) -> WaveformTrace | FilterResult: 

295 """Apply IIR filter to waveform. 

296 

297 Args: 

298 trace: Input waveform trace. 

299 return_details: If True, return FilterResult with introspection data. 

300 filtfilt: If True, use zero-phase filtering (forward-backward). 

301 If False, use causal filtering. 

302 

303 Returns: 

304 Filtered trace, or FilterResult if return_details=True. 

305 

306 Raises: 

307 AnalysisError: If filter not designed or is unstable. 

308 """ 

309 if self._sos is None and self._ba is None: 

310 raise AnalysisError("Filter not designed - no coefficients available") 

311 

312 if not self.is_stable: 

313 raise AnalysisError("Cannot apply unstable filter") 

314 

315 # Apply filter 

316 if self._sos is not None: 

317 if filtfilt: 

318 filtered_data = signal.sosfiltfilt(self._sos, trace.data) 

319 else: 

320 filtered_data = signal.sosfilt(self._sos, trace.data) 

321 else: 

322 b, a = self._ba # type: ignore[misc] 

323 if filtfilt: 

324 filtered_data = signal.filtfilt(b, a, trace.data) 

325 else: 

326 filtered_data = signal.lfilter(b, a, trace.data) 

327 

328 filtered_trace = WaveformTrace( 

329 data=filtered_data.astype(np.float64), 

330 metadata=trace.metadata, 

331 ) 

332 

333 if return_details: 

334 _w, h = self.get_frequency_response() 

335 impulse = self.get_impulse_response() 

336 _, gd = self.get_group_delay() 

337 return FilterResult( 

338 trace=filtered_trace, 

339 transfer_function=h, 

340 impulse_response=impulse, 

341 group_delay=gd, 

342 ) 

343 

344 return filtered_trace 

345 

346 def get_frequency_response( 

347 self, 

348 worN: int | NDArray[np.float64] | None = None, 

349 ) -> tuple[NDArray[np.float64], NDArray[np.complex128]]: 

350 """Get frequency response.""" 

351 if worN is None: 

352 worN = 512 

353 

354 if self._sos is not None: 

355 w, h = signal.sosfreqz(self._sos, worN=worN) 

356 elif self._ba is not None: 

357 w, h = signal.freqz(self._ba[0], self._ba[1], worN=worN) 

358 else: 

359 raise AnalysisError("Filter not designed") 

360 

361 return w.astype(np.float64), h.astype(np.complex128) 

362 

363 def get_impulse_response( 

364 self, 

365 n_samples: int = 256, 

366 ) -> NDArray[np.float64]: 

367 """Get impulse response.""" 

368 impulse = np.zeros(n_samples) 

369 impulse[0] = 1.0 

370 

371 response: NDArray[np.float64] 

372 if self._sos is not None: 

373 response = signal.sosfilt(self._sos, impulse).astype(np.float64) 

374 elif self._ba is not None: 

375 response = signal.lfilter(self._ba[0], self._ba[1], impulse).astype(np.float64) 

376 else: 

377 raise AnalysisError("Filter not designed") 

378 

379 return response 

380 

381 def get_step_response( 

382 self, 

383 n_samples: int = 256, 

384 ) -> NDArray[np.float64]: 

385 """Get step response.""" 

386 step = np.ones(n_samples) 

387 

388 response: NDArray[np.float64] 

389 if self._sos is not None: 

390 response = signal.sosfilt(self._sos, step).astype(np.float64) 

391 elif self._ba is not None: 

392 response = signal.lfilter(self._ba[0], self._ba[1], step).astype(np.float64) 

393 else: 

394 raise AnalysisError("Filter not designed") 

395 

396 return response 

397 

398 

399class FIRFilter(Filter): 

400 """Finite Impulse Response filter base class. 

401 

402 Stores filter coefficients as a single array of tap weights. 

403 FIR filters are always stable and can achieve linear phase. 

404 

405 Attributes: 

406 coeffs: Filter tap coefficients. 

407 """ 

408 

409 def __init__( 

410 self, 

411 sample_rate: float | None = None, 

412 coeffs: NDArray[np.float64] | None = None, 

413 ) -> None: 

414 """Initialize FIR filter. 

415 

416 Args: 

417 sample_rate: Sample rate in Hz. 

418 coeffs: Filter coefficients (tap weights). 

419 """ 

420 super().__init__(sample_rate) 

421 self._coeffs = coeffs 

422 self._is_designed = coeffs is not None 

423 

424 @property 

425 def coeffs(self) -> NDArray[np.float64] | None: 

426 """Filter coefficients.""" 

427 return self._coeffs 

428 

429 @coeffs.setter 

430 def coeffs(self, value: NDArray[np.float64]) -> None: 

431 """Set filter coefficients.""" 

432 self._coeffs = value 

433 self._is_designed = True 

434 

435 @property 

436 def is_stable(self) -> bool: 

437 """FIR filters are always stable.""" 

438 return True 

439 

440 @property 

441 def order(self) -> int: 

442 """Filter order (number of taps - 1).""" 

443 if self._coeffs is not None: 

444 return len(self._coeffs) - 1 

445 return 0 

446 

447 @property 

448 def is_linear_phase(self) -> bool: 

449 """Check if filter has linear phase (symmetric or antisymmetric coefficients).""" 

450 if self._coeffs is None: 

451 return False 

452 len(self._coeffs) 

453 # Check symmetry 

454 symmetric = np.allclose(self._coeffs, self._coeffs[::-1]) 

455 antisymmetric = np.allclose(self._coeffs, -self._coeffs[::-1]) 

456 return symmetric or antisymmetric # type: ignore[no-any-return] 

457 

458 def apply( 

459 self, 

460 trace: WaveformTrace, 

461 *, 

462 return_details: bool = False, 

463 mode: Literal["full", "same", "valid"] = "same", 

464 ) -> WaveformTrace | FilterResult: 

465 """Apply FIR filter to waveform. 

466 

467 Args: 

468 trace: Input waveform trace. 

469 return_details: If True, return FilterResult with introspection data. 

470 mode: Convolution mode - "same" preserves length. 

471 

472 Returns: 

473 Filtered trace, or FilterResult if return_details=True. 

474 

475 Raises: 

476 AnalysisError: If filter not designed. 

477 """ 

478 if self._coeffs is None: 

479 raise AnalysisError("Filter not designed - no coefficients available") 

480 

481 # Apply filter using convolution 

482 filtered_data = np.convolve(trace.data, self._coeffs, mode=mode) 

483 

484 filtered_trace = WaveformTrace( 

485 data=filtered_data.astype(np.float64), 

486 metadata=trace.metadata, 

487 ) 

488 

489 if return_details: 

490 _w, h = self.get_frequency_response() 

491 impulse = self.get_impulse_response() 

492 _, gd = self.get_group_delay() 

493 return FilterResult( 

494 trace=filtered_trace, 

495 transfer_function=h, 

496 impulse_response=impulse, 

497 group_delay=gd, 

498 ) 

499 

500 return filtered_trace 

501 

502 def get_frequency_response( 

503 self, 

504 worN: int | NDArray[np.float64] | None = None, 

505 ) -> tuple[NDArray[np.float64], NDArray[np.complex128]]: 

506 """Get frequency response.""" 

507 if self._coeffs is None: 

508 raise AnalysisError("Filter not designed") 

509 

510 if worN is None: 

511 worN = 512 

512 

513 w, h = signal.freqz(self._coeffs, 1, worN=worN) 

514 return w.astype(np.float64), h.astype(np.complex128) 

515 

516 def get_impulse_response( 

517 self, 

518 n_samples: int = 256, 

519 ) -> NDArray[np.float64]: 

520 """Get impulse response (just the coefficients, zero-padded).""" 

521 if self._coeffs is None: 

522 raise AnalysisError("Filter not designed") 

523 

524 if len(self._coeffs) >= n_samples: 

525 return self._coeffs[:n_samples].astype(np.float64) 

526 

527 response = np.zeros(n_samples) 

528 response[: len(self._coeffs)] = self._coeffs 

529 return response.astype(np.float64) 

530 

531 def get_step_response( 

532 self, 

533 n_samples: int = 256, 

534 ) -> NDArray[np.float64]: 

535 """Get step response.""" 

536 if self._coeffs is None: 

537 raise AnalysisError("Filter not designed") 

538 

539 step = np.ones(n_samples) 

540 response = np.convolve(step, self._coeffs, mode="full")[:n_samples] 

541 return response.astype(np.float64) 

542 

543 def get_group_delay( 

544 self, 

545 worN: int | NDArray[np.float64] | None = None, 

546 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

547 """Get group delay.""" 

548 if self._coeffs is None: 

549 raise AnalysisError("Filter not designed") 

550 

551 if worN is None: 

552 worN = 512 

553 

554 w, gd = signal.group_delay((self._coeffs, 1), w=worN) 

555 return w.astype(np.float64), gd.astype(np.float64) 

556 

557 

558__all__ = [ 

559 "FIRFilter", 

560 "Filter", 

561 "FilterResult", 

562 "IIRFilter", 

563]