Coverage for src / tracekit / filtering / convenience.py: 93%

106 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Convenience filtering functions for TraceKit. 

2 

3Provides simple one-call filter functions for common operations like 

4moving average, median filter, Savitzky-Golay smoothing, and matched 

5filtering. 

6 

7Example: 

8 >>> from tracekit.filtering.convenience import low_pass, moving_average 

9 >>> filtered = low_pass(trace, cutoff=1e6) 

10 >>> smoothed = moving_average(trace, window_size=11) 

11""" 

12 

13from __future__ import annotations 

14 

15from typing import TYPE_CHECKING, Any, Literal 

16 

17import numpy as np 

18from scipy import ndimage, signal 

19 

20from tracekit.core.exceptions import AnalysisError 

21from tracekit.core.types import WaveformTrace 

22from tracekit.filtering.design import ( 

23 BandPassFilter, 

24 BandStopFilter, 

25 HighPassFilter, 

26 LowPassFilter, 

27) 

28 

29if TYPE_CHECKING: 

30 from numpy.typing import NDArray 

31 

32 

33def low_pass( 

34 trace: WaveformTrace, 

35 cutoff: float, 

36 *, 

37 order: int = 4, 

38 filter_type: Literal[ 

39 "butterworth", "chebyshev1", "chebyshev2", "bessel", "elliptic" 

40 ] = "butterworth", 

41) -> WaveformTrace: 

42 """Apply low-pass filter to trace. 

43 

44 Args: 

45 trace: Input waveform trace. 

46 cutoff: Cutoff frequency in Hz. 

47 order: Filter order (default 4). 

48 filter_type: Type of filter (default Butterworth). 

49 

50 Returns: 

51 Filtered waveform trace. 

52 

53 Example: 

54 >>> filtered = low_pass(trace, cutoff=1e6) 

55 """ 

56 filt = LowPassFilter( 

57 cutoff=cutoff, 

58 sample_rate=trace.metadata.sample_rate, 

59 order=order, 

60 filter_type=filter_type, 

61 ) 

62 result = filt.apply(trace) 

63 if isinstance(result, WaveformTrace): 63 ↛ 65line 63 didn't jump to line 65 because the condition on line 63 was always true

64 return result 

65 return result.trace 

66 

67 

68def high_pass( 

69 trace: WaveformTrace, 

70 cutoff: float, 

71 *, 

72 order: int = 4, 

73 filter_type: Literal[ 

74 "butterworth", "chebyshev1", "chebyshev2", "bessel", "elliptic" 

75 ] = "butterworth", 

76) -> WaveformTrace: 

77 """Apply high-pass filter to trace. 

78 

79 Args: 

80 trace: Input waveform trace. 

81 cutoff: Cutoff frequency in Hz. 

82 order: Filter order (default 4). 

83 filter_type: Type of filter (default Butterworth). 

84 

85 Returns: 

86 Filtered waveform trace. 

87 

88 Example: 

89 >>> filtered = high_pass(trace, cutoff=100) # Remove DC and low frequencies 

90 """ 

91 filt = HighPassFilter( 

92 cutoff=cutoff, 

93 sample_rate=trace.metadata.sample_rate, 

94 order=order, 

95 filter_type=filter_type, 

96 ) 

97 result = filt.apply(trace) 

98 if isinstance(result, WaveformTrace): 98 ↛ 100line 98 didn't jump to line 100 because the condition on line 98 was always true

99 return result 

100 return result.trace 

101 

102 

103def band_pass( 

104 trace: WaveformTrace, 

105 low: float, 

106 high: float, 

107 *, 

108 order: int = 4, 

109 filter_type: Literal[ 

110 "butterworth", "chebyshev1", "chebyshev2", "bessel", "elliptic" 

111 ] = "butterworth", 

112) -> WaveformTrace: 

113 """Apply band-pass filter to trace. 

114 

115 Args: 

116 trace: Input waveform trace. 

117 low: Lower cutoff frequency in Hz. 

118 high: Upper cutoff frequency in Hz. 

119 order: Filter order (default 4). 

120 filter_type: Type of filter (default Butterworth). 

121 

122 Returns: 

123 Filtered waveform trace. 

124 

125 Example: 

126 >>> filtered = band_pass(trace, low=1e3, high=10e3) 

127 """ 

128 filt = BandPassFilter( 

129 low=low, 

130 high=high, 

131 sample_rate=trace.metadata.sample_rate, 

132 order=order, 

133 filter_type=filter_type, 

134 ) 

135 result = filt.apply(trace) 

136 if isinstance(result, WaveformTrace): 136 ↛ 138line 136 didn't jump to line 138 because the condition on line 136 was always true

137 return result 

138 return result.trace 

139 

140 

141def band_stop( 

142 trace: WaveformTrace, 

143 low: float, 

144 high: float, 

145 *, 

146 order: int = 4, 

147 filter_type: Literal[ 

148 "butterworth", "chebyshev1", "chebyshev2", "bessel", "elliptic" 

149 ] = "butterworth", 

150) -> WaveformTrace: 

151 """Apply band-stop (notch) filter to trace. 

