Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ timeseries \ filters \ filter_base.py: 79%

201 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1# ===================================================== 

2# Imports 

3# ===================================================== 

4from typing import Annotated 

5 

6import numpy as np 

7import pandas as pd 

8from loguru import logger 

9from pydantic import computed_field, Field, field_validator, PrivateAttr, ValidationInfo 

10 

11from mt_metadata.base import MetadataBase 

12from mt_metadata.base.helpers import filter_descriptions, requires 

13from mt_metadata.common import Comment 

14from mt_metadata.common.mttime import MTime 

15from mt_metadata.common.units import get_unit_object, Unit 

16from mt_metadata.timeseries.filters.plotting_helpers import plot_response 

17 

18 

19try: 

20 from obspy.core.inventory.response import ResponseListResponseStage, ResponseStage 

21 

22 obspy_import = True 

23except ImportError: 

24 ResponseListResponseStage = None 

25 ResponseStage = None 

26 obspy_import = False 

27 

28 

29# ===================================================== 

30 

31 

32def get_base_obspy_mapping(): 

33 """ 

34 Different filters have different mappings, but the attributes mapped here are common to all of them. 

35 Hence the name "base obspy mapping" 

36 Note: If we wanted to support inverse forms of these filters, and argument specifying filter direction could be added. 

37 

38 :return: mapping to an obspy filter, mapping['obspy_label'] = 'mt_metadata_label' 

39 :rtype: dict 

40 """ 

41 mapping = {} 

42 mapping["description"] = "comments" 

43 mapping["name"] = "name" 

44 mapping["stage_gain"] = "gain" 

45 mapping["input_units"] = "units_in" 

46 mapping["output_units"] = "units_out" 

47 mapping["stage_sequence_number"] = "sequence_number" 

48 return mapping 

49 

50 

51class FilterBase(MetadataBase): 

52 _obspy_mapping: dict = PrivateAttr({}) 

53 _filter_type: str = PrivateAttr("base") 

54 name: Annotated[ 

55 str, 

56 Field( 

57 default="", 

58 description="Name of filter applied or to be applied. If more than one filter input as a comma separated list.", 

59 alias=None, 

60 json_schema_extra={ 

61 "units": None, 

62 "required": True, 

63 "examples": '"lowpass_magnetic"', 

64 }, 

65 ), 

66 ] 

67 

68 comments: Annotated[ 

69 Comment, 

70 Field( 

71 default_factory=lambda: Comment(), 

72 description="Any comments about the filter.", 

73 alias=None, 

74 json_schema_extra={ 

75 "units": None, 

76 "required": False, 

77 "examples": "ambient air temperature", 

78 }, 

79 ), 

80 ] 

81 

82 type: Annotated[ 

83 str, 

84 Field( 

85 default="base", 

86 description="Type of filter, must be one of the available filters.", 

87 alias=None, 

88 json_schema_extra={ 

89 "units": None, 

90 "required": True, 

91 "examples": "fap_table", 

92 }, 

93 ), 

94 ] 

95 

96 units_in: Annotated[ 

97 str, 

98 Field( 

99 default="", 

100 description="Name of the input units to the filter. Should be all lowercase and separated with an underscore, use 'per' if units are divided and '-' if units are multiplied.", 

101 alias=None, 

102 json_schema_extra={ 

103 "units": None, 

104 "required": True, 

105 "examples": "count", 

106 }, 

107 ), 

108 ] 

109 

110 units_out: Annotated[ 

111 str, 

112 Field( 

113 default="", 

114 description="Name of the output units. Should be all lowercase and separated with an underscore, use 'per' if units are divided and '-' if units are multiplied.", 

115 alias=None, 

116 json_schema_extra={ 

117 "units": None, 

118 "required": True, 

119 "examples": "millivolt", 

120 }, 

121 ), 

122 ] 

123 

124 calibration_date: Annotated[ 

125 MTime | str | float | int | np.datetime64 | pd.Timestamp | None, 

126 Field( 

127 default_factory=lambda: MTime(time_stamp=None), 

128 description="Most recent date of filter calibration in ISO format of YYY-MM-DD.", 

129 alias=None, 

130 json_schema_extra={ 

131 "units": None, 

132 "required": False, 

133 "examples": "2020-01-01", 

134 }, 

135 ), 

136 ] 

