Source code for scitex_dsp.filt

#!/usr/bin/env python3
# Time-stamp: "2024-11-04 02:05:47 (ywatanabe)"
# File: ./scitex_repo/src/scitex/dsp/filt.py

import numpy as np

from scitex_decorators import signal_fn

# No top-level imports from nn module to avoid circular dependency
# Filters will be imported inside functions when needed


[docs] @signal_fn def gauss(x, sigma, t=None): from scitex_nn._Filters import GaussianFilter return GaussianFilter(sigma)(x, t=t)
[docs] @signal_fn def bandpass(x, fs, bands, t=None): import torch from scitex_nn._Filters import BandPassFilter # Convert bands to tensor if it's not already if not isinstance(bands, torch.Tensor): bands = torch.tensor(bands, dtype=torch.float32) return BandPassFilter(bands, fs, x.shape[-1])(x, t=t)
[docs] @signal_fn def bandstop(x, fs, bands, t=None): import torch from scitex_nn._Filters import BandStopFilter # Convert bands to tensor if it's not already if not isinstance(bands, torch.Tensor): bands = torch.tensor(bands, dtype=torch.float32) return BandStopFilter(bands, fs, x.shape[-1])(x, t=t)
[docs] @signal_fn def lowpass(x, fs, cutoffs_hz, t=None): from scitex_nn._Filters import LowPassFilter return LowPassFilter(cutoffs_hz, fs, x.shape[-1])(x, t=t)
[docs] @signal_fn def highpass(x, fs, cutoffs_hz, t=None): from scitex_nn._Filters import HighPassFilter return HighPassFilter(cutoffs_hz, fs, x.shape[-1])(x, t=t)
def _custom_print(x): print(type(x), x.shape) if __name__ == "__main__": import sys import scitex # noqa: E402 — script-only import matplotlib.pyplot as plt # Start CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) # Parametes T_SEC = 1 SRC_FS = 1024 FREQS_HZ = list(np.linspace(0, 500, 10, endpoint=False).astype(int)) SIG_TYPE = "periodic" BANDS = np.vstack([[80, 310]]) SIGMA = 3 # Demo Signal xx, tt, fs = scitex.dsp.demo_sig( t_sec=T_SEC, fs=SRC_FS, freqs_hz=FREQS_HZ, sig_type=SIG_TYPE, ) # Filtering x_bp, t_bp = scitex.dsp.filt.bandpass(xx, fs, BANDS, t=tt) x_bs, t_bs = scitex.dsp.filt.bandstop(xx, fs, BANDS, t=tt) x_lp, t_lp = scitex.dsp.filt.lowpass(xx, fs, BANDS[:, 0], t=tt) x_hp, t_hp = scitex.dsp.filt.highpass(xx, fs, BANDS[:, 1], t=tt) x_g, t_g = scitex.dsp.filt.gauss(xx, sigma=SIGMA, t=tt) filted = { f"Original (Sum of {FREQS_HZ}-Hz signals)": (xx, tt, fs), f"Bandpass-filtered ({BANDS[0][0]} - {BANDS[0][1]} Hz)": ( x_bp, t_bp, fs, ), f"Bandstop-filtered ({BANDS[0][0]} - {BANDS[0][1]} Hz)": ( x_bs, t_bs, fs, ), f"Lowpass-filtered ({BANDS[0][0]} Hz)": (x_lp, t_lp, fs), f"Highpass-filtered ({BANDS[0][1]} Hz)": (x_hp, t_hp, fs), f"Gaussian-filtered (sigma = {SIGMA} SD [point])": (x_g, t_g, fs), } # Plots traces fig, axes = plt.subplots(nrows=len(filted), ncols=1, sharex=True, sharey=True) i_batch = 0 i_ch = 0 i_filt = 0 for ax, (k, v) in zip(axes, filted.items()): _xx, _tt, _fs = v if _xx.ndim == 3: _xx = _xx[i_batch, i_ch] elif _xx.ndim == 4: _xx = _xx[i_batch, i_ch, i_filt] ax.plot(_tt, _xx, label=k) ax.legend(loc="upper left") fig.suptitle("Filtered") fig.supxlabel("Time [s]") fig.supylabel("Amplitude") scitex.io.save(fig, "traces.png") # Calculates and Plots PSD fig, axes = plt.subplots(nrows=len(filted), ncols=1, sharex=True, sharey=True) i_batch = 0 i_ch = 0 i_filt = 0 for ax, (k, v) in zip(axes, filted.items()): _xx, _tt, _fs = v _psd, ff = scitex.dsp.psd(_xx, _fs) if _psd.ndim == 3: _psd = _psd[i_batch, i_ch] elif _psd.ndim == 4: _psd = _psd[i_batch, i_ch, i_filt] ax.plot(ff, _psd, label=k) ax.legend(loc="upper left") for bb in np.hstack(BANDS): ax.axvline(x=bb, color=CC["grey"], linestyle="--") fig.suptitle("PSD (power spectrum density) of filtered signals") fig.supxlabel("Frequency [Hz]") fig.supylabel("log(Power [uV^2 / Hz]) [a.u.]") scitex.io.save(fig, "psd.png") # Close scitex.session.close(CONFIG) # EOF """ /home/ywatanabe/proj/scitex/src/scitex/dsp/filt.py """ # EOF