Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mth5 \ mth5 \ timeseries \ scipy_filters.py: 86%

194 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-27 20:09 -0800

1# -*- coding: utf-8 -*- 

2""" 

3Scipy filter wrappers for xarray. 

4 

5This module provides xarray-compatible wrappers for scipy.signal filtering 

6functions, enabling efficient filtering operations on labeled N-dimensional 

7arrays with automatic dimension handling. 

8 

9Notes 

10----- 

11- Adapted from xr-scipy: https://github.com/fujiisoup/xr-scipy 

12- Filters can be applied along any dimension with automatic sampling rate detection. 

13- Supports both IIR and FIR filter types with forward/backward filtering. 

14 

15Examples 

16-------- 

17Apply a bandpass filter to an xarray DataArray:: 

18 

19 >>> import xarray as xr 

20 >>> import numpy as np 

21 >>> data = xr.DataArray(np.random.randn(1000), dims=['time']) 

22 >>> data['time'] = np.arange(1000) / 100.0 # 100 Hz 

23 >>> filtered = data.sps_filters.bandpass(10, 40) # 10-40 Hz 

24 

25""" 

26 

27from __future__ import annotations 

28 

29import warnings 

30from fractions import Fraction 

31from typing import Any 

32 

33import numpy as np 

34import pandas as pd 

35import scipy.signal 

36import xarray as xr 

37from loguru import logger 

38 

39 

40# ============================================================================= 

41# Imports 

42# ============================================================================= 

43 

44 

45try: 

46 from scipy.signal import sosfiltfilt 

47except ImportError: 

48 sosfiltfilt = None 

49# ============================================================================= 

50 

51 

52def _firwin_ba(*args, **kwargs): 

53 if not kwargs.get("pass_zero"): 

54 args = (args[0] + 1,) + args[1:] # numtaps must be odd 

55 return scipy.signal.firwin(*args, **kwargs), np.array([1]) 

56 

57 

58_BA_FUNCS = { 

59 "iir": scipy.signal.iirfilter, 

60 "fir": _firwin_ba, 

61} 

62 

63_ORDER_DEFAULTS = { 

64 "iir": 8, 

65 "fir": 29, 

66} 

67 

68 

69### Warnings 

70class UnevenSamplingWarning(Warning): 

71 pass 

72 

73 

74class FilteringNaNWarning(Warning): 

75 pass 

76 

77 

78class DecimationWarning(Warning): 

79 pass 

80 

81 

82warnings.filterwarnings("always", category=UnevenSamplingWarning) 

83warnings.filterwarnings("always", category=FilteringNaNWarning) 

84warnings.filterwarnings("always", category=DecimationWarning) 

85 

86 

87def get_maybe_only_dim(darray: xr.DataArray | xr.Dataset, dim: str | None) -> str: 

88 """ 

89 Determine the dimension along which to operate. 

90 

91 If `dim` is None and the array is 1-D, returns the single dimension. 

92 Otherwise, returns the provided `dim`. 

93 

94 Parameters 

95 ---------- 

96 darray : xarray.DataArray | xarray.Dataset 

97 An xarray DataArray or Dataset. 

98 dim : str | None 

99 Dimension name, or None to auto-detect for 1-D arrays. 

100 

101 Returns 

102 ------- 

103 str 

104 The dimension name. 

105 

106 Raises 

107 ------ 

108 ValueError 

109 If `dim` is None and the array is not 1-D. 

110 """ 

111 if dim is None: 

112 if len(darray.dims) == 1: 

113 if isinstance(darray, xr.DataArray): 

114 return str(darray.dims[0]) 

115 elif isinstance(darray, xr.Dataset): 

116 return str(list(darray.sizes.keys())[0]) 

117 else: 

118 raise ValueError("Specify the dimension") 

119 else: 

120 return dim 

121 

122 

123def get_maybe_last_dim_axis( 

124 darray: xr.DataArray | xr.Dataset, dim: str | None = None 

125) -> tuple[str, int]: 

126 """ 

127 Get dimension name and axis index. 

128 

129 Parameters 

130 ---------- 

131 darray : xarray.DataArray | xarray.Dataset 

132 Input array. 

133 dim : str | None, default None 

134 Dimension name. If None, uses the last dimension. 

135 

136 Returns 

137 ------- 

138 tuple[str, int] 

139 Dimension name and corresponding axis index. 

140 """ 

141 if dim is None: 

142 axis = darray.ndim - 1 

143 dim = str(darray.dims[axis]) 

144 else: 

145 axis = darray.get_axis_num(dim) 

146 dim = str(dim) 

147 return dim, axis 

148 

149 

150def get_sampling_step( 

151 darray: xr.DataArray | xr.Dataset, dim: str | None = None, rtol: float = 1e-3 

152) -> float: 

