Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ timeseries \ filters \ filter_base.py: 79%
201 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
1# =====================================================
2# Imports
3# =====================================================
4from typing import Annotated
6import numpy as np
7import pandas as pd
8from loguru import logger
9from pydantic import computed_field, Field, field_validator, PrivateAttr, ValidationInfo
11from mt_metadata.base import MetadataBase
12from mt_metadata.base.helpers import filter_descriptions, requires
13from mt_metadata.common import Comment
14from mt_metadata.common.mttime import MTime
15from mt_metadata.common.units import get_unit_object, Unit
16from mt_metadata.timeseries.filters.plotting_helpers import plot_response
19try:
20 from obspy.core.inventory.response import ResponseListResponseStage, ResponseStage
22 obspy_import = True
23except ImportError:
24 ResponseListResponseStage = None
25 ResponseStage = None
26 obspy_import = False
29# =====================================================
32def get_base_obspy_mapping():
33 """
34 Different filters have different mappings, but the attributes mapped here are common to all of them.
35 Hence the name "base obspy mapping"
36 Note: If we wanted to support inverse forms of these filters, and argument specifying filter direction could be added.
38 :return: mapping to an obspy filter, mapping['obspy_label'] = 'mt_metadata_label'
39 :rtype: dict
40 """
41 mapping = {}
42 mapping["description"] = "comments"
43 mapping["name"] = "name"
44 mapping["stage_gain"] = "gain"
45 mapping["input_units"] = "units_in"
46 mapping["output_units"] = "units_out"
47 mapping["stage_sequence_number"] = "sequence_number"
48 return mapping
51class FilterBase(MetadataBase):
52 _obspy_mapping: dict = PrivateAttr({})
53 _filter_type: str = PrivateAttr("base")
54 name: Annotated[
55 str,
56 Field(
57 default="",
58 description="Name of filter applied or to be applied. If more than one filter input as a comma separated list.",
59 alias=None,
60 json_schema_extra={
61 "units": None,
62 "required": True,
63 "examples": '"lowpass_magnetic"',
64 },
65 ),
66 ]
68 comments: Annotated[
69 Comment,
70 Field(
71 default_factory=lambda: Comment(),
72 description="Any comments about the filter.",
73 alias=None,
74 json_schema_extra={
75 "units": None,
76 "required": False,
77 "examples": "ambient air temperature",
78 },
79 ),
80 ]
82 type: Annotated[
83 str,
84 Field(
85 default="base",
86 description="Type of filter, must be one of the available filters.",
87 alias=None,
88 json_schema_extra={
89 "units": None,
90 "required": True,
91 "examples": "fap_table",
92 },
93 ),
94 ]
96 units_in: Annotated[
97 str,
98 Field(
99 default="",
100 description="Name of the input units to the filter. Should be all lowercase and separated with an underscore, use 'per' if units are divided and '-' if units are multiplied.",
101 alias=None,
102 json_schema_extra={
103 "units": None,
104 "required": True,
105 "examples": "count",
106 },
107 ),
108 ]
110 units_out: Annotated[
111 str,
112 Field(
113 default="",
114 description="Name of the output units. Should be all lowercase and separated with an underscore, use 'per' if units are divided and '-' if units are multiplied.",
115 alias=None,
116 json_schema_extra={
117 "units": None,
118 "required": True,
119 "examples": "millivolt",
120 },
121 ),
122 ]
124 calibration_date: Annotated[
125 MTime | str | float | int | np.datetime64 | pd.Timestamp | None,
126 Field(
127 default_factory=lambda: MTime(time_stamp=None),
128 description="Most recent date of filter calibration in ISO format of YYY-MM-DD.",
129 alias=None,
130 json_schema_extra={
131 "units": None,
132 "required": False,
133 "examples": "2020-01-01",
134 },
135 ),
136 ]
138 gain: Annotated[
139 float,
140 Field(
141 default=1.0,
142 description="scalar gain of the filter across all frequencies, producted with any frequency depenendent terms",
143 alias=None,
144 json_schema_extra={
145 "units": None,
146 "required": True,
147 "examples": "1.0",
148 },
149 ),
150 ]
152 sequence_number: Annotated[
153 int,
154 Field(
155 default=0,
156 description="Sequence number of the filter in the processing chain.",
157 alias=None,
158 ge=0,
159 json_schema_extra={
160 "units": None,
161 "required": True,
162 "examples": 1,
163 },
164 ),
165 ]
167 @field_validator("calibration_date", mode="before")
168 @classmethod
169 def validate_calibration_date(
170 cls, field_value: MTime | float | int | np.datetime64 | pd.Timestamp | str
171 ):
172 return MTime(time_stamp=field_value)
174 @field_validator("comments", mode="before")
175 @classmethod
176 def validate_comments(cls, value, info: ValidationInfo) -> Comment:
177 if isinstance(value, str):
178 return Comment(value=value)
179 return value
181 @field_validator("type", mode="before")
182 @classmethod
183 def validate_type(cls, value, info: ValidationInfo) -> str:
184 """
185 Validate that the type of filter is set to "fir"
186 """
187 # Get the expected filter type based on the actual class
188 # Make sure derived classes define their own _filter_type as class variable
189 expected_type = getattr(cls, "_filter_type", "base").default
191 if value != expected_type:
192 logger.warning(
193 f"Filter type is set to {value}, but should be "
194 f"{expected_type} for {cls.__name__}."
195 )
196 return expected_type
198 @field_validator("units_in", "units_out", mode="before")
199 @classmethod
200 def validate_units(cls, value: str, info: ValidationInfo) -> str:
201 """
202 validate units base on input string will return the long name
204 Parameters
205 ----------
206 value : units string
207 unit string separated by either '/' for division or ' ' for
208 multiplication. Or 'per' and ' ', respectively
209 info : ValidationInfo
210 _description_
212 Returns
213 -------
214 str
215 return the long descriptive name of the unit. For example 'kilometers'.
216 """
218 try:
219 unit_object = get_unit_object(value, allow_none=False)
220 return unit_object.name
221 except ValueError as error:
222 raise KeyError(error)
223 except KeyError as error:
224 raise KeyError(error)
226 @property
227 def units_in_object(self) -> Unit:
228 return get_unit_object(self.units_in, allow_none=False)
230 @property
231 def units_out_object(self) -> Unit:
232 return get_unit_object(self.units_out, allow_none=False)
234 def make_obspy_mapping(self):
235 mapping = get_base_obspy_mapping()
236 return mapping
238 @property
239 def obspy_mapping(self):
240 """
242 :return: mapping to an obspy filter
243 :rtype: dict
245 """
246 if self._obspy_mapping == {}:
247 self._obspy_mapping = self.make_obspy_mapping()
248 return self._obspy_mapping
250 @obspy_mapping.setter
251 def obspy_mapping(self, obspy_dict):
252 """
253 set the obspy mapping: this is a dictionary relating attribute labels from obspy stage objects to
254 mt_metadata filter objects.
255 """
256 if not isinstance(obspy_dict, dict):
257 msg = f"Input must be a dictionary not {type(obspy_dict)}"
258 logger.error(msg)
259 raise TypeError(msg)
261 self._obspy_mapping = obspy_dict
263 @computed_field
264 @property
265 def total_gain(self) -> float:
266 """
268 :return: Total gain of the filter
269 :rtype: float
271 """
272 return self.gain
274 def get_filter_description(self):
275 """
277 :return: predetermined filter description based on the
278 type of filter
279 :rtype: string
281 """
283 if self.comments.value is None:
284 return filter_descriptions[self.type]
286 return self.comments
288 @requires(obspy=obspy_import)
289 @classmethod
290 def from_obspy_stage(
291 cls,
292 stage, # : Union[ResponseStage, ResponseListResponseStage],
293 mapping: dict = None,
294 ) -> "FilterBase":
295 """
296 Expected to return a multiply operation function
298 :param cls: a filter object
299 :type cls: filter object
300 :param stage: Obspy stage filter
301 :type stage: :class:`obspy.inventory.response.ResponseStage`
302 :param mapping: dictionary for mapping from an obspy stage,
303 defaults to None
304 :type mapping: dict, optional
305 :raises TypeError: If stage is not a
306 :class:`obspy.inventory.response.ResponseStage`
307 :return: the appropriate mt_metadata.timeseries.filter object
308 :rtype: mt_metadata.timeseries.filter object
310 """
312 if mapping is None:
313 mapping = cls().make_obspy_mapping()
314 kwargs = {"name": ""}
316 if not isinstance(stage, (ResponseListResponseStage, ResponseStage)):
317 msg = f"Expected a ResponseStage and got a {type(stage)}"
318 logger.error(msg)
319 raise TypeError(msg)
321 if isinstance(stage, ResponseListResponseStage):
322 frequencies = []
323 amplitudes = []
324 phases = []
325 for element in stage.response_list_elements:
326 frequencies.append(element.frequency)
327 amplitudes.append(element.amplitude)
328 phases.append(element.phase)
329 kwargs["frequencies"] = np.array(frequencies)
330 kwargs["amplitudes"] = np.array(amplitudes)
331 kwargs["phases"] = np.array(phases)
333 for obspy_label, mth5_label in mapping.items():
334 if obspy_label in ["amplitudes", "phases", "frequencies"]:
335 continue
336 if mth5_label == "comments" or obspy_label == "description":
337 kwargs[mth5_label] = Comment(value=getattr(stage, obspy_label))
338 else:
339 try:
340 kwargs[mth5_label] = getattr(stage, obspy_label)
342 except AttributeError:
343 logger.warning(
344 f"Attribute {obspy_label} not found in stage object, skipping."
345 )
346 if kwargs.get("name") is None:
347 kwargs["name"] = ""
348 return cls(**kwargs)
350 def complex_response(self, frqs):
351 msg = f"complex_response not defined for {self.__class__.__name__} class"
352 logger.info(msg)
353 return None
355 def pass_band(
356 self, frequencies: np.ndarray, window_len: int = 5, tol: float = 0.5, **kwargs
357 ) -> np.ndarray:
358 """
359 Fast passband estimation using decimation (10-100x faster than original).
361 Caveat: This should work for most Fluxgate and feedback coil magnetometers, and basically most filters
362 having a "low" number of poles and zeros. This method is not 100% robust to filters with a notch in them.
364 Try to estimate pass band of the filter from the flattest spots in
365 the amplitude. Instead of checking every frequency point, this decimates the
366 frequency array and only checks a subset of windows. The pass band
367 region is then interpolated across the full array.
369 The flattest spot is determined by calculating a sliding window
370 with length `window_len` and estimating normalized std.
372 ..note:: This only works for simple filters with on flat pass band.
374 :param frequencies: array of frequencies
375 :type frequencies: np.ndarray
377 :param window_len: length of sliding window in points
378 :type window_len: integer
380 :param tol: the ratio of the mean/std should be around 1
381 tol is the range around 1 to find the flat part of the curve.
382 :type tol: float
384 :return: pass band frequencies [f_start, f_end]
385 :rtype: np.ndarray or None
387 """
389 f = np.array(frequencies)
390 if f.size == 0:
391 logger.warning("Frequency array is empty, returning None")
392 return None
393 elif f.size == 1:
394 logger.warning("Frequency array is too small, returning None")
395 return f
397 cr = self.complex_response(f, **kwargs)
398 if cr is None:
399 logger.warning(
400 "complex response is None, cannot estimate pass band. Returning None"
401 )
402 return None
404 amp = np.abs(cr)
406 # precision is apparently an important variable here
407 if np.round(amp, 6).all() == np.round(amp.mean(), 6):
408 return np.array([f.min(), f.max()])
410 # Decimate frequency array for faster processing
411 # If array is large, sample every Nth point
412 decimate_factor = max(1, f.size // 1000) # Keep ~1000 points for analysis
413 if decimate_factor > 1:
414 f_dec = f[::decimate_factor]
415 amp_dec = amp[::decimate_factor]
416 else:
417 f_dec = f
418 amp_dec = amp
420 n_windows = f_dec.size - window_len
421 if n_windows <= 0:
422 return np.array([f.min(), f.max()])
424 # Vectorized window analysis on decimated array
425 try:
426 from numpy.lib.stride_tricks import as_strided
428 shape = (n_windows, window_len)
429 strides = (amp_dec.strides[0], amp_dec.strides[0])
430 amp_windows = as_strided(amp_dec, shape=shape, strides=strides)
432 window_mins = np.min(amp_windows, axis=1)
433 window_maxs = np.max(amp_windows, axis=1)
435 with np.errstate(divide="ignore", invalid="ignore"):
436 ratios = np.log10(window_mins) / np.log10(window_maxs)
437 ratios = np.nan_to_num(ratios, nan=np.inf)
438 test_values = np.abs(1 - ratios)
440 passing_windows = test_values <= tol
442 if not passing_windows.any():
443 # If no windows pass, return full frequency range
444 return np.array([f.min(), f.max()])
446 # Find first and last passing windows
447 passing_indices = np.where(passing_windows)[0]
448 start_idx = passing_indices[0]
449 end_idx = passing_indices[-1] + window_len
451 # Map back to original frequency array
452 start_freq_idx = start_idx * decimate_factor
453 end_freq_idx = min(end_idx * decimate_factor, f.size - 1)
455 return np.array([f[start_freq_idx], f[end_freq_idx]])
457 except Exception as e:
458 logger.debug(f"Decimated passband method failed: {e}, returning full range")
459 return np.array([f.min(), f.max()])
461 def generate_frequency_axis(self, sampling_rate, n_observations):
462 dt = 1.0 / sampling_rate
463 frequency_axis = np.fft.fftfreq(n_observations, d=dt)
464 frequency_axis = np.fft.fftshift(frequency_axis)
465 return frequency_axis
467 def plot_response(
468 self,
469 frequencies,
470 x_units="period",
471 unwrap=True,
472 pb_tol=1e-1,
473 interpolation_method="slinear",
474 ):
475 if frequencies is None:
476 frequencies = self.generate_frequency_axis(10.0, 1000)
477 x_units = "frequency"
479 kwargs = {
480 "title": self.name,
481 "unwrap": unwrap,
482 "x_units": x_units,
483 "label": self.name,
484 }
486 complex_response = self.complex_response(
487 frequencies, **{"interpolation_method": interpolation_method}
488 )
489 if hasattr(self, "poles"):
490 kwargs["poles"] = self.poles
491 kwargs["zeros"] = self.zeros
493 if hasattr(self, "pass_band"):
494 kwargs["pass_band"] = self.pass_band(
495 frequencies,
496 tol=pb_tol,
497 **{"interpolation_method": interpolation_method},
498 )
500 plot_response(frequencies, complex_response, **kwargs)
502 @property
503 def decimation_active(self):
504 """
506 :return: if decimation is prescribed
507 :rtype: bool
509 """
510 if hasattr(self, "decimation_factor"):
511 if self.decimation_factor != 1.0:
512 return True
513 return False