Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ timeseries \ filters \ fir_filter.py: 75%

91 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1# ===================================================== 

2# Imports 

3# ===================================================== 

4from typing import Annotated 

5 

6import matplotlib.pyplot as plt 

7import numpy as np 

8from pydantic import computed_field, Field, field_validator, PrivateAttr, ValidationInfo 

9 

10from mt_metadata.base.helpers import requires 

11from mt_metadata.timeseries.filters import FilterBase, get_base_obspy_mapping 

12 

13 

14try: 

15 from obspy.core.inventory.response import FIRResponseStage 

16except ImportError: 

17 FIRResponseStage = None 

18import scipy.signal as signal 

19 

20from mt_metadata.common import SymmetryEnum 

21 

22 

23# ===================================================== 

24 

25 

26class FIRFilter(FilterBase): 

27 _filter_type: str = PrivateAttr("fir") 

28 type: Annotated[ 

29 str, 

30 Field( 

31 default="fir", 

32 description="Type of filter. Must be 'fir'", 

33 alias=None, 

34 json_schema_extra={ 

35 "units": None, 

36 "required": True, 

37 "examples": "fir", 

38 }, 

39 ), 

40 ] 

41 coefficients: Annotated[ 

42 np.ndarray | list[float], 

43 Field( 

44 default_factory=lambda: np.empty(0), 

45 description="The FIR coefficients associated with the filter stage response.", 

46 alias=None, 

47 json_schema_extra={ 

48 "units": None, 

49 "required": True, 

50 "items": {"type": "number"}, 

51 "examples": '"[0.25, 0.5, 0.25]"', 

52 }, 

53 ), 

54 ] 

55 

56 decimation_factor: Annotated[ 

57 float, 

58 Field( 

59 default=1.0, 

60 description="Downsample factor.", 

61 alias=None, 

62 json_schema_extra={ 

63 "units": None, 

64 "required": False, 

65 "examples": "16", 

66 }, 

67 ), 

68 ] 

69 

70 decimation_input_sample_rate: Annotated[ 

71 float, 

72 Field( 

73 default=1.0, 

74 description="Sample rate of FIR taps.", 

75 alias=None, 

76 json_schema_extra={ 

77 "units": None, 

78 "required": False, 

79 "examples": "2000", 

80 }, 

81 ), 

82 ] 

83 

84 @computed_field 

85 @property 

86 def output_sampling_rate(self) -> float: 

87 return self.decimation_input_sample_rate / self.decimation_factor 

88 

89 gain_frequency: Annotated[ 

90 float, 

91 Field( 

92 default=0.0, 

93 description="Frequency of the reference gain, usually in passband.", 

94 alias=None, 

95 json_schema_extra={ 

96 "units": "hertz", 

97 "required": True, 

98 "examples": "0.0", 

99 }, 

100 ), 

101 ] 

102 

103 symmetry: Annotated[ 

104 SymmetryEnum, 

105 Field( 

106 default="NONE", 

107 description="Symmetry of FIR coefficients", 

108 alias=None, 

109 json_schema_extra={ 

110 "units": None, 

111 "required": True, 

112 "examples": "NONE", 

113 }, 

114 ), 

115 ] 

116 

117 @field_validator("coefficients") 

118 @classmethod 

119 def validate_coefficients( 

120 cls, value: list[float], info: ValidationInfo 

121 ) -> list[float]: 

122 """ 

123 Validate the coefficients to ensure they are a list of floats. 

124 :param value: The value to validate. 

125 :param info: Validation information. 

126 :return: The validated value. 

127 """ 

128 if isinstance(value, (list, tuple, np.ndarray)): 

129 return np.array(value, dtype=float) 

130 elif isinstance(value, str): 

131 return np.array(value.split(","), dtype=float) 

132 else: 

133 raise ValueError("Coefficients must be a list, tuple, or string.") 

134 

135 def make_obspy_mapping(self): 

136 mapping = get_base_obspy_mapping() 

137 mapping["_symmetry"] = "symmetry" 

138 mapping["_coefficients"] = "coefficients" 

139 mapping["decimation_factor"] = "decimation_factor" 

140 mapping["decimation_input_sample_rate"] = "decimation_input_sample_rate" 

141 mapping["stage_gain_frequency"] = "gain_frequency" 

142 return mapping 

143 

144 @property 

145 def symmetry_corrected_coefficients(self): 

146 if self.symmetry == "EVEN": 

147 return np.hstack((self.coefficients, np.flipud(self.coefficients))) 

148 elif self.symmetry == "ODD": 

149 return np.hstack((self.coefficients, np.flipud(self.coefficients[1:]))) 

150 else: 

151 return self.coefficients 

152 

153 @property 

154 def coefficient_gain(self): 