153 """ 

154 Compute the sampling step along a dimension. 

155 

156 Automatically detects time unit (ns, us, ms, s) and scales accordingly. 

157 Issues a warning if sampling is not uniform (average vs first step mismatch). 

158 

159 Parameters 

160 ---------- 

161 darray : xarray.DataArray | xarray.Dataset 

162 Input array with coordinates. 

163 dim : str | None, default None 

164 Dimension name. Auto-detected for 1-D arrays. 

165 rtol : float, default 1e-3 

166 Relative tolerance for detecting uneven sampling. 

167 

168 Returns 

169 ------- 

170 float 

171 Sampling step in appropriate units (seconds for time coordinates). 

172 

173 Raises 

174 ------ 

175 ValueError 

176 If the coordinate has fewer than 2 samples. 

177 

178 Warnings 

179 -------- 

180 UnevenSamplingWarning 

181 If average sampling step differs from first step by more than rtol. 

182 

183 Examples 

184 -------- 

185 Get sampling step from a time series:: 

186 

187 >>> dt = get_sampling_step(data, dim='time') 

188 >>> sample_rate = 1.0 / dt 

189 """ 

190 dim = get_maybe_only_dim(darray, dim) 

191 

192 coord = darray.coords[dim] 

193 

194 if len(coord) < 2: 

195 raise ValueError( 

196 f"Cannot compute sampling step for coordinate with less than 2 samples. " 

197 f"Got {len(coord)} samples." 

198 ) 

199 

200 if "ns" in coord.dtype.descr[0][1]: 

201 t_scale = 1e9 

202 elif "us" in coord.dtype.descr[0][1]: 

203 t_scale = 1e6 

204 elif "ms" in coord.dtype.descr[0][1]: 

205 t_scale = 1e3 

206 else: 

207 t_scale = 1 

208 

209 # FIX: Convert timedelta to numeric value before float() for Python 3.11+ compatibility 

210 # When subtracting datetime64 values, result is timedelta64 which must be converted 

211 dt_avg_raw = coord[-1] - coord[0] 

212 dt_first_raw = coord[1] - coord[0] 

213 

214 # Convert to float - handle numpy timedelta64, pandas Timedelta, and datetime.timedelta 

215 if hasattr(dt_avg_raw, "values"): 

216 # xarray DataArray - get the underlying value 

217 dt_avg_value = dt_avg_raw.values 

218 dt_first_value = dt_first_raw.values 

219 else: 

220 dt_avg_value = dt_avg_raw 

221 dt_first_value = dt_first_raw 

222 

223 # Convert timedelta objects to total_seconds() * 1e9 (nanoseconds) 

224 if hasattr(dt_avg_value, "total_seconds"): 

225 # Python datetime.timedelta 

226 dt_avg = (dt_avg_value.total_seconds() * 1e9 / (len(coord) - 1)) / t_scale 

227 dt_first = (dt_first_value.total_seconds() * 1e9) / t_scale 

228 elif hasattr(dt_avg_value, "view"): 

229 # numpy timedelta64 - convert to int64 view 

230 dt_avg = (float(dt_avg_value.view("int64")) / (len(coord) - 1)) / t_scale 

231 dt_first = float(dt_first_value.view("int64")) / t_scale 

232 else: 

233 # Already numeric 

234 dt_avg = (float(dt_avg_value) / (len(coord) - 1)) / t_scale 

235 dt_first = float(dt_first_value) / t_scale 

236 

237 if abs(dt_avg - dt_first) > rtol * min(dt_first, dt_avg): 

238 # show warning at caller level to see which signal it is related to 

239 warnings.warn( 

240 f"Average sampling {dt_avg:.3g} != first sampling step {dt_first:.3g}", 

241 UnevenSamplingWarning, 

242 stacklevel=2, 

243 ) 

244 return dt_avg 

245 

246 

247def frequency_filter( 

248 darray: xr.DataArray | xr.Dataset, 

249 f_crit: float | list[float] | tuple[float, float], 

250 order: int | None = None, 

251 irtype: str = "iir", 

252 filtfilt: bool = True, 

253 apply_kwargs: dict[str, Any] | None = None, 

254 in_nyq: bool = False, 

255 dim: str | None = None, 

256 **kwargs: Any, 

257) -> xr.DataArray | xr.Dataset: 