137 

138 gain: Annotated[ 

139 float, 

140 Field( 

141 default=1.0, 

142 description="scalar gain of the filter across all frequencies, producted with any frequency depenendent terms", 

143 alias=None, 

144 json_schema_extra={ 

145 "units": None, 

146 "required": True, 

147 "examples": "1.0", 

148 }, 

149 ), 

150 ] 

151 

152 sequence_number: Annotated[ 

153 int, 

154 Field( 

155 default=0, 

156 description="Sequence number of the filter in the processing chain.", 

157 alias=None, 

158 ge=0, 

159 json_schema_extra={ 

160 "units": None, 

161 "required": True, 

162 "examples": 1, 

163 }, 

164 ), 

165 ] 

166 

167 @field_validator("calibration_date", mode="before") 

168 @classmethod 

169 def validate_calibration_date( 

170 cls, field_value: MTime | float | int | np.datetime64 | pd.Timestamp | str 

171 ): 

172 return MTime(time_stamp=field_value) 

173 

174 @field_validator("comments", mode="before") 

175 @classmethod 

176 def validate_comments(cls, value, info: ValidationInfo) -> Comment: 

177 if isinstance(value, str): 

178 return Comment(value=value) 

179 return value 

180 

181 @field_validator("type", mode="before") 

182 @classmethod 

183 def validate_type(cls, value, info: ValidationInfo) -> str: 

184 """ 

185 Validate that the type of filter is set to "fir" 

186 """ 

187 # Get the expected filter type based on the actual class 

188 # Make sure derived classes define their own _filter_type as class variable 

189 expected_type = getattr(cls, "_filter_type", "base").default 

190 

191 if value != expected_type: 

192 logger.warning( 

193 f"Filter type is set to {value}, but should be " 

194 f"{expected_type} for {cls.__name__}." 

195 ) 

196 return expected_type 

197 

198 @field_validator("units_in", "units_out", mode="before") 

199 @classmethod 

200 def validate_units(cls, value: str, info: ValidationInfo) -> str: 

201 """ 

202 validate units base on input string will return the long name 

203 

204 Parameters 

205 ---------- 

206 value : units string 

207 unit string separated by either '/' for division or ' ' for 

208 multiplication. Or 'per' and ' ', respectively 

209 info : ValidationInfo 

210 _description_ 

211 

212 Returns 

213 ------- 

214 str 

215 return the long descriptive name of the unit. For example 'kilometers'. 

216 """ 

217 

218 try: 

219 unit_object = get_unit_object(value, allow_none=False) 

220 return unit_object.name 

221 except ValueError as error: 

222 raise KeyError(error) 

223 except KeyError as error: 

224 raise KeyError(error) 

225 

226 @property 

227 def units_in_object(self) -> Unit: 

228 return get_unit_object(self.units_in, allow_none=False) 

229 

230 @property 

231 def units_out_object(self) -> Unit: 

232 return get_unit_object(self.units_out, allow_none=False) 

233 

234 def make_obspy_mapping(self): 

235 mapping = get_base_obspy_mapping() 

236 return mapping 

237 

238 @property 

239 def obspy_mapping(self): 

240 """ 

241 

242 :return: mapping to an obspy filter 

243 :rtype: dict 

244 

245 """ 

246 if self._obspy_mapping == {}: 

247 self._obspy_mapping = self.make_obspy_mapping() 

248 return self._obspy_mapping 

249 

250 @obspy_mapping.setter 

251 def obspy_mapping(self, obspy_dict): 

252 """ 

253 set the obspy mapping: this is a dictionary relating attribute labels from obspy stage objects to 

254 mt_metadata filter objects. 

255 """ 

256 if not isinstance(obspy_dict, dict): 

257 msg = f"Input must be a dictionary not {type(obspy_dict)}" 

258 logger.error(msg) 

259 raise TypeError(msg) 

260 

261 self._obspy_mapping = obspy_dict 

262 

263 @computed_field 

264 @property 

265 def total_gain(self) -> float: 

266 """ 

267 

268 :return: Total gain of the filter 

269 :rtype: float 

270 

271 """ 

272 return self.gain 

273 

274 def get_filter_description(self): 

275 """ 

276 

277 :return: predetermined filter description based on the 

278 type of filter 

279 :rtype: string 

280 

281 """ 