152 

153 Args: 

154 trace: Input waveform trace. 

155 low: Lower cutoff frequency in Hz. 

156 high: Upper cutoff frequency in Hz. 

157 order: Filter order (default 4). 

158 filter_type: Type of filter (default Butterworth). 

159 

160 Returns: 

161 Filtered waveform trace. 

162 

163 Example: 

164 >>> filtered = band_stop(trace, low=59, high=61) # Remove 60 Hz 

165 """ 

166 filt = BandStopFilter( 

167 low=low, 

168 high=high, 

169 sample_rate=trace.metadata.sample_rate, 

170 order=order, 

171 filter_type=filter_type, 

172 ) 

173 result = filt.apply(trace) 

174 if isinstance(result, WaveformTrace): 174 ↛ 176line 174 didn't jump to line 176 because the condition on line 174 was always true

175 return result 

176 return result.trace 

177 

178 

179def notch_filter( 

180 trace: WaveformTrace, 

181 freq: float, 

182 *, 

183 q_factor: float = 30.0, 

184) -> WaveformTrace: 

185 """Apply narrow notch filter at specified frequency. 

186 

187 Uses a band-stop Butterworth filter with bandwidth determined by Q factor. 

188 Bandwidth (Hz) = freq / Q 

189 

190 Args: 

191 trace: Input waveform trace. 

192 freq: Center frequency to notch out in Hz. 

193 q_factor: Quality factor (higher = narrower notch). Default 30. 

194 

195 Returns: 

196 Filtered waveform trace. 

197 

198 Raises: 

199 AnalysisError: If notch frequency exceeds Nyquist frequency. 

200 

201 Example: 

202 >>> filtered = notch_filter(trace, freq=60, q_factor=30) # Remove 60 Hz line noise 

203 """ 

204 sample_rate = trace.metadata.sample_rate 

205 

206 if freq >= sample_rate / 2: 

207 raise AnalysisError( 

208 f"Notch frequency {freq} Hz must be less than Nyquist {sample_rate / 2} Hz" 

209 ) 

210 

211 # Calculate bandwidth from Q factor: BW = f0 / Q 

212 bandwidth = freq / q_factor 

213 

214 # Design band-stop filter centered at freq with calculated bandwidth 

215 # Use 4th order Butterworth for good attenuation 

216 low = max(freq - bandwidth / 2, 0.1) # Avoid zero frequency 

217 high = min(freq + bandwidth / 2, sample_rate / 2 - 1) # Stay below Nyquist 

218 

219 # Normalize frequencies 

220 wn = [low / (sample_rate / 2), high / (sample_rate / 2)] 

221 

222 # Design bandstop filter 

223 sos = signal.butter(4, wn, btype="bandstop", output="sos") 

224 

225 # Apply zero-phase filter 

226 filtered_data = signal.sosfiltfilt(sos, trace.data) 

227 

228 return WaveformTrace( 

229 data=filtered_data.astype(np.float64), 

230 metadata=trace.metadata, 

231 ) 

232 

233 

234def moving_average( 

235 trace: WaveformTrace, 

236 window_size: int, 

237 *, 

238 mode: Literal["same", "valid", "full"] = "same", 

239) -> WaveformTrace: 

240 """Apply moving average filter. 

241 

242 Simple FIR filter with uniform weights. 

243 

244 Args: 

245 trace: Input waveform trace. 

246 window_size: Number of samples in averaging window (must be odd for 'same' mode). 

247 mode: Convolution mode - "same" preserves length. 

248 

249 Returns: 

250 Filtered waveform trace. 

251 

252 Raises: 

253 AnalysisError: If window_size is not positive or exceeds data length. 

254 

255 Example: 

256 >>> smoothed = moving_average(trace, window_size=11) 

257 """ 

258 if window_size < 1: 

259 raise AnalysisError(f"Window size must be positive, got {window_size}") 

260 

261 if window_size > len(trace.data): 

262 raise AnalysisError(f"Window size {window_size} exceeds data length {len(trace.data)}") 

263 

264 kernel = np.ones(window_size) / window_size 

265 filtered_data = np.convolve(trace.data, kernel, mode=mode) 

266 

267 return WaveformTrace( 

268 data=filtered_data.astype(np.float64), 

269 metadata=trace.metadata, 

270 ) 

271 

272 

273def median_filter( 

274 trace: WaveformTrace, 

275 kernel_size: int, 

276) -> WaveformTrace: 

277 """Apply median filter for spike/impulse noise removal. 