258 """ 

259 Apply a frequency filter to an xarray. 

260 

261 Supports IIR (infinite impulse response) and FIR (finite impulse response) 

262 filters with optional forward-backward filtering for zero-phase response. 

263 

264 Parameters 

265 ---------- 

266 darray : xarray.DataArray | xarray.Dataset 

267 Data to be filtered. 

268 f_crit : float | list[float] | tuple[float, float] 

269 Critical frequency or frequencies (in Hz). 

270 Scalar for lowpass/highpass, 2-element for bandpass/bandstop. 

271 order : int | None, default None 

272 Filter order. If None, uses default (8 for IIR, 29 for FIR). 

273 irtype : {'iir', 'fir'}, default 'iir' 

274 Impulse response type: 'iir' or 'fir'. 

275 filtfilt : bool, default True 

276 Apply filter forwards and backwards for zero-phase response. 

277 apply_kwargs : dict | None, default None 

278 Additional kwargs passed to the filter function. 

279 in_nyq : bool, default False 

280 If True, `f_crit` values are already normalized by Nyquist frequency. 

281 dim : str | None, default None 

282 Dimension along which to filter. Auto-detected for 1-D arrays. 

283 **kwargs 

284 Passed to filter design function (scipy.signal.iirfilter or firwin). 

285 

286 Returns 

287 ------- 

288 xarray.DataArray | xarray.Dataset 

289 Filtered data with same structure as input. 

290 

291 Raises 

292 ------ 

293 ValueError 

294 If `irtype` is not 'iir' or 'fir', or if `dim` is ambiguous. 

295 

296 Warnings 

297 -------- 

298 FilteringNaNWarning 

299 If input contains NaN values. 

300 

301 Examples 

302 -------- 

303 Apply a 4th-order IIR lowpass at 10 Hz:: 

304 

305 >>> filtered = frequency_filter(data, 10, order=4, btype='low') 

306 

307 FIR bandpass filter from 5 to 15 Hz:: 

308 

309 >>> filtered = frequency_filter(data, [5, 15], irtype='fir', btype='band') 

310 """ 

311 if irtype not in _BA_FUNCS: 

312 raise ValueError( 

313 "Wrong argument for irtype: {}, must be one of {}".format( 

314 irtype, _BA_FUNCS.keys() 

315 ) 

316 ) 

317 if order is None: 

318 order = _ORDER_DEFAULTS[irtype] 

319 if apply_kwargs is None: 

320 apply_kwargs = {} 

321 dim = get_maybe_only_dim(darray, dim) 

322 

323 f_crit_norm = np.asarray(f_crit, dtype=np.float64) 

324 if not in_nyq: # normalize by Nyquist frequency 

325 f_crit_norm *= 2 * get_sampling_step(darray, dim) 

326 if np.any( 

327 np.isnan( 

328 np.asarray(darray.to_array() if isinstance(darray, xr.Dataset) else darray) 

329 ) 

330 ): # only warn since simple forward-filter or FIR is valid 

331 warnings.warn( 

332 "data contains NaNs, filter will propagate them", 

333 FilteringNaNWarning, 

334 stacklevel=2, 

335 ) 

336 if sosfiltfilt and irtype == "iir": 

337 sos = scipy.signal.iirfilter(order, f_crit_norm, output="sos", **kwargs) 

338 if filtfilt: 

339 ret = xr.apply_ufunc( 

340 sosfiltfilt, 

341 sos, 

342 darray, 

343 input_core_dims=[[], [dim]], 

344 output_core_dims=[[dim]], 

345 kwargs=apply_kwargs, 

346 ) 

347 else: 

348 ret = xr.apply_ufunc( 

349 scipy.signal.sosfilt, 

350 sos, 

351 darray, 

352 input_core_dims=[[], [dim]], 

353 output_core_dims=[[dim]], 

354 kwargs=apply_kwargs, 

355 ) 

356 else: 

357 b, a = _BA_FUNCS[irtype](order, f_crit_norm, **kwargs) 

358 if filtfilt: 

359 ret = xr.apply_ufunc( 

360 scipy.signal.filtfilt, 

361 b, 

362 a, 

363 darray, 

364 input_core_dims=[[], [], [dim]], 

365 output_core_dims=[[dim]], 

366 kwargs=apply_kwargs, 

367 ) 

368 else: 

369 ret = xr.apply_ufunc( 

370 scipy.signal.lfilter, 

371 b, 

372 a, 

373 darray, 

374 input_core_dims=[[], [], [dim]], 

375 output_core_dims=[[dim]], 

376 kwargs=apply_kwargs, 

377 ) 

378 return ret 

379 

380 

381def _update_ftype_kwargs(kwargs, iirvalue, firvalue): 

382 if kwargs.get("irtype", "iir") == "iir": 

383 kwargs.setdefault("btype", iirvalue) 

384 else: # fir 

385 kwargs.setdefault("pass_zero", firvalue) 

386 return kwargs 

387 

388 

389def lowpass( 

390 darray: xr.DataArray | xr.Dataset, f_cutoff: float, *args: Any, **kwargs: Any 

391) -> xr.DataArray | xr.Dataset: 

392 """ 

393 Apply a lowpass filter. 

394 

395 Parameters 

396 ---------- 

397 darray : xarray.DataArray | xarray.Dataset 

398 Data to filter. 

399 f_cutoff : float 

400 Cutoff frequency in Hz. 

401 *args 

402 Passed to `frequency_filter`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim). 

403 **kwargs 

404 Passed to filter design (see `frequency_filter`). 

405 

406 Returns 

407 ------- 

408 xarray.DataArray | xarray.Dataset 

409 Lowpass-filtered data. 

410 

411 Examples 

412 -------- 

413 Remove components above 50 Hz:: 

414 

415 >>> filtered = lowpass(data, 50) 

416 """ 

417 kwargs = _update_ftype_kwargs(kwargs, "lowpass", True) 