155 """ 

156 The gain at the reference frequency due only to the coefficients 

157 Sometimes this is different from the gain in the stationxml and a 

158 corrective scalar must be applied 

159 """ 

160 if self.gain_frequency == 0.0: 

161 coefficient_gain = self.symmetry_corrected_coefficients.sum() 

162 else: 

163 # estimate the gain from the coefficeints at gain_frequency 

164 ww, hh = signal.freqz( 

165 self.symmetry_corrected_coefficients, 

166 worN=2 * np.pi * self.gain_frequency, 

167 fs=2 * np.pi * self.decimation_input_sample_rate, 

168 ) 

169 coefficient_gain = np.abs(hh) 

170 return coefficient_gain 

171 

172 @property 

173 def n_coefficients(self): 

174 return len(self.coefficients) 

175 

176 @property 

177 def corrective_scalar(self): 

178 """ """ 

179 if self.coefficient_gain != self.gain: 

180 return self.coefficient_gain / self.total_gain 

181 else: 

182 return 1.0 

183 

184 def plot_fir_response(self): 

185 w, h = signal.freqz(self.full_coefficients) 

186 fig = plt.figure() 

187 plt.title("Digital filter frequency response") 

188 ax1 = fig.add_subplot(111) 

189 plt.plot(w, 20 * np.log10(abs(h)), "b") 

190 plt.ylabel("Amplitude [dB]", color="b") 

191 plt.xlabel("Frequency [rad/sample]") 

192 

193 ax2 = ax1.twinx() 

194 angles = np.unwrap(np.angle(h)) 

195 plt.plot(w, angles, "g") 

196 plt.ylabel("Angle (radians)", color="g") 

197 plt.grid() 

198 plt.axis("tight") 

199 plt.show() 

200 

201 return fig 

202 

203 @requires(obspy=FIRResponseStage) 

204 def to_obspy( 

205 self, 

206 stage_number=1, 

207 normalization_frequency=1, 

208 sample_rate=1, 

209 ): 

210 """ 

211 create an obspy stage 

212 

213 :return: DESCRIPTION 

214 :rtype: TYPE 

215 

216 """ 

217 # self, stage_sequence_number, stage_gain, 

218 # stage_gain_frequency, input_units, output_units, 

219 # symmetry="NONE", resource_id=None, resource_id2=None, 

220 # name=None, 

221 # coefficients=None, input_units_description=None, 

222 # output_units_description=None, description=None, 

223 # decimation_input_sample_rate=None, decimation_factor=None, 

224 # decimation_offset=None, decimation_delay=None, 

225 # decimation_correction=None 

226 rs = FIRResponseStage( 

227 stage_number, 

228 self.gain, 

229 normalization_frequency, 

230 self.units_in_object.symbol, 

231 self.units_out_object.symbol, 

232 coefficients=self.coefficients.tolist(), 

233 symmetry=self.symmetry, 

234 name=self.name, 

235 description=self.get_filter_description(), 

236 input_units_description=self.units_in_object.name, 

237 output_units_description=self.units_out_object.name, 

238 decimation_input_sample_rate=self.decimation_input_sample_rate, 

239 decimation_factor=self.decimation_factor, 

240 ) 

241 

242 return rs 

243 

244 def unscaled_complex_response(self, frequencies): 

245 """ 

246 need this to avoid RecursionError. 

247 The problem is that some FIRs need a scale factor to make their gains be 

248 the same as those reported in the stationXML. The pure coefficients 

249 themselves sometimes result in pass-band gains that differ from the 

250 gain in the XML. For example filter fs2d5 has a passband gain of 0.5 

251 based on the coefficients alone, but the cited gain in the xml is 

252 almost 1. 

253 

254 I wanted to scale the coefficients so they equal the gain... but maybe 

255 we can add the gain in complex response 

256 :param frequencies: 

257 :return: 

258 """ 

259 angular_frequencies = 2 * np.pi * frequencies 

260 w, h = signal.freqz( 

261 self.symmetry_corrected_coefficients, 

262 worN=angular_frequencies, 

263 fs=2 * np.pi * self.decimation_input_sample_rate, 

264 ) 

265 return h 

266 

267 def complex_response(self, frequencies, **kwargs): 

268 """ 

269 

270 Parameters 

271 ---------- 

272 frequencies: numpy array of frequencies, expected in Hz 

273 

274 Returns 

275 ------- 

276 h : numpy array of (possibly complex-valued) frequency response at the input frequencies 

277 

278 """ 

279 # fir_filter.full_coefficients 

280 angular_frequencies = 2 * np.pi * frequencies 

281 w, h = signal.freqz( 

282 self.symmetry_corrected_coefficients, 

283 worN=angular_frequencies, 

284 fs=2 * np.pi * self.decimation_input_sample_rate, 

285 ) 

286 h /= self.corrective_scalar 

287 

288 return h