278 

279 Non-linear filter that preserves edges while removing outliers. 

280 

281 Args: 

282 trace: Input waveform trace. 

283 kernel_size: Size of the median filter kernel (must be odd). 

284 

285 Returns: 

286 Filtered waveform trace. 

287 

288 Raises: 

289 AnalysisError: If kernel_size is not positive or not odd. 

290 

291 Example: 

292 >>> cleaned = median_filter(trace, kernel_size=5) # Remove impulse noise 

293 """ 

294 if kernel_size < 1: 

295 raise AnalysisError(f"Kernel size must be positive, got {kernel_size}") 

296 

297 if kernel_size % 2 == 0: 

298 raise AnalysisError(f"Kernel size must be odd, got {kernel_size}") 

299 

300 filtered_data = ndimage.median_filter(trace.data, size=kernel_size) 

301 

302 return WaveformTrace( 

303 data=filtered_data.astype(np.float64), 

304 metadata=trace.metadata, 

305 ) 

306 

307 

308def savgol_filter( 

309 trace: WaveformTrace, 

310 window_length: int, 

311 polyorder: int, 

312 *, 

313 deriv: int = 0, 

314) -> WaveformTrace: 

315 """Apply Savitzky-Golay smoothing filter. 

316 

317 Smooths data while preserving higher moments (peaks, etc.) better 

318 than simple moving average. 

319 

320 Args: 

321 trace: Input waveform trace. 

322 window_length: Length of filter window (must be odd and > polyorder). 

323 polyorder: Order of polynomial used in fitting. 

324 deriv: Derivative order (0 = smoothing, 1 = first derivative, etc.). 

325 

326 Returns: 

327 Filtered waveform trace. 

328 

329 Raises: 

330 AnalysisError: If window_length is not odd or polyorder is invalid. 

331 

332 Example: 

333 >>> smoothed = savgol_filter(trace, window_length=11, polyorder=3) 

334 """ 

335 if window_length % 2 == 0: 

336 raise AnalysisError(f"Window length must be odd, got {window_length}") 

337 

338 if polyorder >= window_length: 

339 raise AnalysisError( 

340 f"Polynomial order {polyorder} must be less than window length {window_length}" 

341 ) 

342 

343 filtered_data = signal.savgol_filter(trace.data, window_length, polyorder, deriv=deriv) 

344 

345 return WaveformTrace( 

346 data=filtered_data.astype(np.float64), 

347 metadata=trace.metadata, 

348 ) 

349 

350 

351def matched_filter( 

352 trace: WaveformTrace, 

353 template: NDArray[np.floating[Any]], 

354 *, 

355 normalize: bool = True, 

356) -> WaveformTrace: 

357 """Apply matched filter for pulse detection. 

358 

359 Correlates the input with a known pulse template to detect 

360 occurrences of that pulse shape. 

361 

362 Args: 

363 trace: Input waveform trace. 

364 template: Template pulse to match. 

365 normalize: If True, normalize template for unit energy. 

366 

367 Returns: 

368 Matched filter output trace. Peaks indicate template matches. 

369 

370 Raises: 

371 AnalysisError: If template is empty or exceeds data length. 

372 

373 Example: 

374 >>> # Detect a specific pulse shape 

375 >>> pulse_template = np.array([0, 0.5, 1.0, 0.5, 0]) 

376 >>> match_output = matched_filter(trace, pulse_template) 

377 >>> # Find peaks in match_output for detection 

378 """ 

379 if len(template) == 0: 

380 raise AnalysisError("Template cannot be empty") 

381 

382 if len(template) > len(trace.data): 

383 raise AnalysisError( 

384 f"Template length {len(template)} exceeds data length {len(trace.data)}" 

385 ) 

386 

387 # Matched filter is correlation with time-reversed template 

388 h = template[::-1].copy() 

389 

390 if normalize: 

391 energy = np.sum(h**2) 

392 if energy > 0: 392 ↛ 396line 392 didn't jump to line 396 because the condition on line 392 was always true

393 h = h / np.sqrt(energy) 

394 

395 # Correlate (convolve with time-reversed template) 

396 output = np.convolve(trace.data, h, mode="same") 

397 

398 return WaveformTrace( 

399 data=output.astype(np.float64), 

400 metadata=trace.metadata, 

401 ) 

402 

403 

404def exponential_moving_average( 

405 trace: WaveformTrace, 

406 alpha: float, 

407) -> WaveformTrace: 

408 """Apply exponential moving average (EMA) filter. 

409 

410 IIR filter with exponential decay weighting. 

411 

412 Args: 

413 trace: Input waveform trace. 

414 alpha: Smoothing factor (0 < alpha <= 1). Higher = less smoothing. 

415 

416 Returns: 

417 Filtered waveform trace. 

418 