418 return frequency_filter(darray, f_cutoff, *args, **kwargs) 

419 

420 

421def highpass( 

422 darray: xr.DataArray | xr.Dataset, f_cutoff: float, *args: Any, **kwargs: Any 

423) -> xr.DataArray | xr.Dataset: 

424 """ 

425 Apply a highpass filter. 

426 

427 Parameters 

428 ---------- 

429 darray : xarray.DataArray | xarray.Dataset 

430 Data to filter. 

431 f_cutoff : float 

432 Cutoff frequency in Hz. 

433 *args 

434 Passed to `frequency_filter`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim). 

435 **kwargs 

436 Passed to filter design (see `frequency_filter`). 

437 

438 Returns 

439 ------- 

440 xarray.DataArray | xarray.Dataset 

441 Highpass-filtered data. 

442 

443 Examples 

444 -------- 

445 Remove components below 1 Hz:: 

446 

447 >>> filtered = highpass(data, 1.0) 

448 """ 

449 kwargs = _update_ftype_kwargs(kwargs, "highpass", False) 

450 return frequency_filter(darray, f_cutoff, *args, **kwargs) 

451 

452 

453def bandpass( 

454 darray: xr.DataArray | xr.Dataset, 

455 f_low: float, 

456 f_high: float, 

457 *args: Any, 

458 **kwargs: Any, 

459) -> xr.DataArray | xr.Dataset: 

460 """ 

461 Apply a bandpass filter. 

462 

463 Parameters 

464 ---------- 

465 darray : xarray.DataArray | xarray.Dataset 

466 Data to filter. 

467 f_low : float 

468 Lower cutoff frequency in Hz. 

469 f_high : float 

470 Upper cutoff frequency in Hz. 

471 *args 

472 Passed to `frequency_filter`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim). 

473 **kwargs 

474 Passed to filter design (see `frequency_filter`). 

475 

476 Returns 

477 ------- 

478 xarray.DataArray | xarray.Dataset 

479 Bandpass-filtered data. 

480 

481 Examples 

482 -------- 

483 Keep components between 10 and 50 Hz:: 

484 

485 >>> filtered = bandpass(data, 10, 50) 

486 """ 

487 kwargs = _update_ftype_kwargs(kwargs, "bandpass", False) 

488 return frequency_filter(darray, [f_low, f_high], *args, **kwargs) 

489 

490 

491def bandstop( 

492 darray: xr.DataArray | xr.Dataset, 

493 f_low: float, 

494 f_high: float, 

495 *args: Any, 

496 **kwargs: Any, 

497) -> xr.DataArray | xr.Dataset: 

498 """ 

499 Apply a bandstop (notch) filter. 

500 

501 Parameters 

502 ---------- 

503 darray : xarray.DataArray | xarray.Dataset 

504 Data to filter. 

505 f_low : float 

506 Lower cutoff frequency in Hz. 

507 f_high : float 

508 Upper cutoff frequency in Hz. 

509 *args 

510 Passed to `frequency_filter`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim). 

511 **kwargs 

512 Passed to filter design (see `frequency_filter`). 

513 

514 Returns 

515 ------- 

516 xarray.DataArray | xarray.Dataset 

517 Bandstop-filtered data (removes frequencies between f_low and f_high). 

518 

519 Examples 

520 -------- 

521 Remove 50-60 Hz powerline noise:: 

522 

523 >>> filtered = bandstop(data, 50, 60) 

524 """ 

525 kwargs = _update_ftype_kwargs(kwargs, "bandstop", True) 

526 return frequency_filter(darray, [f_low, f_high], *args, **kwargs) 

527 

528 

529# def notch( 

530# darray, 

531# notch_freq, 

532# notch_radius=0.5, 

533# frequency_radius=0.9, 

534# ripple=0.1, 

535# db_stop_limit=5.0, 

536# ): 

537# ford, wn = signal.cheb1ord(wp, ws, 1, dbstop) 

538# b, a = signal.cheby1(1, 0.5, wn, btype="bandstop") 

539# bx = signal.filtfilt(b, a, bx) 

540 

541 

542def decimate( 

543 darray: xr.Dataset | xr.DataArray, 

544 target_sample_rate: float, 

545 n_order: int = 8, 

546 dim: str | None = None, 

547) -> xr.DataArray | xr.Dataset: 

