Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ timeseries \ filters \ pole_zero_filter.py: 78%
64 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
1# =====================================================
2# Imports
3# =====================================================
4from typing import Annotated
6import numpy as np
7from pydantic import Field, field_validator, ValidationInfo
10try:
11 import obspy
12except ImportError:
13 obspy = None
14import scipy.signal as signal
16from mt_metadata.base.helpers import object_to_array, requires
17from mt_metadata.timeseries.filters import FilterBase, get_base_obspy_mapping
20# =====================================================
21class PoleZeroFilter(FilterBase):
22 _filter_type: str = "zpk"
23 type: Annotated[
24 str,
25 Field(
26 default="zpk",
27 description="Type of filter. Must be 'zpk'",
28 alias=None,
29 json_schema_extra={
30 "units": None,
31 "required": True,
32 "examples": "zpk",
33 },
34 ),
35 ]
36 poles: Annotated[
37 np.ndarray | list[complex] | complex,
38 Field(
39 default_factory=lambda: np.empty(0, dtype=complex),
40 description="The complex-valued poles associated with the filter response.",
41 alias=None,
42 json_schema_extra={
43 "units": None,
44 "required": True,
45 "examples": '"[-1/4., -0.1+j*0.3, -0.1-j*0.3]"',
46 },
47 ),
48 ]
50 zeros: Annotated[
51 np.ndarray | list[complex] | complex,
52 Field(
53 default_factory=lambda: np.empty(0, dtype=complex),
54 description="The complex-valued zeros associated with the filter response.",
55 alias=None,
56 json_schema_extra={
57 "units": None,
58 "required": True,
59 "examples": '"[0.0, ]"',
60 },
61 ),
62 ]
64 normalization_factor: Annotated[
65 float,
66 Field(
67 default=1.0,
68 description="The scale factor to apply to the monic response.",
69 alias=None,
70 json_schema_extra={
71 "units": None,
72 "required": True,
73 "examples": '"[-1000.1]"',
74 },
75 ),
76 ]
78 @field_validator("poles", "zeros", mode="before")
79 @classmethod
80 def validate_input_arrays(cls, value, info: ValidationInfo) -> np.ndarray:
81 """
82 Validate that the input is a list, tuple, or np.ndarray and convert to np.ndarray.
83 """
84 return object_to_array(value, dtype=complex)
86 def make_obspy_mapping(self):
87 mapping = get_base_obspy_mapping()
88 mapping["_zeros"] = "zeros"
89 mapping["_poles"] = "poles"
90 mapping["normalization_factor"] = "normalization_factor"
91 return mapping
93 @property
94 def n_poles(self):
95 """
96 :return: number of poles
97 :rtype: integer
99 """
100 return len(self.poles)
102 @property
103 def n_zeros(self):
104 """
106 :return: number of zeros
107 :rtype: integer
109 """
110 return len(self.zeros)
112 def zero_pole_gain_representation(self):
113 """
115 :return: scipy.signal.ZPG object
116 :rtype: :class:`scipy.signal.ZerosPolesGain`
118 """
119 zpg = signal.ZerosPolesGain(self.zeros, self.poles, self.normalization_factor)
120 return zpg
122 @property
123 def total_gain(self):
124 """
126 :return: total gain of the filter
127 :rtype: float
129 """
130 return self.gain * self.normalization_factor
132 @requires(obspy=obspy)
133 def to_obspy(
134 self,
135 stage_number=1,
136 pz_type="LAPLACE (RADIANS/SECOND)",
137 normalization_frequency=1,
138 sample_rate=1,
139 ):
140 """
141 Convert the filter to an obspy filter
143 :param stage_number: sequential stage number, defaults to 1
144 :type stage_number: integer, optional
145 :param pz_type: Pole Zero type, defaults to "LAPLACE (RADIANS/SECOND)"
146 :type pz_type: string, optional
147 :param normalization_frequency: Normalization frequency, defaults to 1
148 :type normalization_frequency: float, optional
149 :param sample_rate: sample rate, defaults to 1
150 :type sample_rate: float, optional
151 :return: Obspy stage filter
152 :rtype: :class:`obspy.core.inventory.PolesZerosResponseStage`
154 """
155 if self.zeros is None:
156 self.zeros = []
157 if self.poles is None:
158 self.poles = []
160 rs = obspy.core.inventory.PolesZerosResponseStage(
161 stage_number,
162 self.gain,
163 normalization_frequency,
164 self.units_in_object.symbol,
165 self.units_out_object.symbol,
166 pz_type,
167 normalization_frequency,
168 self.zeros,
169 self.poles,
170 name=self.name,
171 normalization_factor=self.normalization_factor,
172 description=self.get_filter_description(),
173 input_units_description=self.units_in_object.name,
174 output_units_description=self.units_out_object.name,
175 )
177 return rs
179 def complex_response(self, frequencies, **kwargs):
180 """
181 Computes complex response for given frequency range
182 :param frequencies: array of frequencies to estimate the response
183 :type frequencies: np.ndarray
185 :return: complex response
186 :rtype: np.ndarray
188 """
189 angular_frequencies = 2 * np.pi * np.array(frequencies)
190 w, h = signal.freqs_zpk(
191 self.zeros, self.poles, self.total_gain, worN=angular_frequencies
192 )
194 return h
196 def normalization_frequency(
197 self,
198 frequencies: np.ndarray = np.logspace(-4, 4, 32),
199 estimate: str = "mean",
200 window_len: int = 5,
201 tol: float = 1e-4,
202 ) -> float:
203 """
204 Try to estimate the normalization frequency in the pass band
205 by finding the flattest spot in the amplitude.
207 The flattest spot is determined by calculating a sliding window
208 with length `window_len` and estimating normalized std.
210 ..note:: This only works for simple filters with
211 on flat pass band.
213 :param window_len: length of sliding window in points
214 :type window_len: integer
216 :param tol: the ratio of the mean/std should be around 1
217 tol is the range around 1 to find the flat part of the curve.
218 :type tol: float
220 :return: estimated normalization frequency Hz
221 :rtype: float
223 """
224 pass_band = self.pass_band(frequencies, window_len, tol)
225 if pass_band is None:
226 return np.NAN
227 if pass_band.size == 0:
228 return np.NAN
230 if estimate == "mean":
231 return pass_band.mean()
233 elif estimate == "median":
234 return np.median(pass_band)
236 elif estimate == "min":
237 return pass_band.min()
239 elif estimate == "max":
240 return pass_band.max()