282 

283 if self.comments.value is None: 

284 return filter_descriptions[self.type] 

285 

286 return self.comments 

287 

288 @requires(obspy=obspy_import) 

289 @classmethod 

290 def from_obspy_stage( 

291 cls, 

292 stage, # : Union[ResponseStage, ResponseListResponseStage], 

293 mapping: dict = None, 

294 ) -> "FilterBase": 

295 """ 

296 Expected to return a multiply operation function 

297 

298 :param cls: a filter object 

299 :type cls: filter object 

300 :param stage: Obspy stage filter 

301 :type stage: :class:`obspy.inventory.response.ResponseStage` 

302 :param mapping: dictionary for mapping from an obspy stage, 

303 defaults to None 

304 :type mapping: dict, optional 

305 :raises TypeError: If stage is not a 

306 :class:`obspy.inventory.response.ResponseStage` 

307 :return: the appropriate mt_metadata.timeseries.filter object 

308 :rtype: mt_metadata.timeseries.filter object 

309 

310 """ 

311 

312 if mapping is None: 

313 mapping = cls().make_obspy_mapping() 

314 kwargs = {"name": ""} 

315 

316 if not isinstance(stage, (ResponseListResponseStage, ResponseStage)): 

317 msg = f"Expected a ResponseStage and got a {type(stage)}" 

318 logger.error(msg) 

319 raise TypeError(msg) 

320 

321 if isinstance(stage, ResponseListResponseStage): 

322 frequencies = [] 

323 amplitudes = [] 

324 phases = [] 

325 for element in stage.response_list_elements: 

326 frequencies.append(element.frequency) 

327 amplitudes.append(element.amplitude) 

328 phases.append(element.phase) 

329 kwargs["frequencies"] = np.array(frequencies) 

330 kwargs["amplitudes"] = np.array(amplitudes) 

331 kwargs["phases"] = np.array(phases) 

332 

333 for obspy_label, mth5_label in mapping.items(): 

334 if obspy_label in ["amplitudes", "phases", "frequencies"]: 

335 continue 

336 if mth5_label == "comments" or obspy_label == "description": 

337 kwargs[mth5_label] = Comment(value=getattr(stage, obspy_label)) 

338 else: 

339 try: 

340 kwargs[mth5_label] = getattr(stage, obspy_label) 

341 

342 except AttributeError: 

343 logger.warning( 

344 f"Attribute {obspy_label} not found in stage object, skipping." 

345 ) 

346 if kwargs.get("name") is None: 

347 kwargs["name"] = "" 

348 return cls(**kwargs) 

349 

350 def complex_response(self, frqs): 

351 msg = f"complex_response not defined for {self.__class__.__name__} class" 

352 logger.info(msg) 

353 return None 

354 

355 def pass_band( 

356 self, frequencies: np.ndarray, window_len: int = 5, tol: float = 0.5, **kwargs 

357 ) -> np.ndarray: 

358 """ 

359 Fast passband estimation using decimation (10-100x faster than original). 

360 

361 Caveat: This should work for most Fluxgate and feedback coil magnetometers, and basically most filters 

362 having a "low" number of poles and zeros. This method is not 100% robust to filters with a notch in them. 

363 

364 Try to estimate pass band of the filter from the flattest spots in 

365 the amplitude. Instead of checking every frequency point, this decimates the 

366 frequency array and only checks a subset of windows. The pass band 

367 region is then interpolated across the full array. 

368 

369 The flattest spot is determined by calculating a sliding window 

370 with length `window_len` and estimating normalized std. 

371 

372 ..note:: This only works for simple filters with on flat pass band. 

373 

374 :param frequencies: array of frequencies 

375 :type frequencies: np.ndarray 

376 

377 :param window_len: length of sliding window in points 

378 :type window_len: integer 

379 

380 :param tol: the ratio of the mean/std should be around 1 

381 tol is the range around 1 to find the flat part of the curve. 

382 :type tol: float 

383 

384 :return: pass band frequencies [f_start, f_end] 

385 :rtype: np.ndarray or None 

386 

387 """ 

388 

389 f = np.array(frequencies) 

390 if f.size == 0: 

391 logger.warning("Frequency array is empty, returning None") 

392 return None 

393 elif f.size == 1: 