419 Raises: 

420 AnalysisError: If alpha is not in range (0, 1]. 

421 

422 Example: 

423 >>> smoothed = exponential_moving_average(trace, alpha=0.1) 

424 """ 

425 if not 0 < alpha <= 1: 

426 raise AnalysisError(f"Alpha must be in (0, 1], got {alpha}") 

427 

428 # EMA as IIR filter: y[n] = alpha * x[n] + (1 - alpha) * y[n-1] 

429 # Transfer function: H(z) = alpha / (1 - (1-alpha) * z^-1) 

430 b = np.array([alpha]) 

431 a = np.array([1.0, -(1 - alpha)]) 

432 

433 filtered_data = signal.lfilter(b, a, trace.data) 

434 

435 return WaveformTrace( 

436 data=filtered_data.astype(np.float64), 

437 metadata=trace.metadata, 

438 ) 

439 

440 

441def gaussian_filter( 

442 trace: WaveformTrace, 

443 sigma: float, 

444) -> WaveformTrace: 

445 """Apply Gaussian smoothing filter. 

446 

447 Smooth with Gaussian kernel of specified standard deviation. 

448 

449 Args: 

450 trace: Input waveform trace. 

451 sigma: Standard deviation of Gaussian kernel in samples. 

452 

453 Returns: 

454 Filtered waveform trace. 

455 

456 Raises: 

457 AnalysisError: If sigma is not positive. 

458 

459 Example: 

460 >>> smoothed = gaussian_filter(trace, sigma=3.0) 

461 """ 

462 if sigma <= 0: 

463 raise AnalysisError(f"Sigma must be positive, got {sigma}") 

464 

465 filtered_data = ndimage.gaussian_filter1d(trace.data, sigma) 

466 

467 return WaveformTrace( 

468 data=filtered_data.astype(np.float64), 

469 metadata=trace.metadata, 

470 ) 

471 

472 

473def differentiate( 

474 trace: WaveformTrace, 

475 *, 

476 order: int = 1, 

477) -> WaveformTrace: 

478 """Compute numerical derivative of trace. 

479 

480 Uses numpy gradient for smooth differentiation. 

481 

482 Args: 

483 trace: Input waveform trace. 

484 order: Derivative order (1 = first derivative, 2 = second, etc.). 

485 

486 Returns: 

487 Differentiated waveform trace. Units change (V -> V/s, etc.). 

488 

489 Raises: 

490 AnalysisError: If order is not positive. 

491 

492 Example: 

493 >>> velocity = differentiate(position_trace) 

494 >>> acceleration = differentiate(position_trace, order=2) 

495 """ 

496 if order < 1: 

497 raise AnalysisError(f"Derivative order must be positive, got {order}") 

498 

499 sample_period = trace.metadata.time_base 

500 result = trace.data.copy() 

501 

502 for _ in range(order): 

503 result = np.gradient(result, sample_period) 

504 

505 return WaveformTrace( 

506 data=result.astype(np.float64), 

507 metadata=trace.metadata, 

508 ) 

509 

510 

511def integrate( 

512 trace: WaveformTrace, 

513 *, 

514 method: Literal["cumtrapz", "cumsum"] = "cumtrapz", 

515 initial: float = 0.0, 

516) -> WaveformTrace: 

517 """Compute numerical integral of trace. 

518 

519 Args: 

520 trace: Input waveform trace. 

521 method: Integration method - "cumtrapz" (trapezoidal) or "cumsum". 

522 initial: Initial value at t=0. 

523 

524 Returns: 

525 Integrated waveform trace. Units change (V -> V*s, etc.). 

526 

527 Raises: 

528 AnalysisError: If method is not recognized. 

529 

530 Example: 

531 >>> position = integrate(velocity_trace) 

532 """ 

533 sample_period = trace.metadata.time_base 

534 

535 if method == "cumtrapz": 

536 from scipy.integrate import cumulative_trapezoid 

537 

538 result = cumulative_trapezoid(trace.data, dx=sample_period, initial=initial) 

539 elif method == "cumsum": 539 ↛ 540line 539 didn't jump to line 540 because the condition on line 539 was never true

540 result = np.cumsum(trace.data) * sample_period + initial 

541 else: 

542 raise AnalysisError(f"Unknown integration method: {method}") 

543 

544 return WaveformTrace( 

545 data=result.astype(np.float64), 

546 metadata=trace.metadata, 

547 ) 

548 

549 

550__all__ = [ 

551 "band_pass", 

552 "band_stop", 

553 "differentiate", 

554 "exponential_moving_average", 

555 "gaussian_filter", 

556 "high_pass", 

557 "integrate", 

558 "low_pass", 

559 "matched_filter", 

560 "median_filter", 

561 "moving_average", 

562 "notch_filter", 

563 "savgol_filter", 

564]