Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mth5 \ mth5 \ timeseries \ scipy_filters.py: 86%
194 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-27 20:09 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-27 20:09 -0800
1# -*- coding: utf-8 -*-
2"""
3Scipy filter wrappers for xarray.
5This module provides xarray-compatible wrappers for scipy.signal filtering
6functions, enabling efficient filtering operations on labeled N-dimensional
7arrays with automatic dimension handling.
9Notes
10-----
11- Adapted from xr-scipy: https://github.com/fujiisoup/xr-scipy
12- Filters can be applied along any dimension with automatic sampling rate detection.
13- Supports both IIR and FIR filter types with forward/backward filtering.
15Examples
16--------
17Apply a bandpass filter to an xarray DataArray::
19 >>> import xarray as xr
20 >>> import numpy as np
21 >>> data = xr.DataArray(np.random.randn(1000), dims=['time'])
22 >>> data['time'] = np.arange(1000) / 100.0 # 100 Hz
23 >>> filtered = data.sps_filters.bandpass(10, 40) # 10-40 Hz
25"""
27from __future__ import annotations
29import warnings
30from fractions import Fraction
31from typing import Any
33import numpy as np
34import pandas as pd
35import scipy.signal
36import xarray as xr
37from loguru import logger
40# =============================================================================
41# Imports
42# =============================================================================
45try:
46 from scipy.signal import sosfiltfilt
47except ImportError:
48 sosfiltfilt = None
49# =============================================================================
52def _firwin_ba(*args, **kwargs):
53 if not kwargs.get("pass_zero"):
54 args = (args[0] + 1,) + args[1:] # numtaps must be odd
55 return scipy.signal.firwin(*args, **kwargs), np.array([1])
58_BA_FUNCS = {
59 "iir": scipy.signal.iirfilter,
60 "fir": _firwin_ba,
61}
63_ORDER_DEFAULTS = {
64 "iir": 8,
65 "fir": 29,
66}
69### Warnings
70class UnevenSamplingWarning(Warning):
71 pass
74class FilteringNaNWarning(Warning):
75 pass
78class DecimationWarning(Warning):
79 pass
82warnings.filterwarnings("always", category=UnevenSamplingWarning)
83warnings.filterwarnings("always", category=FilteringNaNWarning)
84warnings.filterwarnings("always", category=DecimationWarning)
87def get_maybe_only_dim(darray: xr.DataArray | xr.Dataset, dim: str | None) -> str:
88 """
89 Determine the dimension along which to operate.
91 If `dim` is None and the array is 1-D, returns the single dimension.
92 Otherwise, returns the provided `dim`.
94 Parameters
95 ----------
96 darray : xarray.DataArray | xarray.Dataset
97 An xarray DataArray or Dataset.
98 dim : str | None
99 Dimension name, or None to auto-detect for 1-D arrays.
101 Returns
102 -------
103 str
104 The dimension name.
106 Raises
107 ------
108 ValueError
109 If `dim` is None and the array is not 1-D.
110 """
111 if dim is None:
112 if len(darray.dims) == 1:
113 if isinstance(darray, xr.DataArray):
114 return str(darray.dims[0])
115 elif isinstance(darray, xr.Dataset):
116 return str(list(darray.sizes.keys())[0])
117 else:
118 raise ValueError("Specify the dimension")
119 else:
120 return dim
123def get_maybe_last_dim_axis(
124 darray: xr.DataArray | xr.Dataset, dim: str | None = None
125) -> tuple[str, int]:
126 """
127 Get dimension name and axis index.
129 Parameters
130 ----------
131 darray : xarray.DataArray | xarray.Dataset
132 Input array.
133 dim : str | None, default None
134 Dimension name. If None, uses the last dimension.
136 Returns
137 -------
138 tuple[str, int]
139 Dimension name and corresponding axis index.
140 """
141 if dim is None:
142 axis = darray.ndim - 1
143 dim = str(darray.dims[axis])
144 else:
145 axis = darray.get_axis_num(dim)
146 dim = str(dim)
147 return dim, axis
150def get_sampling_step(
151 darray: xr.DataArray | xr.Dataset, dim: str | None = None, rtol: float = 1e-3
152) -> float:
153 """
154 Compute the sampling step along a dimension.
156 Automatically detects time unit (ns, us, ms, s) and scales accordingly.
157 Issues a warning if sampling is not uniform (average vs first step mismatch).
159 Parameters
160 ----------
161 darray : xarray.DataArray | xarray.Dataset
162 Input array with coordinates.
163 dim : str | None, default None
164 Dimension name. Auto-detected for 1-D arrays.
165 rtol : float, default 1e-3
166 Relative tolerance for detecting uneven sampling.
168 Returns
169 -------
170 float
171 Sampling step in appropriate units (seconds for time coordinates).
173 Raises
174 ------
175 ValueError
176 If the coordinate has fewer than 2 samples.
178 Warnings
179 --------
180 UnevenSamplingWarning
181 If average sampling step differs from first step by more than rtol.
183 Examples
184 --------
185 Get sampling step from a time series::
187 >>> dt = get_sampling_step(data, dim='time')
188 >>> sample_rate = 1.0 / dt
189 """
190 dim = get_maybe_only_dim(darray, dim)
192 coord = darray.coords[dim]
194 if len(coord) < 2:
195 raise ValueError(
196 f"Cannot compute sampling step for coordinate with less than 2 samples. "
197 f"Got {len(coord)} samples."
198 )
200 if "ns" in coord.dtype.descr[0][1]:
201 t_scale = 1e9
202 elif "us" in coord.dtype.descr[0][1]:
203 t_scale = 1e6
204 elif "ms" in coord.dtype.descr[0][1]:
205 t_scale = 1e3
206 else:
207 t_scale = 1
209 # FIX: Convert timedelta to numeric value before float() for Python 3.11+ compatibility
210 # When subtracting datetime64 values, result is timedelta64 which must be converted
211 dt_avg_raw = coord[-1] - coord[0]
212 dt_first_raw = coord[1] - coord[0]
214 # Convert to float - handle numpy timedelta64, pandas Timedelta, and datetime.timedelta
215 if hasattr(dt_avg_raw, "values"):
216 # xarray DataArray - get the underlying value
217 dt_avg_value = dt_avg_raw.values
218 dt_first_value = dt_first_raw.values
219 else:
220 dt_avg_value = dt_avg_raw
221 dt_first_value = dt_first_raw
223 # Convert timedelta objects to total_seconds() * 1e9 (nanoseconds)
224 if hasattr(dt_avg_value, "total_seconds"):
225 # Python datetime.timedelta
226 dt_avg = (dt_avg_value.total_seconds() * 1e9 / (len(coord) - 1)) / t_scale
227 dt_first = (dt_first_value.total_seconds() * 1e9) / t_scale
228 elif hasattr(dt_avg_value, "view"):
229 # numpy timedelta64 - convert to int64 view
230 dt_avg = (float(dt_avg_value.view("int64")) / (len(coord) - 1)) / t_scale
231 dt_first = float(dt_first_value.view("int64")) / t_scale
232 else:
233 # Already numeric
234 dt_avg = (float(dt_avg_value) / (len(coord) - 1)) / t_scale
235 dt_first = float(dt_first_value) / t_scale
237 if abs(dt_avg - dt_first) > rtol * min(dt_first, dt_avg):
238 # show warning at caller level to see which signal it is related to
239 warnings.warn(
240 f"Average sampling {dt_avg:.3g} != first sampling step {dt_first:.3g}",
241 UnevenSamplingWarning,
242 stacklevel=2,
243 )
244 return dt_avg
247def frequency_filter(
248 darray: xr.DataArray | xr.Dataset,
249 f_crit: float | list[float] | tuple[float, float],
250 order: int | None = None,
251 irtype: str = "iir",
252 filtfilt: bool = True,
253 apply_kwargs: dict[str, Any] | None = None,
254 in_nyq: bool = False,
255 dim: str | None = None,
256 **kwargs: Any,
257) -> xr.DataArray | xr.Dataset:
258 """
259 Apply a frequency filter to an xarray.
261 Supports IIR (infinite impulse response) and FIR (finite impulse response)
262 filters with optional forward-backward filtering for zero-phase response.
264 Parameters
265 ----------
266 darray : xarray.DataArray | xarray.Dataset
267 Data to be filtered.
268 f_crit : float | list[float] | tuple[float, float]
269 Critical frequency or frequencies (in Hz).
270 Scalar for lowpass/highpass, 2-element for bandpass/bandstop.
271 order : int | None, default None
272 Filter order. If None, uses default (8 for IIR, 29 for FIR).
273 irtype : {'iir', 'fir'}, default 'iir'
274 Impulse response type: 'iir' or 'fir'.
275 filtfilt : bool, default True
276 Apply filter forwards and backwards for zero-phase response.
277 apply_kwargs : dict | None, default None
278 Additional kwargs passed to the filter function.
279 in_nyq : bool, default False
280 If True, `f_crit` values are already normalized by Nyquist frequency.
281 dim : str | None, default None
282 Dimension along which to filter. Auto-detected for 1-D arrays.
283 **kwargs
284 Passed to filter design function (scipy.signal.iirfilter or firwin).
286 Returns
287 -------
288 xarray.DataArray | xarray.Dataset
289 Filtered data with same structure as input.
291 Raises
292 ------
293 ValueError
294 If `irtype` is not 'iir' or 'fir', or if `dim` is ambiguous.
296 Warnings
297 --------
298 FilteringNaNWarning
299 If input contains NaN values.
301 Examples
302 --------
303 Apply a 4th-order IIR lowpass at 10 Hz::
305 >>> filtered = frequency_filter(data, 10, order=4, btype='low')
307 FIR bandpass filter from 5 to 15 Hz::
309 >>> filtered = frequency_filter(data, [5, 15], irtype='fir', btype='band')
310 """
311 if irtype not in _BA_FUNCS:
312 raise ValueError(
313 "Wrong argument for irtype: {}, must be one of {}".format(
314 irtype, _BA_FUNCS.keys()
315 )
316 )
317 if order is None:
318 order = _ORDER_DEFAULTS[irtype]
319 if apply_kwargs is None:
320 apply_kwargs = {}
321 dim = get_maybe_only_dim(darray, dim)
323 f_crit_norm = np.asarray(f_crit, dtype=np.float64)
324 if not in_nyq: # normalize by Nyquist frequency
325 f_crit_norm *= 2 * get_sampling_step(darray, dim)
326 if np.any(
327 np.isnan(
328 np.asarray(darray.to_array() if isinstance(darray, xr.Dataset) else darray)
329 )
330 ): # only warn since simple forward-filter or FIR is valid
331 warnings.warn(
332 "data contains NaNs, filter will propagate them",
333 FilteringNaNWarning,
334 stacklevel=2,
335 )
336 if sosfiltfilt and irtype == "iir":
337 sos = scipy.signal.iirfilter(order, f_crit_norm, output="sos", **kwargs)
338 if filtfilt:
339 ret = xr.apply_ufunc(
340 sosfiltfilt,
341 sos,
342 darray,
343 input_core_dims=[[], [dim]],
344 output_core_dims=[[dim]],
345 kwargs=apply_kwargs,
346 )
347 else:
348 ret = xr.apply_ufunc(
349 scipy.signal.sosfilt,
350 sos,
351 darray,
352 input_core_dims=[[], [dim]],
353 output_core_dims=[[dim]],
354 kwargs=apply_kwargs,
355 )
356 else:
357 b, a = _BA_FUNCS[irtype](order, f_crit_norm, **kwargs)
358 if filtfilt:
359 ret = xr.apply_ufunc(
360 scipy.signal.filtfilt,
361 b,
362 a,
363 darray,
364 input_core_dims=[[], [], [dim]],
365 output_core_dims=[[dim]],
366 kwargs=apply_kwargs,
367 )
368 else:
369 ret = xr.apply_ufunc(
370 scipy.signal.lfilter,
371 b,
372 a,
373 darray,
374 input_core_dims=[[], [], [dim]],
375 output_core_dims=[[dim]],
376 kwargs=apply_kwargs,
377 )
378 return ret
381def _update_ftype_kwargs(kwargs, iirvalue, firvalue):
382 if kwargs.get("irtype", "iir") == "iir":
383 kwargs.setdefault("btype", iirvalue)
384 else: # fir
385 kwargs.setdefault("pass_zero", firvalue)
386 return kwargs
389def lowpass(
390 darray: xr.DataArray | xr.Dataset, f_cutoff: float, *args: Any, **kwargs: Any
391) -> xr.DataArray | xr.Dataset:
392 """
393 Apply a lowpass filter.
395 Parameters
396 ----------
397 darray : xarray.DataArray | xarray.Dataset
398 Data to filter.
399 f_cutoff : float
400 Cutoff frequency in Hz.
401 *args
402 Passed to `frequency_filter`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim).
403 **kwargs
404 Passed to filter design (see `frequency_filter`).
406 Returns
407 -------
408 xarray.DataArray | xarray.Dataset
409 Lowpass-filtered data.
411 Examples
412 --------
413 Remove components above 50 Hz::
415 >>> filtered = lowpass(data, 50)
416 """
417 kwargs = _update_ftype_kwargs(kwargs, "lowpass", True)
418 return frequency_filter(darray, f_cutoff, *args, **kwargs)
421def highpass(
422 darray: xr.DataArray | xr.Dataset, f_cutoff: float, *args: Any, **kwargs: Any
423) -> xr.DataArray | xr.Dataset:
424 """
425 Apply a highpass filter.
427 Parameters
428 ----------
429 darray : xarray.DataArray | xarray.Dataset
430 Data to filter.
431 f_cutoff : float
432 Cutoff frequency in Hz.
433 *args
434 Passed to `frequency_filter`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim).
435 **kwargs
436 Passed to filter design (see `frequency_filter`).
438 Returns
439 -------
440 xarray.DataArray | xarray.Dataset
441 Highpass-filtered data.
443 Examples
444 --------
445 Remove components below 1 Hz::
447 >>> filtered = highpass(data, 1.0)
448 """
449 kwargs = _update_ftype_kwargs(kwargs, "highpass", False)
450 return frequency_filter(darray, f_cutoff, *args, **kwargs)
453def bandpass(
454 darray: xr.DataArray | xr.Dataset,
455 f_low: float,
456 f_high: float,
457 *args: Any,
458 **kwargs: Any,
459) -> xr.DataArray | xr.Dataset:
460 """
461 Apply a bandpass filter.
463 Parameters
464 ----------
465 darray : xarray.DataArray | xarray.Dataset
466 Data to filter.
467 f_low : float
468 Lower cutoff frequency in Hz.
469 f_high : float
470 Upper cutoff frequency in Hz.
471 *args
472 Passed to `frequency_filter`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim).
473 **kwargs
474 Passed to filter design (see `frequency_filter`).
476 Returns
477 -------
478 xarray.DataArray | xarray.Dataset
479 Bandpass-filtered data.
481 Examples
482 --------
483 Keep components between 10 and 50 Hz::
485 >>> filtered = bandpass(data, 10, 50)
486 """
487 kwargs = _update_ftype_kwargs(kwargs, "bandpass", False)
488 return frequency_filter(darray, [f_low, f_high], *args, **kwargs)
491def bandstop(
492 darray: xr.DataArray | xr.Dataset,
493 f_low: float,
494 f_high: float,
495 *args: Any,
496 **kwargs: Any,
497) -> xr.DataArray | xr.Dataset:
498 """
499 Apply a bandstop (notch) filter.
501 Parameters
502 ----------
503 darray : xarray.DataArray | xarray.Dataset
504 Data to filter.
505 f_low : float
506 Lower cutoff frequency in Hz.
507 f_high : float
508 Upper cutoff frequency in Hz.
509 *args
510 Passed to `frequency_filter`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim).
511 **kwargs
512 Passed to filter design (see `frequency_filter`).
514 Returns
515 -------
516 xarray.DataArray | xarray.Dataset
517 Bandstop-filtered data (removes frequencies between f_low and f_high).
519 Examples
520 --------
521 Remove 50-60 Hz powerline noise::
523 >>> filtered = bandstop(data, 50, 60)
524 """
525 kwargs = _update_ftype_kwargs(kwargs, "bandstop", True)
526 return frequency_filter(darray, [f_low, f_high], *args, **kwargs)
529# def notch(
530# darray,
531# notch_freq,
532# notch_radius=0.5,
533# frequency_radius=0.9,
534# ripple=0.1,
535# db_stop_limit=5.0,
536# ):
537# ford, wn = signal.cheb1ord(wp, ws, 1, dbstop)
538# b, a = signal.cheby1(1, 0.5, wn, btype="bandstop")
539# bx = signal.filtfilt(b, a, bx)
542def decimate(
543 darray: xr.Dataset | xr.DataArray,
544 target_sample_rate: float,
545 n_order: int = 8,
546 dim: str | None = None,
547) -> xr.DataArray | xr.Dataset:
548 """
549 Decimate data using Chebyshev filter and downsampling.
551 Applies an 8th-order Chebyshev type I filter with zero-phase filtering
552 (sosfiltfilt) before downsampling.
554 Parameters
555 ----------
556 darray : xarray.DataArray | xarray.Dataset
557 Data to decimate.
558 target_sample_rate : float
559 Target sample rate in samples per second (or per space unit).
560 n_order : int, default 8
561 Order of the Chebyshev type I filter.
562 dim : str | None, default None
563 Dimension to decimate along. Auto-detected for 1-D arrays.
565 Returns
566 -------
567 xarray.DataArray | xarray.Dataset
568 Decimated data with adjusted coordinates.
570 Raises
571 ------
572 ValueError
573 If `dim` is None and array is not 1-D.
575 Warnings
576 --------
577 UserWarning
578 If decimation factor > 13, suggest calling decimate multiple times.
580 Notes
581 -----
582 If sample_rate / target_sample_rate > 13, call decimate multiple times
583 to avoid aliasing artifacts.
585 Examples
586 --------
587 Decimate from 100 Hz to 10 Hz::
589 >>> decimated = decimate(data, target_sample_rate=10.0)
590 """
592 dim = get_maybe_only_dim(darray, dim)
594 dt = get_sampling_step(darray, dim)
595 q = int(np.rint(1 / (dt * target_sample_rate)))
597 if q > 13:
598 warnings.warn(
599 f"Decimation factor is larger than 13 ({q}), the resulting "
600 "decimated array maybe incorrect. Suggest calling decimate "
601 "multiple times."
602 )
603 sos = scipy.signal.cheby1(n_order, 0.05, 0.8 / q, output="sos")
605 if sosfiltfilt is None:
606 raise ImportError("sosfiltfilt not available in scipy.signal")
608 ret = xr.apply_ufunc(
609 sosfiltfilt,
610 sos,
611 darray,
612 input_core_dims=[[], [dim]],
613 output_core_dims=[[dim]],
614 kwargs={},
615 )
617 return ret.isel(**{dim: slice(None, None, q)})
620def resample_poly(
621 darray: xr.DataArray | xr.Dataset,
622 new_sample_rate: float,
623 dim: str | None = None,
624 pad_type: str = "mean",
625) -> xr.DataArray | xr.Dataset:
626 """
627 Resample using polyphase filtering.
629 Computes rational resampling ratio (up/down) and applies
630 scipy.signal.resample_poly. Automatically handles coordinate updates.
632 Parameters
633 ----------
634 darray : xarray.DataArray | xarray.Dataset
635 Data to resample.
636 new_sample_rate : float
637 Target sample rate.
638 dim : str | None, default None
639 Dimension to resample along. Auto-detected for 1-D arrays.
640 pad_type : str, default 'mean'
641 Padding type passed to scipy.signal.resample_poly.
642 Options: 'constant', 'line', 'mean', 'median', etc.
644 Returns
645 -------
646 xarray.DataArray | xarray.Dataset
647 Resampled data with updated coordinates.
649 Raises
650 ------
651 ValueError
652 If `dim` is None and array is not 1-D.
654 Warnings
655 --------
656 UserWarning
657 If new sample rate is not an integer multiple of original rate.
659 Notes
660 -----
661 In newer scipy versions, data is cast to float and returns float dtype.
663 Examples
664 --------
665 Resample to 50 Hz::
667 >>> resampled = resample_poly(data, new_sample_rate=50.0)
668 """
669 dim = get_maybe_only_dim(darray, dim)
670 old_sample_rate = 1.0 / get_sampling_step(darray, dim)
672 fraction = Fraction(new_sample_rate / old_sample_rate).limit_denominator()
674 # need to resample the time coordinate because it will change and that
675 # is illegal in apply_ufunc.
676 dim = get_maybe_only_dim(darray, dim)
678 ret = xr.apply_ufunc(
679 scipy.signal.resample_poly,
680 darray.astype(float),
681 fraction.numerator,
682 fraction.denominator,
683 input_core_dims=[[dim], [], []],
684 output_core_dims=[[dim]],
685 exclude_dims=set([dim]),
686 kwargs={"padtype": pad_type},
687 )
689 dt = get_sampling_step(darray, dim)
690 new_step = 1 / (dt * new_sample_rate)
691 if new_step % 1 == 0:
692 q = int(np.rint(new_step))
693 # directly downsample without AAF on dimension
694 # this only works if q is an integer, otherwise to
695 # the index gets messed up from fractional spacing
696 new_dim = darray[dim].values[slice(None, None, q)]
698 else:
699 logger.warning(
700 "New sample rate is not an even number of original sample rate. "
701 f"The ratio is {new_step}. Use the new dimensions with caution."
702 )
703 # need to reset the end time
704 end_time = darray[dim].values[0] + np.timedelta64(
705 int(np.rint(((ret[dim].size - 1) / new_sample_rate) * 1e9)), "ns"
706 )
707 if dim in ["time"]:
708 new_dim = pd.date_range(
709 darray[dim].values[0],
710 end_time,
711 periods=ret[dim].size,
712 )
713 else:
714 end_index = (
715 int(np.rint((ret[dim].size - (darray[dim].size / new_step)))) - 1
716 )
717 new_dim = np.linspace(
718 darray[dim].values[0], darray[dim].values[end_index], ret[dim].size
719 )
721 # check to make sure the dimension size is the same as the new array
722 n_samples_data = len(ret[dim])
723 n_samples_axis = len(new_dim)
724 if n_samples_data != n_samples_axis:
725 logger.warning(
726 f"conflicting axes sizes {n_samples_data} data and {n_samples_axis}"
727 " axes after resampling"
728 )
729 logger.info(f"trimming {dim} axis from {n_samples_axis} to {n_samples_data}")
730 new_dim = new_dim[:n_samples_data]
732 ret[dim] = new_dim
734 return ret
737def savgol_filter(
738 darray: xr.DataArray | xr.Dataset,
739 window_length: int,
740 polyorder: int,
741 deriv: int = 0,
742 delta: float | None = None,
743 dim: str | None = None,
744 mode: str = "interp",
745 cval: float = 0.0,
746) -> xr.DataArray | xr.Dataset:
747 """
748 Apply a Savitzky-Golay filter.
750 Smooths data using least-squares polynomial fit over a sliding window.
751 Can also compute derivatives.
753 Parameters
754 ----------
755 darray : xarray.DataArray | xarray.Dataset
756 Data to filter. Converted to float64 if not already float.
757 window_length : int
758 Length of the filter window (number of coefficients).
759 Must be positive odd integer.
760 polyorder : int
761 Order of polynomial for fitting. Must be < window_length.
762 deriv : int, default 0
763 Order of derivative to compute (0 = no differentiation).
764 delta : float | None, default None
765 Sample spacing for derivative computation. Only used if deriv > 0.
766 dim : str | None, default None
767 Dimension along which to filter. Auto-detected for 1-D arrays.
768 mode : str, default 'interp'
769 Extension mode: 'mirror', 'constant', 'nearest', 'wrap', or 'interp'.
770 cval : float, default 0.0
771 Fill value when mode='constant'.
773 Returns
774 -------
775 xarray.DataArray | xarray.Dataset
776 Filtered data.
778 Raises
779 ------
780 ValueError
781 If window_length is not positive odd, polyorder >= window_length,
782 or `dim` is None for multi-dimensional arrays.
784 Examples
785 --------
786 Smooth with 11-point window, 2nd-order polynomial::
788 >>> smoothed = savgol_filter(data, window_length=11, polyorder=2)
790 Compute first derivative::
792 >>> deriv1 = savgol_filter(data, 11, 2, deriv=1, delta=0.01)
793 """
794 dim = get_maybe_only_dim(darray, dim)
795 if delta is None:
796 delta = get_sampling_step(darray, dim)
797 window_length = int(np.rint(window_length / delta))
798 if window_length % 2 == 0: # must be odd
799 window_length += 1
800 return xr.apply_ufunc(
801 scipy.signal.savgol_filter,
802 darray,
803 input_core_dims=[[dim]],
804 output_core_dims=[[dim]],
805 kwargs=dict(
806 window_length=window_length,
807 polyorder=polyorder,
808 deriv=deriv,
809 delta=delta,
810 mode=mode,
811 cval=cval,
812 ),
813 )
816def detrend(
817 darray: xr.DataArray | xr.Dataset,
818 dim: str | None = None,
819 trend_type: str = "linear",
820) -> xr.DataArray | xr.Dataset:
821 """
822 Remove linear or constant trend from data.
824 Parameters
825 ----------
826 darray : xarray.DataArray | xarray.Dataset
827 Data to detrend.
828 dim : str | None, default None
829 Dimension along which to detrend. Auto-detected for 1-D arrays.
830 trend_type : {'linear', 'constant'}, default 'linear'
831 Type of detrending: 'linear' removes linear trend via least-squares,
832 'constant' removes mean.
834 Returns
835 -------
836 xarray.DataArray | xarray.Dataset
837 Detrended data.
839 Raises
840 ------
841 ValueError
842 If `dim` is None and array is not 1-D.
844 Examples
845 --------
846 Remove linear trend::
848 >>> detrended = detrend(data, trend_type='linear')
850 Remove mean (DC component)::
852 >>> demeaned = detrend(data, trend_type='constant')
853 """
855 dim = get_maybe_only_dim(darray, dim)
857 return xr.apply_ufunc(
858 scipy.signal.detrend,
859 darray,
860 input_core_dims=[[dim]],
861 output_core_dims=[[dim]],
862 kwargs={"type": trend_type},
863 )
866@xr.register_dataarray_accessor("sps_filters")
867@xr.register_dataset_accessor("sps_filters")
868class FilterAccessor:
869 """
870 Accessor exposing common frequency and filtering methods.
872 Registered as xarray accessor under `.sps_filters` for both DataArray
873 and Dataset objects.
875 Attributes
876 ----------
877 darray : xarray.DataArray | xarray.Dataset
878 The wrapped xarray object.
880 Examples
881 --------
882 Apply filters via accessor::
884 >>> data.sps_filters.low(10) # lowpass at 10 Hz
885 >>> data.sps_filters.bandpass(5, 15) # bandpass 5-15 Hz
886 """
888 def __init__(self, darray: xr.DataArray | xr.Dataset) -> None:
889 self.darray: xr.DataArray | xr.Dataset = darray
891 @property
892 def dt(self) -> float:
893 """Sampling step of last axis."""
894 return get_sampling_step(self.darray)
896 @property
897 def fs(self) -> float:
898 """Sampling frequency in inverse units of self.dt."""
899 return 1.0 / self.dt
901 @property
902 def dx(self) -> np.ndarray:
903 """Sampling steps for all axes as array."""
904 return np.array(
905 [get_sampling_step(self.darray, str(dim)) for dim in self.darray.dims]
906 )
908 # NOTE: the arguments are coded explicitly for tab-completion to work,
909 # using a decorator wrapper with *args would not expose them
910 def low(
911 self, f_cutoff: float, *args: Any, **kwargs: Any
912 ) -> xr.DataArray | xr.Dataset:
913 """
914 Apply lowpass filter.
916 Parameters
917 ----------
918 f_cutoff : float
919 Cutoff frequency in Hz.
920 *args
921 Passed to `lowpass`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim).
922 **kwargs
923 Passed to filter design.
925 Returns
926 -------
927 xarray.DataArray | xarray.Dataset
928 Lowpass-filtered data.
929 """
930 return lowpass(self.darray, f_cutoff, *args, **kwargs)
932 def high(
933 self, f_cutoff: float, *args: Any, **kwargs: Any
934 ) -> xr.DataArray | xr.Dataset:
935 """
936 Apply highpass filter.
938 Parameters
939 ----------
940 f_cutoff : float
941 Cutoff frequency in Hz.
942 *args
943 Passed to `highpass`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim).
944 **kwargs
945 Passed to filter design.
947 Returns
948 -------
949 xarray.DataArray | xarray.Dataset
950 Highpass-filtered data.
951 """
952 return highpass(self.darray, f_cutoff, *args, **kwargs)
954 def bandpass(
955 self, f_low: float, f_high: float, *args: Any, **kwargs: Any
956 ) -> xr.DataArray | xr.Dataset:
957 """
958 Apply bandpass filter.
960 Parameters
961 ----------
962 f_low : float
963 Lower cutoff frequency in Hz.
964 f_high : float
965 Upper cutoff frequency in Hz.
966 *args
967 Passed to `bandpass`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim).
968 **kwargs
969 Passed to filter design.
971 Returns
972 -------
973 xarray.DataArray | xarray.Dataset
974 Bandpass-filtered data.
975 """
976 return bandpass(self.darray, f_low, f_high, *args, **kwargs)
978 def bandstop(
979 self, f_low: float, f_high: float, *args: Any, **kwargs: Any
980 ) -> xr.DataArray | xr.Dataset:
981 """
982 Apply bandstop filter.
984 Parameters
985 ----------
986 f_low : float
987 Lower cutoff frequency in Hz.
988 f_high : float
989 Upper cutoff frequency in Hz.
990 *args
991 Passed to `bandstop`: (order, irtype, filtfilt, apply_kwargs, in_nyq, dim).
992 **kwargs
993 Passed to filter design.
995 Returns
996 -------
997 xarray.DataArray | xarray.Dataset
998 Bandstop-filtered data.
999 """
1000 return bandstop(self.darray, f_low, f_high, *args, **kwargs)
1002 def freq(
1003 self,
1004 f_crit: float | list[float] | tuple[float, float],
1005 order: int | None = None,
1006 irtype: str = "iir",
1007 filtfilt: bool = True,
1008 apply_kwargs: dict[str, Any] | None = None,
1009 in_nyq: bool = False,
1010 dim: str | None = None,
1011 **kwargs: Any,
1012 ) -> xr.DataArray | xr.Dataset:
1013 """
1014 Apply general frequency filter.
1016 Parameters
1017 ----------
1018 f_crit : float | list[float] | tuple[float, float]
1019 Critical frequency or frequencies in Hz.
1020 order : int | None, default None
1021 Filter order.
1022 irtype : {'iir', 'fir'}, default 'iir'
1023 Impulse response type.
1024 filtfilt : bool, default True
1025 Apply forward-backward filtering.
1026 apply_kwargs : dict | None, default None
1027 Additional filter function kwargs.
1028 in_nyq : bool, default False
1029 If True, f_crit is normalized by Nyquist frequency.
1030 dim : str | None, default None
1031 Dimension along which to filter.
1032 **kwargs
1033 Passed to filter design.
1035 Returns
1036 -------
1037 xarray.DataArray | xarray.Dataset
1038 Filtered data.
1039 """
1040 return frequency_filter(
1041 self.darray,
1042 f_crit,
1043 order,
1044 irtype,
1045 filtfilt,
1046 apply_kwargs,
1047 in_nyq,
1048 dim,
1049 **kwargs,
1050 )
1052 __call__ = freq
1054 def savgol(
1055 self,
1056 window_length: int,
1057 polyorder: int,
1058 deriv: int = 0,
1059 delta: float | None = None,
1060 dim: str | None = None,
1061 mode: str = "interp",
1062 cval: float = 0.0,
1063 ) -> xr.DataArray | xr.Dataset:
1064 """
1065 Apply Savitzky-Golay filter.
1067 Parameters
1068 ----------
1069 window_length : int
1070 Filter window length (positive odd integer).
1071 polyorder : int
1072 Polynomial order (< window_length).
1073 deriv : int, default 0
1074 Derivative order.
1075 delta : float | None, default None
1076 Sample spacing.
1077 dim : str | None, default None
1078 Dimension to filter along.
1079 mode : str, default 'interp'
1080 Extension mode.
1081 cval : float, default 0.0
1082 Constant fill value.
1084 Returns
1085 -------
1086 xarray.DataArray | xarray.Dataset
1087 Filtered data.
1088 """
1089 return savgol_filter(
1090 self.darray,
1091 window_length,
1092 polyorder,
1093 deriv,
1094 delta,
1095 dim,
1096 mode,
1097 cval,
1098 )
1100 def decimate(
1101 self, target_sample_rate: float, n_order: int = 8, dim: str | None = None
1102 ) -> xr.DataArray | xr.Dataset:
1103 """
1104 Decimate signal.
1106 Parameters
1107 ----------
1108 target_sample_rate : float
1109 Target sample rate.
1110 n_order : int, default 8
1111 Chebyshev filter order.
1112 dim : str | None, default None
1113 Dimension to decimate along.
1115 Returns
1116 -------
1117 xarray.DataArray | xarray.Dataset
1118 Decimated data.
1119 """
1120 return decimate(self.darray, target_sample_rate, n_order, dim)
1122 def detrend(
1123 self, trend_type: str = "linear", dim: str | None = None
1124 ) -> xr.DataArray | xr.Dataset:
1125 """
1126 Remove trend from data.
1128 Parameters
1129 ----------
1130 trend_type : {'linear', 'constant'}, default 'linear'
1131 Type of trend to remove.
1132 dim : str | None, default None
1133 Dimension to detrend along.
1135 Returns
1136 -------
1137 xarray.DataArray | xarray.Dataset
1138 Detrended data.
1139 """
1140 return detrend(self.darray, dim, trend_type)
1142 def resample_poly(
1143 self, target_sample_rate: float, pad_type: str = "mean", dim: str | None = None
1144 ) -> xr.DataArray | xr.Dataset:
1145 """
1146 Resample using polyphase filtering.
1148 Parameters
1149 ----------
1150 target_sample_rate : float
1151 Target sample rate.
1152 pad_type : str, default 'mean'
1153 Padding type for resampling.
1154 dim : str | None, default None
1155 Dimension to resample along.
1157 Returns
1158 -------
1159 xarray.DataArray | xarray.Dataset
1160 Resampled data.
1161 """
1162 return resample_poly(
1163 self.darray, target_sample_rate, dim=dim, pad_type=pad_type
1164 )