#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Time-stamp: "2024-11-03 07:24:43 (ywatanabe)"
# File: ./scitex_repo/src/scitex/dsp/utils/filter.py
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import firwin, freqz
from scitex_decorators import numpy_fn
from scitex_gen._to_even import to_even
[docs]
@numpy_fn
def design_filter(sig_len, fs, low_hz=None, high_hz=None, cycle=3, is_bandstop=False):
"""
Designs a Finite Impulse Response (FIR) filter based on the specified parameters.
Arguments:
- sig_len (int): Length of the signal for which the filter is being designed.
- fs (int): Sampling frequency of the signal.
- low_hz (float, optional): Low cutoff frequency for the filter. Required for lowpass and bandpass filters.
- high_hz (float, optional): High cutoff frequency for the filter. Required for highpass and bandpass filters.
- cycle (int, optional): Number of cycles to use in determining the filter order. Defaults to 3.
- is_bandstop (bool, optional): Specifies if the filter should be a bandstop filter. Defaults to False.
Returns:
- The coefficients of the designed FIR filter.
Raises:
- FilterParameterError: If the provided parameters are invalid.
"""
class FilterParameterError(Exception):
"""Custom exception for invalid filter parameters."""
pass
def estimate_filter_type(low_hz=None, high_hz=None, is_bandstop=False):
"""
Estimates the filter type based on the provided low and high cutoff frequencies,
and whether a bandstop filter is desired. Raises an exception for invalid configurations.
"""
if low_hz is not None and low_hz < 0:
raise FilterParameterError("low_hz must be non-negative.")
if high_hz is not None and high_hz < 0:
raise FilterParameterError("high_hz must be non-negative.")
if low_hz is not None and high_hz is not None and low_hz >= high_hz:
raise FilterParameterError(
"low_hz must be less than high_hz for valid configurations."
)
if low_hz is not None and high_hz is not None:
return "bandstop" if is_bandstop else "bandpass"
elif low_hz is not None:
return "lowpass"
elif high_hz is not None:
return "highpass"
else:
raise FilterParameterError(
"At least one of low_hz or high_hz must be provided."
)
def determine_cutoff_frequencies(filter_mode, low_hz, high_hz):
if filter_mode in ["lowpass", "highpass"]:
cutoff = low_hz if filter_mode == "lowpass" else high_hz
else: # 'bandpass' or 'bandstop'
cutoff = [low_hz, high_hz]
return cutoff
def determine_low_freq(filter_mode, low_hz, high_hz):
if filter_mode in ["lowpass", "bandstop"]:
low_freq = low_hz
else: # 'highpass' or 'bandpass'
low_freq = high_hz if filter_mode == "highpass" else min(low_hz, high_hz)
return low_freq
def determine_order(filter_mode, fs, low_freq, sig_len, cycle):
order = cycle * int((fs // low_freq))
if 3 * order < sig_len:
order = (sig_len - 1) // 3
order = to_even(order)
return order
# fs may arrive as a 0-d torch tensor (from @signal_fn-decorated callers
# or numpy-array wrappers); int(tensor) needs .item() on 0-d.
def _to_scalar(value, cast):
# Accept 0-d / 1-d / single-element arrays/tensors. NumPy 2+
# rejects float(np.array([x])) — funnel through .item().
if value is None:
return None
if hasattr(value, "item"):
try:
return cast(value.item())
except (ValueError, TypeError):
pass
return cast(value)
fs = _to_scalar(fs, int)
low_hz = _to_scalar(low_hz, float)
high_hz = _to_scalar(high_hz, float)
filter_mode = estimate_filter_type(low_hz, high_hz, is_bandstop)
cutoff = determine_cutoff_frequencies(filter_mode, low_hz, high_hz)
low_freq = determine_low_freq(filter_mode, low_hz, high_hz)
order = determine_order(filter_mode, fs, low_freq, sig_len, cycle)
numtaps = order + 1
try:
h = firwin(
numtaps=numtaps,
cutoff=cutoff,
pass_zero=(filter_mode in ["highpass", "bandstop"]),
window="hamming",
fs=fs,
scale=True,
)
except Exception as e:
print(e)
import ipdb
ipdb.set_trace()
return h
[docs]
@numpy_fn
def plot_filter_responses(filter, fs, worN=8000, title=None):
"""
Plots the impulse and frequency response of an FIR filter using numpy arrays.
Parameters:
- filter_coeffs (numpy.ndarray): The filter coefficients as a numpy array.
- fs (int): The sampling frequency in Hz.
- title (str, optional): The title of the plot. Defaults to None.
Returns:
- matplotlib.figure.Figure: The figure object containing the impulse and frequency response plots.
"""
try:
import scitex # type: ignore
except ImportError as exc:
raise ImportError(
"scitex.plt requires additional dependencies. "
"Install with: pip install scitex[plt]"
) from exc
ww, hh = freqz(filter, worN=worN, fs=fs)
fig, axes = scitex.plt.subplots(ncols=2)
fig.suptitle(title)
# Impulse Responses of FIR Filter
ax = axes[0]
ax.plot(filter)
ax.set_title("Impulse Responses of FIR Filter")
ax.set_xlabel("Tap Number")
ax.set_ylabel("Amplitude")
# Frequency Response of FIR Filter
ax = axes[1]
ax.plot(ww, 20 * np.log10(abs(hh) + 1e-5))
ax.set_title("Frequency Response of FIR Filter")
ax.set_xlabel("Frequency [Hz]")
ax.set_ylabel("Gain [dB]")
return fig
if __name__ == "__main__":
import scitex
# Example usage
xx, tt, fs = scitex.dsp.demo_sig()
batch_size, n_chs, seq_len = xx.shape
lp_filter = design_filter(seq_len, fs, low_hz=30, high_hz=None)
hp_filter = design_filter(seq_len, fs, low_hz=None, high_hz=70)
bp_filter = design_filter(seq_len, fs, low_hz=30, high_hz=70)
bs_filter = design_filter(seq_len, fs, low_hz=30, high_hz=70, is_bandstop=True)
fig = plot_filter_responses(lp_filter, fs, title="Lowpass Filter")
fig = plot_filter_responses(hp_filter, fs, title="Highpass Filter")
fig = plot_filter_responses(bp_filter, fs, title="Bandpass Filter")
fig = plot_filter_responses(bs_filter, fs, title="Bandstop Filter")
# Figure
fig, axes = plt.subplots(nrows=4, ncols=2)
# Time domain expressions??
axes[0, 0].plot(lp_filter, label="Lowpass Filter")
axes[1, 0].plot(hp_filter, label="Highpass Filter")
axes[2, 0].plot(bp_filter, label="Bandpass Filter")
axes[3, 0].plot(bs_filter, label="Bandstop Filter")
# fig.suptitle("Impulse Responses of FIR Filter")
# fig.supxlabel("Tap Number")
# fig.supylabel("Amplitude")
# fig.show()
# Frequency response of the filters
w, h_lp = freqz(lp_filter, worN=8000, fs=fs)
w, h_hp = freqz(hp_filter, worN=8000, fs=fs)
w, h_bp = freqz(bp_filter, worN=8000, fs=fs)
w, h_bs = freqz(bs_filter, worN=8000, fs=fs)
# Plotting the frequency response
axes[0, 1].plot(w, 20 * np.log10(abs(h_lp)), label="Lowpass Filter")
axes[1, 1].plot(w, 20 * np.log10(abs(h_hp)), label="Highpass Filter")
axes[2, 1].plot(w, 20 * np.log10(abs(h_bp)), label="Bandpass Filter")
axes[3, 1].plot(w, 20 * np.log10(abs(h_bs)), label="Bandstop Filter")
# plt.title("Frequency Response of FIR Filters")
# plt.xlabel("Frequency (Hz)")
# plt.ylabel("Gain (dB)")
# plt.grid(True)
# plt.legend(loc="best")
# plt.show()
fig.tight_layout()
plt.show()
# @torch_fn
# def bandpass(x, filt):
# assert x.ndim == 3
# xf = F.conv1d(
# x.reshape(-1, x.shape[-1]).unsqueeze(1),
# filt.unsqueeze(0).unsqueeze(0),
# padding="same",
# ).reshape(*x.shape)
# assert x.shape == xf.shape
# return xf
# def define_bandpass_filters(seq_len, fs, freq_bands, cycle=3):
# """
# Defines Finite Impulse Response (FIR) filters.
# b: The filter coefficients (or taps) of the FIR filters
# a: The denominator coefficients of the filter's transfer function. However, FIR filters have a transfer function with a denominator equal to 1 (since they are all-zero filters with no poles).
# """
# # Parameters
# n_freqs = len(freq_bands)
# nyq = fs / 2.0
# bs = []
# for ll, hh in freq_bands:
# wn = np.array([ll, hh]) / nyq
# order = define_fir_order(fs, seq_len, ll, cycle=cycle)
# bs.append(fir1(order, wn)[0])
# return bs
# def define_fir_order(fs, sizevec, flow, cycle=3):
# """
# Calculate filter order.
# """
# if cycle is None:
# filtorder = 3 * np.fix(fs / flow)
# else:
# filtorder = cycle * (fs // flow)
# if sizevec < 3 * filtorder:
# filtorder = (sizevec - 1) // 3
# return int(filtorder)
# def n_odd_fcn(f, o, w, l):
# """Odd case."""
# # Variables :
# b0 = 0
# m = np.array(range(int(l + 1)))
# k = m[1 : len(m)]
# b = np.zeros(k.shape)
# # Run Loop :
# for s in range(0, len(f), 2):
# m = (o[s + 1] - o[s]) / (f[s + 1] - f[s])
# b1 = o[s] - m * f[s]
# b0 = b0 + (
# b1 * (f[s + 1] - f[s])
# + m / 2 * (f[s + 1] * f[s + 1] - f[s] * f[s])
# ) * abs(np.square(w[round((s + 1) / 2)]))
# b = b + (
# m
# / (4 * np.pi * np.pi)
# * (
# np.cos(2 * np.pi * k * f[s + 1])
# - np.cos(2 * np.pi * k * f[s])
# )
# / (k * k)
# ) * abs(np.square(w[round((s + 1) / 2)]))
# b = b + (
# f[s + 1] * (m * f[s + 1] + b1) * np.sinc(2 * k * f[s + 1])
# - f[s] * (m * f[s] + b1) * np.sinc(2 * k * f[s])
# ) * abs(np.square(w[round((s + 1) / 2)]))
# b = np.insert(b, 0, b0)
# a = (np.square(w[0])) * 4 * b
# a[0] = a[0] / 2
# aud = np.flipud(a[1 : len(a)]) / 2
# a2 = np.insert(aud, len(aud), a[0])
# h = np.concatenate((a2, a[1:] / 2))
# return h
# def n_even_fcn(f, o, w, l):
# """Even case."""
# # Variables :
# k = np.array(range(0, int(l) + 1, 1)) + 0.5
# b = np.zeros(k.shape)
# # # Run Loop :
# for s in range(0, len(f), 2):
# m = (o[s + 1] - o[s]) / (f[s + 1] - f[s])
# b1 = o[s] - m * f[s]
# b = b + (
# m
# / (4 * np.pi * np.pi)
# * (
# np.cos(2 * np.pi * k * f[s + 1])
# - np.cos(2 * np.pi * k * f[s])
# )
# / (k * k)
# ) * abs(np.square(w[round((s + 1) / 2)]))
# b = b + (
# f[s + 1] * (m * f[s + 1] + b1) * np.sinc(2 * k * f[s + 1])
# - f[s] * (m * f[s] + b1) * np.sinc(2 * k * f[s])
# ) * abs(np.square(w[round((s + 1) / 2)]))
# a = (np.square(w[0])) * 4 * b
# h = 0.5 * np.concatenate((np.flipud(a), a))
# return h
# def firls(n, f, o):
# # Variables definition :
# w = np.ones(round(len(f) / 2))
# n += 1
# f /= 2
# lo = (n - 1) / 2
# nodd = bool(n % 2)
# if nodd: # Odd case
# h = n_odd_fcn(f, o, w, lo)
# else: # Even case
# h = n_even_fcn(f, o, w, lo)
# return h
# def fir1(n, wn):
# # Variables definition :
# nbands = len(wn) + 1
# ff = np.array((0, wn[0], wn[0], wn[1], wn[1], 1))
# f0 = np.mean(ff[2:4])
# lo = n + 1
# mags = np.array(range(nbands)).reshape(1, -1) % 2
# aa = np.ravel(np.tile(mags, (2, 1)), order="F")
# # Get filter coefficients :
# h = firls(lo - 1, ff, aa)
# # Apply a window to coefficients :
# wind = np.hamming(lo)
# b = h * wind
# c = np.exp(-1j * 2 * np.pi * (f0 / 2) * np.array(range(lo)))
# b /= abs(c @ b)
# return b, 1
# def apply_filters(x, filts):
# """
# x: (batch_size, n_chs, seq_len)
# filts: (n_filts, seq_len_filt)
# """
# assert x.ndims == 3
# assert filts.ndims == 2
# batch_size, n_chs, n_time = x.shape
# x = x.reshape(-1, x.shape[-1]).unsqueeze(1)
# filts = filts.unsqueeze(1)
# n_filts = len(filts)
# return F.conv1d(x, filts, padding="same").reshape(
# batch_size, n_chs, n_filts, n_time
# )
# if __name__ == "__main__":
# import torch
# import torch.nn.functional as F
# plt, CC = scitex.plt.configure_mpl(plt)
# # Demo Signal
# freqs_hz = [10, 30, 100]
# xx, tt, fs = scitex.dsp.demo_sig(freqs_hz=freqs_hz, sig_type="periodic")
# x = xx
# seq_len = x.shape[-1]
# freq_bands = np.array([[20, 70], [3.0, 4.0]])
# # Plots the figure
# fig, ax = scitex.plt.subplots()
# # ax.plot(b, label="bandpass filter")
# # Bandpass Filtering
# filters = define_bandpass_filters(seq_len, fs, freq_bands, cycle=3)
# i_filt = 0
# # xf = bandpass(xx, filters[i_filt])
# # Plots the signals
# fig, axes = scitex.plt.subplots(nrows=2, sharex=True, sharey=True)
# axes[0].plot(tt, xx[0, 0], label="orig")
# axes[1].plot(tt, xf[0, 0], label="orig")
# [ax.legend(loc="upper left") for ax in axes]
# # Plots PSDs
# psd_xx, ff_xx = scitex.dsp.psd(xx.numpy(), fs)
# psd_xf, ff_xf = scitex.dsp.psd(xf.numpy(), fs)
# fig, axes = scitex.plt.subplots(nrows=2, sharex=True, sharey=True)
# axes[0].plot(ff_xx, psd_xx[0, 0], label="orig")
# axes[1].plot(ff_xf, psd_xf[0, 0], label="filted")
# [ax.legend(loc="upper left") for ax in axes]
# plt.show()
# # Multiple Filters in a parallel computation
# x = torch.randn(33, 32, 30)
# filters = torch.randn(20, 5)
# y = apply_filters(x, filters)
# print(y.shape) # (33, 32, 20, 30)
# EOF