548 """ 

549 Decimate data using Chebyshev filter and downsampling. 

550 

551 Applies an 8th-order Chebyshev type I filter with zero-phase filtering 

552 (sosfiltfilt) before downsampling. 

553 

554 Parameters 

555 ---------- 

556 darray : xarray.DataArray | xarray.Dataset 

557 Data to decimate. 

558 target_sample_rate : float 

559 Target sample rate in samples per second (or per space unit). 

560 n_order : int, default 8 

561 Order of the Chebyshev type I filter. 

562 dim : str | None, default None 

563 Dimension to decimate along. Auto-detected for 1-D arrays. 

564 

565 Returns 

566 ------- 

567 xarray.DataArray | xarray.Dataset 

568 Decimated data with adjusted coordinates. 

569 

570 Raises 

571 ------ 

572 ValueError 

573 If `dim` is None and array is not 1-D. 

574 

575 Warnings 

576 -------- 

577 UserWarning 

578 If decimation factor > 13, suggest calling decimate multiple times. 

579 

580 Notes 

581 ----- 

582 If sample_rate / target_sample_rate > 13, call decimate multiple times 

583 to avoid aliasing artifacts. 

584 

585 Examples 

586 -------- 

587 Decimate from 100 Hz to 10 Hz:: 

588 

589 >>> decimated = decimate(data, target_sample_rate=10.0) 

590 """ 

591 

592 dim = get_maybe_only_dim(darray, dim) 

593 

594 dt = get_sampling_step(darray, dim) 

595 q = int(np.rint(1 / (dt * target_sample_rate))) 

596 

597 if q > 13: 

598 warnings.warn( 

599 f"Decimation factor is larger than 13 ({q}), the resulting " 

600 "decimated array maybe incorrect. Suggest calling decimate " 

601 "multiple times." 

602 ) 

603 sos = scipy.signal.cheby1(n_order, 0.05, 0.8 / q, output="sos") 

604 

605 if sosfiltfilt is None: 

606 raise ImportError("sosfiltfilt not available in scipy.signal") 

607 

608 ret = xr.apply_ufunc( 

609 sosfiltfilt, 

610 sos, 

611 darray, 

612 input_core_dims=[[], [dim]], 

613 output_core_dims=[[dim]], 

614 kwargs={}, 

615 ) 

616 

617 return ret.isel(**{dim: slice(None, None, q)}) 

618 

619 

620def resample_poly( 

621 darray: xr.DataArray | xr.Dataset, 

622 new_sample_rate: float, 

623 dim: str | None = None, 

624 pad_type: str = "mean", 

625) -> xr.DataArray | xr.Dataset: 

626 """ 

627 Resample using polyphase filtering. 

628 

629 Computes rational resampling ratio (up/down) and applies 

630 scipy.signal.resample_poly. Automatically handles coordinate updates. 

631 

632 Parameters 

633 ---------- 

634 darray : xarray.DataArray | xarray.Dataset 

635 Data to resample. 

636 new_sample_rate : float 

637 Target sample rate. 

638 dim : str | None, default None 

639 Dimension to resample along. Auto-detected for 1-D arrays. 

640 pad_type : str, default 'mean' 

641 Padding type passed to scipy.signal.resample_poly. 

642 Options: 'constant', 'line', 'mean', 'median', etc. 

643 

644 Returns 

645 ------- 

646 xarray.DataArray | xarray.Dataset 

647 Resampled data with updated coordinates. 

648 

649 Raises 

650 ------ 

651 ValueError 

652 If `dim` is None and array is not 1-D. 

653 

654 Warnings 

655 -------- 

656 UserWarning 

657 If new sample rate is not an integer multiple of original rate. 

658 

659 Notes 

660 ----- 

661 In newer scipy versions, data is cast to float and returns float dtype. 

662 

663 Examples 

664 -------- 

665 Resample to 50 Hz:: 

666 

667 >>> resampled = resample_poly(data, new_sample_rate=50.0) 

668 """ 

669 dim = get_maybe_only_dim(darray, dim) 

670 old_sample_rate = 1.0 / get_sampling_step(darray, dim) 

671 

672 fraction = Fraction(new_sample_rate / old_sample_rate).limit_denominator() 

673 

674 # need to resample the time coordinate because it will change and that 

675 # is illegal in apply_ufunc. 

676 dim = get_maybe_only_dim(darray, dim) 

677 

678 ret = xr.apply_ufunc( 

679 scipy.signal.resample_poly, 

680 darray.astype(float), 

681 fraction.numerator, 

682 fraction.denominator, 

683 input_core_dims=[[dim], [], []], 

684 output_core_dims=[[dim]], 

685 exclude_dims=set([dim]), 

686 kwargs={"padtype": pad_type}, 

687 ) 

688 

689 dt = get_sampling_step(darray, dim) 

690 new_step = 1 / (dt * new_sample_rate) 

691 if new_step % 1 == 0: 

692 q = int(np.rint(new_step)) 

693 # directly downsample without AAF on dimension 

694 # this only works if q is an integer, otherwise to 

695 # the index gets messed up from fractional spacing 

696 new_dim = darray[dim].values[slice(None, None, q)] 

697 

698 else: 

699 logger.warning( 

700 "New sample rate is not an even number of original sample rate. " 

701 f"The ratio is {new_step}. Use the new dimensions with caution." 

702 ) 

703 # need to reset the end time 

704 end_time = darray[dim].values[0] + np.timedelta64( 

705 int(np.rint(((ret[dim].size - 1) / new_sample_rate) * 1e9)), "ns" 

706 ) 

707 if dim in ["time"]: 