394 logger.warning("Frequency array is too small, returning None") 

395 return f 

396 

397 cr = self.complex_response(f, **kwargs) 

398 if cr is None: 

399 logger.warning( 

400 "complex response is None, cannot estimate pass band. Returning None" 

401 ) 

402 return None 

403 

404 amp = np.abs(cr) 

405 

406 # precision is apparently an important variable here 

407 if np.round(amp, 6).all() == np.round(amp.mean(), 6): 

408 return np.array([f.min(), f.max()]) 

409 

410 # Decimate frequency array for faster processing 

411 # If array is large, sample every Nth point 

412 decimate_factor = max(1, f.size // 1000) # Keep ~1000 points for analysis 

413 if decimate_factor > 1: 

414 f_dec = f[::decimate_factor] 

415 amp_dec = amp[::decimate_factor] 

416 else: 

417 f_dec = f 

418 amp_dec = amp 

419 

420 n_windows = f_dec.size - window_len 

421 if n_windows <= 0: 

422 return np.array([f.min(), f.max()]) 

423 

424 # Vectorized window analysis on decimated array 

425 try: 

426 from numpy.lib.stride_tricks import as_strided 

427 

428 shape = (n_windows, window_len) 

429 strides = (amp_dec.strides[0], amp_dec.strides[0]) 

430 amp_windows = as_strided(amp_dec, shape=shape, strides=strides) 

431 

432 window_mins = np.min(amp_windows, axis=1) 

433 window_maxs = np.max(amp_windows, axis=1) 

434 

435 with np.errstate(divide="ignore", invalid="ignore"): 

436 ratios = np.log10(window_mins) / np.log10(window_maxs) 

437 ratios = np.nan_to_num(ratios, nan=np.inf) 

438 test_values = np.abs(1 - ratios) 

439 

440 passing_windows = test_values <= tol 

441 

442 if not passing_windows.any(): 

443 # If no windows pass, return full frequency range 

444 return np.array([f.min(), f.max()]) 

445 

446 # Find first and last passing windows 

447 passing_indices = np.where(passing_windows)[0] 

448 start_idx = passing_indices[0] 

449 end_idx = passing_indices[-1] + window_len 

450 

451 # Map back to original frequency array 

452 start_freq_idx = start_idx * decimate_factor 

453 end_freq_idx = min(end_idx * decimate_factor, f.size - 1) 

454 

455 return np.array([f[start_freq_idx], f[end_freq_idx]]) 

456 

457 except Exception as e: 

458 logger.debug(f"Decimated passband method failed: {e}, returning full range") 

459 return np.array([f.min(), f.max()]) 

460 

461 def generate_frequency_axis(self, sampling_rate, n_observations): 

462 dt = 1.0 / sampling_rate 

463 frequency_axis = np.fft.fftfreq(n_observations, d=dt) 

464 frequency_axis = np.fft.fftshift(frequency_axis) 

465 return frequency_axis 

466 

467 def plot_response( 

468 self, 

469 frequencies, 

470 x_units="period", 

471 unwrap=True, 

472 pb_tol=1e-1, 

473 interpolation_method="slinear", 

474 ): 

475 if frequencies is None: 

476 frequencies = self.generate_frequency_axis(10.0, 1000) 

477 x_units = "frequency" 

478 

479 kwargs = { 

480 "title": self.name, 

481 "unwrap": unwrap, 

482 "x_units": x_units, 

483 "label": self.name, 

484 } 

485 

486 complex_response = self.complex_response( 

487 frequencies, **{"interpolation_method": interpolation_method} 

488 ) 

489 if hasattr(self, "poles"): 

490 kwargs["poles"] = self.poles 

491 kwargs["zeros"] = self.zeros 

492 

493 if hasattr(self, "pass_band"): 

494 kwargs["pass_band"] = self.pass_band( 

495 frequencies, 

496 tol=pb_tol, 

497 **{"interpolation_method": interpolation_method}, 

498 ) 

499 

500 plot_response(frequencies, complex_response, **kwargs) 

501 

502 @property 

503 def decimation_active(self): 

504 """ 

505 

506 :return: if decimation is prescribed 

507 :rtype: bool 

508 

509 """ 

510 if hasattr(self, "decimation_factor"): 

511 if self.decimation_factor != 1.0: 

512 return True 

513 return False