Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ timeseries \ filters \ fir_filter.py: 75%
91 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
1# =====================================================
2# Imports
3# =====================================================
4from typing import Annotated
6import matplotlib.pyplot as plt
7import numpy as np
8from pydantic import computed_field, Field, field_validator, PrivateAttr, ValidationInfo
10from mt_metadata.base.helpers import requires
11from mt_metadata.timeseries.filters import FilterBase, get_base_obspy_mapping
14try:
15 from obspy.core.inventory.response import FIRResponseStage
16except ImportError:
17 FIRResponseStage = None
18import scipy.signal as signal
20from mt_metadata.common import SymmetryEnum
23# =====================================================
26class FIRFilter(FilterBase):
27 _filter_type: str = PrivateAttr("fir")
28 type: Annotated[
29 str,
30 Field(
31 default="fir",
32 description="Type of filter. Must be 'fir'",
33 alias=None,
34 json_schema_extra={
35 "units": None,
36 "required": True,
37 "examples": "fir",
38 },
39 ),
40 ]
41 coefficients: Annotated[
42 np.ndarray | list[float],
43 Field(
44 default_factory=lambda: np.empty(0),
45 description="The FIR coefficients associated with the filter stage response.",
46 alias=None,
47 json_schema_extra={
48 "units": None,
49 "required": True,
50 "items": {"type": "number"},
51 "examples": '"[0.25, 0.5, 0.25]"',
52 },
53 ),
54 ]
56 decimation_factor: Annotated[
57 float,
58 Field(
59 default=1.0,
60 description="Downsample factor.",
61 alias=None,
62 json_schema_extra={
63 "units": None,
64 "required": False,
65 "examples": "16",
66 },
67 ),
68 ]
70 decimation_input_sample_rate: Annotated[
71 float,
72 Field(
73 default=1.0,
74 description="Sample rate of FIR taps.",
75 alias=None,
76 json_schema_extra={
77 "units": None,
78 "required": False,
79 "examples": "2000",
80 },
81 ),
82 ]
84 @computed_field
85 @property
86 def output_sampling_rate(self) -> float:
87 return self.decimation_input_sample_rate / self.decimation_factor
89 gain_frequency: Annotated[
90 float,
91 Field(
92 default=0.0,
93 description="Frequency of the reference gain, usually in passband.",
94 alias=None,
95 json_schema_extra={
96 "units": "hertz",
97 "required": True,
98 "examples": "0.0",
99 },
100 ),
101 ]
103 symmetry: Annotated[
104 SymmetryEnum,
105 Field(
106 default="NONE",
107 description="Symmetry of FIR coefficients",
108 alias=None,
109 json_schema_extra={
110 "units": None,
111 "required": True,
112 "examples": "NONE",
113 },
114 ),
115 ]
117 @field_validator("coefficients")
118 @classmethod
119 def validate_coefficients(
120 cls, value: list[float], info: ValidationInfo
121 ) -> list[float]:
122 """
123 Validate the coefficients to ensure they are a list of floats.
124 :param value: The value to validate.
125 :param info: Validation information.
126 :return: The validated value.
127 """
128 if isinstance(value, (list, tuple, np.ndarray)):
129 return np.array(value, dtype=float)
130 elif isinstance(value, str):
131 return np.array(value.split(","), dtype=float)
132 else:
133 raise ValueError("Coefficients must be a list, tuple, or string.")
135 def make_obspy_mapping(self):
136 mapping = get_base_obspy_mapping()
137 mapping["_symmetry"] = "symmetry"
138 mapping["_coefficients"] = "coefficients"
139 mapping["decimation_factor"] = "decimation_factor"
140 mapping["decimation_input_sample_rate"] = "decimation_input_sample_rate"
141 mapping["stage_gain_frequency"] = "gain_frequency"
142 return mapping
144 @property
145 def symmetry_corrected_coefficients(self):
146 if self.symmetry == "EVEN":
147 return np.hstack((self.coefficients, np.flipud(self.coefficients)))
148 elif self.symmetry == "ODD":
149 return np.hstack((self.coefficients, np.flipud(self.coefficients[1:])))
150 else:
151 return self.coefficients
153 @property
154 def coefficient_gain(self):
155 """
156 The gain at the reference frequency due only to the coefficients
157 Sometimes this is different from the gain in the stationxml and a
158 corrective scalar must be applied
159 """
160 if self.gain_frequency == 0.0:
161 coefficient_gain = self.symmetry_corrected_coefficients.sum()
162 else:
163 # estimate the gain from the coefficeints at gain_frequency
164 ww, hh = signal.freqz(
165 self.symmetry_corrected_coefficients,
166 worN=2 * np.pi * self.gain_frequency,
167 fs=2 * np.pi * self.decimation_input_sample_rate,
168 )
169 coefficient_gain = np.abs(hh)
170 return coefficient_gain
172 @property
173 def n_coefficients(self):
174 return len(self.coefficients)
176 @property
177 def corrective_scalar(self):
178 """ """
179 if self.coefficient_gain != self.gain:
180 return self.coefficient_gain / self.total_gain
181 else:
182 return 1.0
184 def plot_fir_response(self):
185 w, h = signal.freqz(self.full_coefficients)
186 fig = plt.figure()
187 plt.title("Digital filter frequency response")
188 ax1 = fig.add_subplot(111)
189 plt.plot(w, 20 * np.log10(abs(h)), "b")
190 plt.ylabel("Amplitude [dB]", color="b")
191 plt.xlabel("Frequency [rad/sample]")
193 ax2 = ax1.twinx()
194 angles = np.unwrap(np.angle(h))
195 plt.plot(w, angles, "g")
196 plt.ylabel("Angle (radians)", color="g")
197 plt.grid()
198 plt.axis("tight")
199 plt.show()
201 return fig
203 @requires(obspy=FIRResponseStage)
204 def to_obspy(
205 self,
206 stage_number=1,
207 normalization_frequency=1,
208 sample_rate=1,
209 ):
210 """
211 create an obspy stage
213 :return: DESCRIPTION
214 :rtype: TYPE
216 """
217 # self, stage_sequence_number, stage_gain,
218 # stage_gain_frequency, input_units, output_units,
219 # symmetry="NONE", resource_id=None, resource_id2=None,
220 # name=None,
221 # coefficients=None, input_units_description=None,
222 # output_units_description=None, description=None,
223 # decimation_input_sample_rate=None, decimation_factor=None,
224 # decimation_offset=None, decimation_delay=None,
225 # decimation_correction=None
226 rs = FIRResponseStage(
227 stage_number,
228 self.gain,
229 normalization_frequency,
230 self.units_in_object.symbol,
231 self.units_out_object.symbol,
232 coefficients=self.coefficients.tolist(),
233 symmetry=self.symmetry,
234 name=self.name,
235 description=self.get_filter_description(),
236 input_units_description=self.units_in_object.name,
237 output_units_description=self.units_out_object.name,
238 decimation_input_sample_rate=self.decimation_input_sample_rate,
239 decimation_factor=self.decimation_factor,
240 )
242 return rs
244 def unscaled_complex_response(self, frequencies):
245 """
246 need this to avoid RecursionError.
247 The problem is that some FIRs need a scale factor to make their gains be
248 the same as those reported in the stationXML. The pure coefficients
249 themselves sometimes result in pass-band gains that differ from the
250 gain in the XML. For example filter fs2d5 has a passband gain of 0.5
251 based on the coefficients alone, but the cited gain in the xml is
252 almost 1.
254 I wanted to scale the coefficients so they equal the gain... but maybe
255 we can add the gain in complex response
256 :param frequencies:
257 :return:
258 """
259 angular_frequencies = 2 * np.pi * frequencies
260 w, h = signal.freqz(
261 self.symmetry_corrected_coefficients,
262 worN=angular_frequencies,
263 fs=2 * np.pi * self.decimation_input_sample_rate,
264 )
265 return h
267 def complex_response(self, frequencies, **kwargs):
268 """
270 Parameters
271 ----------
272 frequencies: numpy array of frequencies, expected in Hz
274 Returns
275 -------
276 h : numpy array of (possibly complex-valued) frequency response at the input frequencies
278 """
279 # fir_filter.full_coefficients
280 angular_frequencies = 2 * np.pi * frequencies
281 w, h = signal.freqz(
282 self.symmetry_corrected_coefficients,
283 worN=angular_frequencies,
284 fs=2 * np.pi * self.decimation_input_sample_rate,
285 )
286 h /= self.corrective_scalar
288 return h