708 new_dim = pd.date_range( 

709 darray[dim].values[0], 

710 end_time, 

711 periods=ret[dim].size, 

712 ) 

713 else: 

714 end_index = ( 

715 int(np.rint((ret[dim].size - (darray[dim].size / new_step)))) - 1 

716 ) 

717 new_dim = np.linspace( 

718 darray[dim].values[0], darray[dim].values[end_index], ret[dim].size 

719 ) 

720 

721 # check to make sure the dimension size is the same as the new array 

722 n_samples_data = len(ret[dim]) 

723 n_samples_axis = len(new_dim) 

724 if n_samples_data != n_samples_axis: 

725 logger.warning( 

726 f"conflicting axes sizes {n_samples_data} data and {n_samples_axis}" 

727 " axes after resampling" 

728 ) 

729 logger.info(f"trimming {dim} axis from {n_samples_axis} to {n_samples_data}") 

730 new_dim = new_dim[:n_samples_data] 

731 

732 ret[dim] = new_dim 

733 

734 return ret 

735 

736 

737def savgol_filter( 

738 darray: xr.DataArray | xr.Dataset, 

739 window_length: int, 

740 polyorder: int, 

741 deriv: int = 0, 

742 delta: float | None = None, 

743 dim: str | None = None, 

744 mode: str = "interp", 

745 cval: float = 0.0, 

746) -> xr.DataArray | xr.Dataset: 

747 """ 

748 Apply a Savitzky-Golay filter. 

749 

750 Smooths data using least-squares polynomial fit over a sliding window. 

751 Can also compute derivatives. 

752 

753 Parameters 

754 ---------- 

755 darray : xarray.DataArray | xarray.Dataset 

756 Data to filter. Converted to float64 if not already float. 

757 window_length : int 

758 Length of the filter window (number of coefficients). 

759 Must be positive odd integer. 

760 polyorder : int 

761 Order of polynomial for fitting. Must be < window_length. 

762 deriv : int, default 0 

763 Order of derivative to compute (0 = no differentiation). 

764 delta : float | None, default None 

765 Sample spacing for derivative computation. Only used if deriv > 0. 

766 dim : str | None, default None 

767 Dimension along which to filter. Auto-detected for 1-D arrays. 

768 mode : str, default 'interp' 

769 Extension mode: 'mirror', 'constant', 'nearest', 'wrap', or 'interp'. 

770 cval : float, default 0.0 

771 Fill value when mode='constant'. 

772 

773 Returns 

774 ------- 

775 xarray.DataArray | xarray.Dataset 

776 Filtered data. 

777 

778 Raises 

779 ------ 

780 ValueError 

781 If window_length is not positive odd, polyorder >= window_length, 

782 or `dim` is None for multi-dimensional arrays. 

783 

784 Examples 

785 -------- 

786 Smooth with 11-point window, 2nd-order polynomial:: 

787 

788 >>> smoothed = savgol_filter(data, window_length=11, polyorder=2) 

789 

790 Compute first derivative:: 

791 

792 >>> deriv1 = savgol_filter(data, 11, 2, deriv=1, delta=0.01) 

793 """ 

794 dim = get_maybe_only_dim(darray, dim) 

795 if delta is None: 

796 delta = get_sampling_step(darray, dim) 

797 window_length = int(np.rint(window_length / delta)) 

798 if window_length % 2 == 0: # must be odd 

799 window_length += 1 

800 return xr.apply_ufunc( 

801 scipy.signal.savgol_filter, 

802 darray, 

803 input_core_dims=[[dim]], 

804 output_core_dims=[[dim]], 

805 kwargs=dict( 

806 window_length=window_length, 

807 polyorder=polyorder, 

808 deriv=deriv, 

809 delta=delta, 

810 mode=mode, 

811 cval=cval, 

812 ), 

813 ) 

814 

815 

816def detrend( 

817 darray: xr.DataArray | xr.Dataset, 

818 dim: str | None = None, 

819 trend_type: str = "linear", 

820) -> xr.DataArray | xr.Dataset: 

821 """ 

822 Remove linear or constant trend from data. 

823 

824 Parameters 

825 ---------- 

826 darray : xarray.DataArray | xarray.Dataset 

827 Data to detrend. 

828 dim : str | None, default None 

829 Dimension along which to detrend. Auto-detected for 1-D arrays. 

830 trend_type : {'linear', 'constant'}, default 'linear' 

831 Type of detrending: 'linear' removes linear trend via least-squares, 

832 'constant' removes mean. 

833 

834 Returns 

835 ------- 

836 xarray.DataArray | xarray.Dataset 

837 Detrended data. 

838 

839 Raises 

840 ------ 

841 ValueError 

842 If `dim` is None and array is not 1-D. 

843 

844 Examples 

845 -------- 

846 Remove linear trend:: 

847 

848 >>> detrended = detrend(data, trend_type='linear') 

849 

850 Remove mean (DC component):: 

851 

852 >>> demeaned = detrend(data, trend_type='constant') 

853 """ 

