Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ timeseries \ filters \ channel_response.py: 89%
228 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
1# =====================================================
2# Imports
3# =====================================================
4from copy import deepcopy
5from typing import Annotated
7import numpy as np
8from loguru import logger
9from pydantic import (
10 computed_field,
11 Field,
12 field_validator,
13 model_validator,
14 PrivateAttr,
15 ValidationInfo,
16)
18from mt_metadata.base.helpers import object_to_array, requires
19from mt_metadata.common.units import get_unit_object
20from mt_metadata.timeseries.filters import (
21 CoefficientFilter,
22 FilterBase,
23 FIRFilter,
24 FrequencyResponseTableFilter,
25 PoleZeroFilter,
26 TimeDelayFilter,
27)
28from mt_metadata.timeseries.filters.plotting_helpers import plot_response
31try:
32 from obspy.core import inventory
33except ImportError:
34 inventory = None
37# =====================================================
40class ChannelResponse(FilterBase):
41 _supported_filters: list = PrivateAttr(
42 [
43 PoleZeroFilter,
44 CoefficientFilter,
45 TimeDelayFilter,
46 FrequencyResponseTableFilter,
47 FIRFilter,
48 ]
49 )
51 normalization_frequency: Annotated[
52 float,
53 Field(
54 default=0.0,
55 description="Pass band frequency",
56 alias=None,
57 json_schema_extra={
58 "units": None,
59 "required": True,
60 "examples": "100",
61 },
62 ),
63 ]
65 filters_list: Annotated[
66 list[
67 PoleZeroFilter
68 | CoefficientFilter
69 | TimeDelayFilter
70 | FrequencyResponseTableFilter
71 | FIRFilter
72 ],
73 Field(
74 default_factory=list,
75 description="List of filters applied to the channel.",
76 alias=None,
77 json_schema_extra={
78 "units": None,
79 "required": True,
80 "examples": "[PoleZeroFilter, CoefficientFilter]",
81 },
82 ),
83 ]
85 frequencies: Annotated[
86 np.ndarray | list[float],
87 Field(
88 default_factory=lambda: np.empty(0, dtype=float),
89 description="The frequencies at which a calibration of the filter were performed.",
90 alias=None,
91 json_schema_extra={
92 "units": "hertz",
93 "required": True,
94 "items": {"type": "number"},
95 "examples": '"[-0.0001., 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.001, ... 1, 2, 5, 10]"',
96 },
97 ),
98 ]
100 def __str__(self):
101 lines = ["Filters Included:\n", "=" * 25, "\n"]
102 for f in self.filters_list:
103 lines.append(f.__str__())
104 lines.append(f"\n{'-'*20}\n")
106 return "".join(lines)
108 def __repr__(self):
109 return self.__str__()
111 @field_validator("normalization_frequency", mode="after")
112 @classmethod
113 def validate_normalization_frequency(
114 cls, value: float, info: ValidationInfo
115 ) -> float:
116 """
117 Validate that the normalization frequency is a positive float.
118 If value is 0 or None, derive it from the pass_band.
119 """
120 if value in [0.0, None]:
121 # Create a temporary instance to access pass_band property
122 instance = cls.model_construct(**info.data)
124 if hasattr(instance, "pass_band") and instance.pass_band is not None:
125 pass_band = instance.pass_band
126 # Calculate geometric mean of pass band
127 norm_freq = np.round(10 ** np.mean(np.log10(pass_band)), 3)
128 logger.info(
129 f"Setting normalization frequency to {norm_freq} Hz based on pass band"
130 )
131 return norm_freq
133 return value
135 @field_validator("frequencies", mode="before")
136 @classmethod
137 def validate_frequencies(cls, value: np.ndarray | list[float]) -> np.ndarray:
138 """
139 Validate that the frequencies are a numpy array or list of floats.
140 """
141 return object_to_array(value, dtype=float)
143 @field_validator("filters_list", mode="before")
144 @classmethod
145 def validate_filters_list(cls, value: list) -> list:
146 """
147 Validate that the filters_list is a list of filter objects.
148 """
149 if not isinstance(value, list):
150 raise ValueError("filters_list must be a list of filter objects.")
152 value = cls._validate_filters_list(value)
153 value = cls._check_consistency_of_units(value)
154 return value
156 @model_validator(mode="after")
157 def update_units_and_normalization_frequency_from_filters_list(
158 self,
159 ) -> "ChannelResponse":
160 """Update units_in and units_out based on filters_list."""
161 if self.filters_list:
162 object.__setattr__(self, "units_in", self.filters_list[0].units_in)
163 object.__setattr__(self, "units_out", self.filters_list[-1].units_out)
164 if self.normalization_frequency == 0.0:
165 pass_band = self.pass_band
166 if pass_band is not None:
167 # Calculate geometric mean of pass band
168 with np.errstate(divide="ignore"):
169 norm_freq = np.round(10 ** np.mean(np.log10(pass_band)), 3)
170 logger.debug(
171 f"Setting normalization frequency to {norm_freq} Hz based on pass band"
172 )
173 # Set normalization frequency to the gain of the first filter
174 object.__setattr__(self, "normalization_frequency", norm_freq)
175 return self
177 @classmethod
178 def _validate_filters_list(cls, filters_list):
179 """
180 make sure the filters list is valid.
182 :param filters_list: DESCRIPTION
183 :type filters_list: TYPE
184 :return: DESCRIPTION
185 :rtype: TYPE
187 """
189 def is_supported_filter(item):
190 # Convert the list to a tuple of filter classes
191 supported_filter_types = tuple(cls._supported_filters.default)
192 # Check if item is an instance of any of the supported filter types
193 return isinstance(item, supported_filter_types)
195 if filters_list in [[], None]:
196 return []
198 if not isinstance(filters_list, list):
199 msg = f"Input filters list must be a list not {type(filters_list)}"
200 logger.error(msg)
201 raise TypeError(msg)
203 fails = []
204 return_list = []
205 for item in filters_list:
206 if is_supported_filter(item):
207 return_list.append(item)
208 else:
209 fails.append(f"Item is not a supported filter type, {type(item)}")
211 if fails:
212 raise TypeError(", ".join(fails))
214 return return_list
216 @classmethod
217 def _check_consistency_of_units(cls, filters_list):
218 """
219 confirms that the input and output units of each filter state are consistent
220 """
221 if len(filters_list) > 1:
222 previous_units = filters_list[0].units_out
223 for mt_filter in filters_list[1:]:
224 if mt_filter.units_in != previous_units:
225 msg = (
226 "Unit consistency is incorrect. "
227 f"The input units for {mt_filter.name} should be "
228 f"{previous_units} not {mt_filter.units_in}"
229 )
230 logger.error(msg)
231 raise ValueError(msg)
232 previous_units = mt_filter.units_out
234 return filters_list
236 @computed_field
237 @property
238 def names(self) -> list[str]:
239 """names of the filters"""
240 names = []
241 if self.filters_list:
242 names = [f.name for f in self.filters_list]
243 return names
245 @computed_field
246 @property
247 def pass_band(self) -> list[float]:
248 """estimate pass band for all filters in frequency"""
249 if self.frequencies is None:
250 logger.debug("No frequencies provided, cannot calculate pass band")
251 return None
253 if len(self.frequencies) == 0:
254 logger.debug("No frequencies provided, cannot calculate pass band")
255 return None
257 pb = []
258 for f in self.filters_list:
259 if hasattr(f, "pass_band"):
260 f_pb = f.pass_band(self.frequencies)
261 if f_pb is None:
262 continue
263 pb.append((f_pb.min(), f_pb.max()))
265 if pb != []:
266 pb = np.array(pb)
267 return np.array([pb[:, 0].max(), pb[:, 1].min()])
268 return None
270 @computed_field
271 @property
272 def non_delay_filters(self) -> list:
273 """
275 :return: all the non-time_delay filters as a list
277 """
278 non_delay_filters = [x for x in self.filters_list if x.type != "time delay"]
279 return non_delay_filters
281 @computed_field
282 @property
283 def delay_filters(self) -> list[TimeDelayFilter]:
284 """
286 :return: all the time delay filters as a list
288 """
289 delay_filters = [x for x in self.filters_list if x.type == "time delay"]
290 return delay_filters
292 @computed_field
293 @property
294 def total_delay(self) -> float:
295 """
297 :return: the total delay of all filters
299 """
300 delay_filters = self.delay_filters
301 total_delay = 0.0
302 for delay_filter in delay_filters:
303 total_delay += delay_filter.delay
304 return total_delay
306 def get_indices_of_filters_to_remove(
307 self, include_decimation=False, include_delay=False
308 ):
309 indices = list(np.arange(len(self.filters_list)))
311 if not include_delay:
312 indices = [i for i in indices if self.filters_list[i].type != "time delay"]
314 if not include_decimation:
315 indices = [i for i in indices if not self.filters_list[i].decimation_active]
317 return indices
319 def get_list_of_filters_to_remove(
320 self, include_decimation=False, include_delay=False
321 ):
322 """
324 :param include_decimation: bool
325 :param include_delay: bool
326 :return:
328 # Experimental snippet if we want to allow filters with the opposite convention
329 # into channel response -- I don't think we do.
330 # if self.correction_operation == "multiply":
331 # inverse_filters = [x.inverse() for x in self.filters_list]
332 # self.filters_list = inverse_filters
333 """
334 indices = self.get_indices_of_filters_to_remove(
335 include_decimation=include_decimation, include_delay=include_delay
336 )
337 return [self.filters_list[i] for i in indices]
339 def complex_response(
340 self,
341 frequencies=None,
342 filters_list=None,
343 include_decimation=False,
344 include_delay=False,
345 normalize=False,
346 **kwargs,
347 ):
348 """
349 Computes the complex response of self.
350 Allows the user to optionally supply a subset of filters
352 :param frequencies: frequencies to compute complex response,
353 defaults to None
354 :type frequencies: np.ndarray, optional
355 :param include_delay: include delay in complex response,
356 defaults to False
357 :type include_delay: bool, optional
358 :param include_decimation: Include decimation in response,
359 defaults to True
360 :type include_decimation: bool, optional
361 :param normalize: normalize the response to 1, defaults to False
362 :type normalize: bool, optional
363 :return: complex response along give frequency array
364 :rtype: np.ndarray
366 """
367 if frequencies is not None:
368 self.frequencies = frequencies
370 # make filters list if not supplied
371 if filters_list is None:
372 logger.warning(
373 "Filters list not provided, building list assuming all are applied"
374 )
375 filters_list = self.get_list_of_filters_to_remove(
376 include_decimation=include_decimation,
377 include_delay=include_delay,
378 )
380 if len(filters_list) == 0:
381 logger.warning(f"No filters associated with {self.__class__}, returning 1")
382 return np.ones(len(self.frequencies), dtype=complex)
384 # define the product of all filters as the total response function
385 result = filters_list[0].complex_response(self.frequencies)
386 for ff in filters_list[1:]:
387 result *= ff.complex_response(self.frequencies)
389 if normalize:
390 result /= np.max(np.abs(result))
391 return result
393 def compute_instrument_sensitivity(self, normalization_frequency=None, sig_figs=6):
394 """
395 Compute the StationXML instrument sensitivity for the given normalization frequency
397 :param normalization_frequency: DESCRIPTION
398 :type normalization_frequency: TYPE
399 :return: DESCRIPTION
400 :rtype: TYPE
402 """
403 if normalization_frequency is not None:
404 self.normalization_frequency = normalization_frequency
405 sensitivity = 1.0
406 for mt_filter in self.filters_list:
407 complex_response = mt_filter.complex_response(self.normalization_frequency)
408 sensitivity *= complex_response.astype(complex)
409 try:
410 sensitivity = np.abs(sensitivity[0])
411 except (IndexError, TypeError):
412 sensitivity = np.abs(sensitivity)
414 if sensitivity == 0.0:
415 logger.warning(
416 "Sensitivity is zero, cannot compute instrument sensitivity. "
417 "Returning 1.0"
418 )
419 return 1.0
420 if np.isnan(sensitivity):
421 logger.warning("Sensitivity is NaN, setting to 1.0")
422 sensitivity = 1.0
423 return round(sensitivity, sig_figs - int(np.floor(np.log10(abs(sensitivity)))))
425 def compute_total_gain(self, sig_figs=16):
426 """
427 Computing the total sensitivity seems to be different than just adding all the gains together.
428 Overall the total sensitivity is useless for MT cause they don't have the ability to use the units.
429 So if a person downloads data from the DMC, they will simply use the filters provided.
431 Parameters
432 ----------
433 sig_figs : int, optional
434 _description_, by default 6
436 Returns
437 -------
438 _type_
439 _description_
441 Raises
442 ------
443 ValueError
444 _description_
445 """
446 total_gain = 1
447 for mt_filter in self.filters_list:
448 total_gain *= mt_filter.gain
450 return round(total_gain, sig_figs - int(np.floor(np.log10(abs(total_gain)))))
452 @requires(obspy=inventory)
453 def to_obspy(self, sample_rate=1):
454 """
455 Output :class:`obspy.core.inventory.InstrumentSensitivity` object that
456 can be used in a stationxml file.
458 :param normalization_frequency: DESCRIPTION
459 :type normalization_frequency: TYPE
460 :return: DESCRIPTION
461 :rtype: TYPE
463 """
464 total_sensitivity = self.compute_instrument_sensitivity()
465 total_gain = self.compute_total_gain()
467 if total_sensitivity != total_gain:
468 logger.info(
469 f"total sensitivity {total_sensitivity} != total gain {total_gain}. Using total_gain."
470 )
471 total_sensitivity = total_gain
473 units_in_obj = get_unit_object(self.units_in)
474 units_out_obj = get_unit_object(self.units_out)
476 total_response = inventory.Response()
477 total_response.instrument_sensitivity = inventory.InstrumentSensitivity(
478 total_sensitivity,
479 self.normalization_frequency,
480 units_in_obj.symbol,
481 units_out_obj.symbol,
482 input_units_description=units_in_obj.name,
483 output_units_description=units_out_obj.name,
484 )
486 for ii, f in enumerate(self.filters_list, 1):
487 if f.type in ["coefficient"]:
488 if f.units_out not in ["count", "digital counts"]:
489 logger.debug(f"converting CoefficientFilter {f.name} to PZ")
490 pz = PoleZeroFilter()
491 pz.gain = f.gain
492 pz.units_in = f.units_in
493 pz.units_out = f.units_out
494 pz.comments = f.comments
495 pz.name = f.name
496 else:
497 pz = f
499 total_response.response_stages.append(
500 pz.to_obspy(
501 stage_number=ii,
502 normalization_frequency=self.normalization_frequency,
503 sample_rate=sample_rate,
504 )
505 )
506 else:
507 total_response.response_stages.append(
508 f.to_obspy(
509 stage_number=ii,
510 normalization_frequency=self.normalization_frequency,
511 sample_rate=sample_rate,
512 )
513 )
515 return total_response
517 def plot_response(
518 self,
519 frequencies=None,
520 x_units="period",
521 unwrap=True,
522 pb_tol=1e-1,
523 interpolation_method="slinear",
524 include_delay=False,
525 include_decimation=False,
526 ):
527 """
528 Plot the response
530 :param frequencies: frequencies to compute response, defaults to None
531 :type frequencies: np.ndarray, optional
532 :param x_units: [ period | frequency ], defaults to "period"
533 :type x_units: string, optional
534 :param unwrap: Unwrap phase, defaults to True
535 :type unwrap: bool, optional
536 :param pb_tol: pass band tolerance, defaults to 1e-1
537 :type pb_tol: float, optional
538 :param interpolation_method: Interpolation method see scipy.signal.interpolate
539 [ slinear | nearest | cubic | quadratic | ], defaults to "slinear"
540 :type interpolation_method: string, optional
541 :param include_delay: include delays in response, defaults to False
542 :type include_delay: bool, optional
543 :param include_decimation: Include decimation in response,
544 defaults to True
545 :type include_decimation: bool, optional
547 """
549 if frequencies is not None:
550 self.frequencies = frequencies
552 # get only the filters desired
553 if include_delay:
554 filters_list = deepcopy(self.filters_list)
555 else:
556 filters_list = deepcopy(self.non_delay_filters)
558 if not include_decimation:
559 filters_list = deepcopy(
560 [x for x in filters_list if not x.decimation_active]
561 )
563 cr_kwargs = {"interpolation_method": interpolation_method}
565 # get response of individual filters
566 cr_list = [
567 f.complex_response(self.frequencies, **cr_kwargs) for f in filters_list
568 ]
570 # compute total response
571 cr_kwargs["include_delay"] = include_delay
572 cr_kwargs["include_decimation"] = include_decimation
573 complex_response = self.complex_response(self.frequencies, **cr_kwargs)
575 cr_list.append(complex_response)
576 labels = [f.name for f in filters_list] + ["Total Response"]
578 # plot with proper attributes.
579 kwargs = {
580 "title": f"Channel Response: [{', '.join([f.name for f in filters_list])}]",
581 "unwrap": unwrap,
582 "x_units": x_units,
583 "pass_band": self.pass_band,
584 "label": labels,
585 "normalization_frequency": self.normalization_frequency,
586 }
588 plot_response(self.frequencies, cr_list, **kwargs)