Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ timeseries \ filters \ pole_zero_filter.py: 78%

64 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1# ===================================================== 

2# Imports 

3# ===================================================== 

4from typing import Annotated 

5 

6import numpy as np 

7from pydantic import Field, field_validator, ValidationInfo 

8 

9 

10try: 

11 import obspy 

12except ImportError: 

13 obspy = None 

14import scipy.signal as signal 

15 

16from mt_metadata.base.helpers import object_to_array, requires 

17from mt_metadata.timeseries.filters import FilterBase, get_base_obspy_mapping 

18 

19 

20# ===================================================== 

21class PoleZeroFilter(FilterBase): 

22 _filter_type: str = "zpk" 

23 type: Annotated[ 

24 str, 

25 Field( 

26 default="zpk", 

27 description="Type of filter. Must be 'zpk'", 

28 alias=None, 

29 json_schema_extra={ 

30 "units": None, 

31 "required": True, 

32 "examples": "zpk", 

33 }, 

34 ), 

35 ] 

36 poles: Annotated[ 

37 np.ndarray | list[complex] | complex, 

38 Field( 

39 default_factory=lambda: np.empty(0, dtype=complex), 

40 description="The complex-valued poles associated with the filter response.", 

41 alias=None, 

42 json_schema_extra={ 

43 "units": None, 

44 "required": True, 

45 "examples": '"[-1/4., -0.1+j*0.3, -0.1-j*0.3]"', 

46 }, 

47 ), 

48 ] 

49 

50 zeros: Annotated[ 

51 np.ndarray | list[complex] | complex, 

52 Field( 

53 default_factory=lambda: np.empty(0, dtype=complex), 

54 description="The complex-valued zeros associated with the filter response.", 

55 alias=None, 

56 json_schema_extra={ 

57 "units": None, 

58 "required": True, 

59 "examples": '"[0.0, ]"', 

60 }, 

61 ), 

62 ] 

63 

64 normalization_factor: Annotated[ 

65 float, 

66 Field( 

67 default=1.0, 

68 description="The scale factor to apply to the monic response.", 

69 alias=None, 

70 json_schema_extra={ 

71 "units": None, 

72 "required": True, 

73 "examples": '"[-1000.1]"', 

74 }, 

75 ), 

76 ] 

77 

78 @field_validator("poles", "zeros", mode="before") 

79 @classmethod 

80 def validate_input_arrays(cls, value, info: ValidationInfo) -> np.ndarray: 

81 """ 

82 Validate that the input is a list, tuple, or np.ndarray and convert to np.ndarray. 

83 """ 

84 return object_to_array(value, dtype=complex) 

85 

86 def make_obspy_mapping(self): 

87 mapping = get_base_obspy_mapping() 

88 mapping["_zeros"] = "zeros" 

89 mapping["_poles"] = "poles" 

90 mapping["normalization_factor"] = "normalization_factor" 

91 return mapping 

92 

93 @property 

94 def n_poles(self): 

95 """ 

96 :return: number of poles 

97 :rtype: integer 

98 

99 """ 

100 return len(self.poles) 

101 

102 @property 

103 def n_zeros(self): 

104 """ 

105 

106 :return: number of zeros 

107 :rtype: integer 

108 

109 """ 

110 return len(self.zeros) 

111 

112 def zero_pole_gain_representation(self): 

113 """ 

114 

115 :return: scipy.signal.ZPG object 

116 :rtype: :class:`scipy.signal.ZerosPolesGain` 

117 

118 """ 

119 zpg = signal.ZerosPolesGain(self.zeros, self.poles, self.normalization_factor) 

120 return zpg 

121 

122 @property 

123 def total_gain(self): 

124 """ 

125 

126 :return: total gain of the filter 

127 :rtype: float 

128 

129 """ 

130 return self.gain * self.normalization_factor 

131 

132 @requires(obspy=obspy) 

133 def to_obspy( 

134 self, 

135 stage_number=1, 

136 pz_type="LAPLACE (RADIANS/SECOND)", 

137 normalization_frequency=1, 

138 sample_rate=1, 

139 ): 

140 """ 

141 Convert the filter to an obspy filter 

142 

143 :param stage_number: sequential stage number, defaults to 1 

144 :type stage_number: integer, optional 

145 :param pz_type: Pole Zero type, defaults to "LAPLACE (RADIANS/SECOND)" 

146 :type pz_type: string, optional 

147 :param normalization_frequency: Normalization frequency, defaults to 1 

148 :type normalization_frequency: float, optional 

149 :param sample_rate: sample rate, defaults to 1 

150 :type sample_rate: float, optional 

151 :return: Obspy stage filter 

152 :rtype: :class:`obspy.core.inventory.PolesZerosResponseStage` 

153 

154 """ 

155 if self.zeros is None: 

156 self.zeros = [] 

157 if self.poles is None: 

158 self.poles = [] 

159 

160 rs = obspy.core.inventory.PolesZerosResponseStage( 

161 stage_number, 

162 self.gain, 

163 normalization_frequency, 

164 self.units_in_object.symbol, 

165 self.units_out_object.symbol, 

166 pz_type, 

167 normalization_frequency, 

168 self.zeros, 

169 self.poles, 

170 name=self.name, 

171 normalization_factor=self.normalization_factor, 

172 description=self.get_filter_description(), 

173 input_units_description=self.units_in_object.name, 

174 output_units_description=self.units_out_object.name, 

175 ) 

176 

177 return rs 

178 

179 def complex_response(self, frequencies, **kwargs): 

180 """ 

181 Computes complex response for given frequency range 

182 :param frequencies: array of frequencies to estimate the response 

183 :type frequencies: np.ndarray 

184 

185 :return: complex response 

186 :rtype: np.ndarray 

187 

188 """ 

189 angular_frequencies = 2 * np.pi * np.array(frequencies) 

190 w, h = signal.freqs_zpk( 

191 self.zeros, self.poles, self.total_gain, worN=angular_frequencies 

192 ) 

193 

194 return h 

195 

196 def normalization_frequency( 

197 self, 

198 frequencies: np.ndarray = np.logspace(-4, 4, 32), 

199 estimate: str = "mean", 

200 window_len: int = 5, 

201 tol: float = 1e-4, 

202 ) -> float: 

203 """ 

204 Try to estimate the normalization frequency in the pass band 

205 by finding the flattest spot in the amplitude. 

206 

207 The flattest spot is determined by calculating a sliding window 

208 with length `window_len` and estimating normalized std. 

209 

210 ..note:: This only works for simple filters with 

211 on flat pass band. 

212 

213 :param window_len: length of sliding window in points 

214 :type window_len: integer 

215 

216 :param tol: the ratio of the mean/std should be around 1 

217 tol is the range around 1 to find the flat part of the curve. 

218 :type tol: float 

219 

220 :return: estimated normalization frequency Hz 

221 :rtype: float 

222 

223 """ 

224 pass_band = self.pass_band(frequencies, window_len, tol) 

225 if pass_band is None: 

226 return np.NAN 

227 if pass_band.size == 0: 

228 return np.NAN 

229 

230 if estimate == "mean": 

231 return pass_band.mean() 

232 

233 elif estimate == "median": 

234 return np.median(pass_band) 

235 

236 elif estimate == "min": 

237 return pass_band.min() 

238 

239 elif estimate == "max": 

240 return pass_band.max()