854 

855 dim = get_maybe_only_dim(darray, dim) 

856 

857 return xr.apply_ufunc( 

858 scipy.signal.detrend, 

859 darray, 

860 input_core_dims=[[dim]], 

861 output_core_dims=[[dim]], 

862 kwargs={"type": trend_type}, 

863 ) 

864 

865 

866@xr.register_dataarray_accessor("sps_filters") 

867@xr.register_dataset_accessor("sps_filters") 

868class FilterAccessor: 

869 """ 

870 Accessor exposing common frequency and filtering methods. 

871 

872 Registered as xarray accessor under `.sps_filters` for both DataArray 

873 and Dataset objects. 

874 

875 Attributes 

876 ---------- 

877 darray : xarray.DataArray | xarray.Dataset 

878 The wrapped xarray object. 

879 

880 Examples 

881 -------- 

882 Apply filters via accessor:: 

883 

884 >>> data.sps_filters.low(10) # lowpass at 10 Hz 

885 >>> data.sps_filters.bandpass(5, 15) # bandpass 5-15 Hz 

886 """ 

887 

888 def __init__(self, darray: xr.DataArray | xr.Dataset) -> None: 

889 self.darray: xr.DataArray | xr.Dataset = darray 

890 

891 @property 

892 def dt(self) -> float: 

893 """Sampling step of last axis.""" 

894 return get_sampling_step(self.darray) 

895 

896 @property 

897 def fs(self) -> float: 

898 """Sampling frequency in inverse units of self.dt.""" 

899 return 1.0 / self.dt 

900 

901 @property 

902 def dx(self) -> np.ndarray: 

903 """Sampling steps for all axes as array.""" 

904 return np.array( 

905 [get_sampling_step(self.darray, str(dim)) for dim in self.darray.dims] 

906 ) 

907 

908 # NOTE: the arguments are coded explicitly for tab-completion to work, 

909 # using a decorator wrapper with *args would not expose them 

910 def low( 

911 self, f_cutoff: float, *args: Any, **kwargs: Any 

912 ) -> xr.DataArray | xr.Dataset: 

913 """ 

914 Apply lowpass filter. 

915 

916 Parameters 

917 ---------- 

918 f_cutoff : float 

919 Cutoff frequency in Hz. 

920 *args 

921 Passed to `lowpass`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim). 

922 **kwargs 

923 Passed to filter design. 

924 

925 Returns 

926 ------- 

927 xarray.DataArray | xarray.Dataset 

928 Lowpass-filtered data. 

929 """ 

930 return lowpass(self.darray, f_cutoff, *args, **kwargs) 

931 

932 def high( 

933 self, f_cutoff: float, *args: Any, **kwargs: Any 

934 ) -> xr.DataArray | xr.Dataset: 

935 """ 

936 Apply highpass filter. 

937 

938 Parameters 

939 ---------- 

940 f_cutoff : float 

941 Cutoff frequency in Hz. 

942 *args 

943 Passed to `highpass`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim). 

944 **kwargs 

945 Passed to filter design. 

946 

947 Returns 

948 ------- 

949 xarray.DataArray | xarray.Dataset 

950 Highpass-filtered data. 

951 """ 

952 return highpass(self.darray, f_cutoff, *args, **kwargs) 

953 

954 def bandpass( 

955 self, f_low: float, f_high: float, *args: Any, **kwargs: Any 

956 ) -> xr.DataArray | xr.Dataset: 

957 """ 

958 Apply bandpass filter. 

959 

960 Parameters 

961 ---------- 

962 f_low : float 

963 Lower cutoff frequency in Hz. 

964 f_high : float 

965 Upper cutoff frequency in Hz. 

966 *args 

967 Passed to `bandpass`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim). 

968 **kwargs 

969 Passed to filter design. 

970 

971 Returns 

972 ------- 

973 xarray.DataArray | xarray.Dataset 

974 Bandpass-filtered data. 

975 """ 

976 return bandpass(self.darray, f_low, f_high, *args, **kwargs) 

977 

978 def bandstop( 

979 self, f_low: float, f_high: float, *args: Any, **kwargs: Any 

980 ) -> xr.DataArray | xr.Dataset: 

981 """ 

982 Apply bandstop filter. 

983 

984 Parameters 

985 ---------- 

986 f_low : float 

987 Lower cutoff frequency in Hz. 

988 f_high : float 

989 Upper cutoff frequency in Hz. 

990 *args 

991 Passed to `bandstop`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim). 

992 **kwargs 

993 Passed to filter design. 

994 

995 Returns 

996 ------- 

997 xarray.DataArray | xarray.Dataset 

998 Bandstop-filtered data. 

999 """ 

1000 return bandstop(self.darray, f_low, f_high, *args, **kwargs) 

1001 

1002 def freq( 

1003 self, 

1004 f_crit: float | list[float] | tuple[float, float], 

1005 order: int | None = None, 

1006 irtype: str = "iir", 

1007 filtfilt: bool = True, 

1008 apply_kwargs: dict[str, Any] | None = None, 

1009 in_nyq: bool = False, 

1010 dim: str | None = None, 

1011 **kwargs: Any, 

1012 ) -> xr.DataArray | xr.Dataset: 

1013 """ 

1014 Apply general frequency filter. 

1015 

1016 Parameters 

1017 ---------- 

1018 f_crit : float | list[float] | tuple[float, float] 

1019 Critical frequency or frequencies in Hz. 

1020 order : int | None, default None 

1021 Filter order. 

1022 irtype : {'iir', 'fir'}, default 'iir' 

1023 Impulse response type. 

1024 filtfilt : bool, default True 

1025 Apply forward-backward filtering. 

1026 apply_kwargs : dict | None, default None 

1027 Additional filter function kwargs. 

1028 in_nyq : bool, default False 

1029 If True, f_crit is normalized by Nyquist frequency. 

1030 dim : str | None, default None 

1031 Dimension along which to filter. 

1032 **kwargs 

1033 Passed to filter design. 

1034 

1035 Returns 

1036 ------- 

1037 xarray.DataArray | xarray.Dataset 

1038 Filtered data. 

1039 """ 

1040 return frequency_filter( 

1041 self.darray, 

1042 f_crit, 

1043 order, 

1044 irtype, 

1045 filtfilt, 

1046 apply_kwargs, 

1047 in_nyq, 

1048 dim, 

1049 **kwargs, 

1050 ) 

1051 

1052 __call__ = freq 

1053 

1054 def savgol( 

1055 self, 

1056 window_length: int, 

1057 polyorder: int, 

1058 deriv: int = 0, 

1059 delta: float | None = None, 

1060 dim: str | None = None, 

1061 mode: str = "interp", 

1062 cval: float = 0.0, 

1063 ) -> xr.DataArray | xr.Dataset: 

1064 """ 

1065 Apply Savitzky-Golay filter. 

1066 

1067 Parameters 

1068 ---------- 

1069 window_length : int 

1070 Filter window length (positive odd integer). 

1071 polyorder : int 

1072 Polynomial order (< window_length). 

1073 deriv : int, default 0 

1074 Derivative order. 

1075 delta : float | None, default None 

1076 Sample spacing. 

1077 dim : str | None, default None 

1078 Dimension to filter along. 

1079 mode : str, default 'interp' 

1080 Extension mode. 

1081 cval : float, default 0.0 

1082 Constant fill value. 

1083 

1084 Returns 

1085 ------- 

1086 xarray.DataArray | xarray.Dataset 

1087 Filtered data. 

1088 """ 

1089 return savgol_filter( 

1090 self.darray, 

1091 window_length, 

1092 polyorder, 

1093 deriv, 

1094 delta, 

1095 dim, 

1096 mode, 

1097 cval, 

1098 ) 

1099 

1100 def decimate( 

1101 self, target_sample_rate: float, n_order: int = 8, dim: str | None = None 

1102 ) -> xr.DataArray | xr.Dataset: 

1103 """ 

1104 Decimate signal. 

1105 

1106 Parameters 

1107 ---------- 

1108 target_sample_rate : float 

1109 Target sample rate. 

1110 n_order : int, default 8 

1111 Chebyshev filter order. 

1112 dim : str | None, default None 

1113 Dimension to decimate along. 

1114 

1115 Returns 

1116 ------- 

1117 xarray.DataArray | xarray.Dataset 

1118 Decimated data. 

1119 """ 

1120 return decimate(self.darray, target_sample_rate, n_order, dim) 

1121 

1122 def detrend( 

1123 self, trend_type: str = "linear", dim: str | None = None 

1124 ) -> xr.DataArray | xr.Dataset: 

1125 """ 

1126 Remove trend from data. 

1127 

1128 Parameters 

1129 ---------- 

1130 trend_type : {'linear', 'constant'}, default 'linear' 

1131 Type of trend to remove. 

1132 dim : str | None, default None 

1133 Dimension to detrend along. 

1134 

1135 Returns 

1136 ------- 

1137 xarray.DataArray | xarray.Dataset 

1138 Detrended data. 

1139 """ 

1140 return detrend(self.darray, dim, trend_type) 

1141 

1142 def resample_poly( 

1143 self, target_sample_rate: float, pad_type: str = "mean", dim: str | None = None 

1144 ) -> xr.DataArray | xr.Dataset: 

1145 """ 

1146 Resample using polyphase filtering. 

1147 

1148 Parameters 

1149 ---------- 

1150 target_sample_rate : float 

1151 Target sample rate. 

1152 pad_type : str, default 'mean' 

1153 Padding type for resampling. 

1154 dim : str | None, default None 

1155 Dimension to resample along. 

1156 

1157 Returns 

1158 ------- 

1159 xarray.DataArray | xarray.Dataset 

1160 Resampled data. 

1161 """ 

1162 return resample_poly( 

1163 self.darray, target_sample_rate, dim=dim, pad_type=pad_type 

1164 )