Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ processing \ aurora \ decimation_level.py: 79%
224 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
1"""
2This module contains the DecimationLevel class.
3TODO: Factor or rename. The decimation level class here has information about the entire processing.
4"""
6# =====================================================
7# Imports
8# =====================================================
9from typing import Annotated, get_args, List, Union
11import numpy as np
12import pandas as pd
13from loguru import logger
14from pydantic import computed_field, Field, field_validator, ValidationInfo
16from mt_metadata.base import MetadataBase
17from mt_metadata.common.band import Band
18from mt_metadata.common.enumerations import StrEnumerationBase
19from mt_metadata.helper_functions import cast_to_class_if_dict, validate_setter_input
20from mt_metadata.processing import ShortTimeFourierTransform as STFT
21from mt_metadata.processing import TimeSeriesDecimation as Decimation
24from mt_metadata.features.weights import ChannelWeightSpec
26from mt_metadata.processing.aurora.estimator import Estimator
27from mt_metadata.processing.aurora.frequency_bands import FrequencyBands
28from mt_metadata.processing.aurora.regression import Regression
31from mt_metadata.processing.fourier_coefficients.decimation import (
32 Decimation as FCDecimation,
33)
36# =====================================================
37class SaveFcsTypeEnum(StrEnumerationBase):
38 h5 = "h5"
39 csv = "csv"
42class DecimationLevel(MetadataBase):
43 bands: Annotated[
44 list[Band],
45 Field(
46 default_factory=list,
47 description="List of bands",
48 json_schema_extra={
49 "units": None,
50 "required": True,
51 "examples": ["[]"],
52 },
53 ),
54 ]
56 channel_weight_specs: Annotated[
57 List[ChannelWeightSpec],
58 Field(
59 default_factory=list,
60 description="List of weighting schemes to use for TF processing for each output channel",
61 alias=None,
62 json_schema_extra={
63 "units": None,
64 "required": True,
65 "examples": ["[]"],
66 },
67 ),
68 ]
70 input_channels: Annotated[
71 list[str],
72 Field(
73 default_factory=list,
74 description="list of input channels (sources)",
75 alias=None,
76 json_schema_extra={
77 "units": None,
78 "required": True,
79 "examples": ["hx, hy"],
80 },
81 ),
82 ]
84 output_channels: Annotated[
85 list[str],
86 Field(
87 default_factory=list,
88 description="list of output channels (responses)",
89 alias=None,
90 json_schema_extra={
91 "units": None,
92 "required": True,
93 "examples": ["ex, ey, hz"],
94 },
95 ),
96 ]
98 reference_channels: Annotated[
99 list[str],
100 Field(
101 default_factory=list,
102 description="list of reference channels (remote sources)",
103 alias=None,
104 json_schema_extra={
105 "units": None,
106 "required": True,
107 "examples": ["hx, hy"],
108 },
109 ),
110 ]
112 save_fcs: Annotated[
113 bool,
114 Field(
115 default=False,
116 description="Whether the Fourier coefficients are saved [True] or not [False].",
117 alias=None,
118 json_schema_extra={
119 "units": None,
120 "required": True,
121 "examples": [True],
122 },
123 ),
124 ]
126 save_fcs_type: Annotated[
127 SaveFcsTypeEnum | None,
128 Field(
129 default=None,
130 description="Format to use for fc storage",
131 alias=None,
132 json_schema_extra={
133 "units": None,
134 "required": False,
135 "examples": ["h5"],
136 },
137 ),
138 ]
140 decimation: Annotated[
141 Decimation,
142 Field(
143 default_factory=Decimation, # type: ignore
144 description="Decimation settings",
145 alias=None,
146 json_schema_extra={
147 "units": None,
148 "required": False,
149 "examples": ["Decimation()"],
150 },
151 ),
152 ]
154 estimator: Annotated[
155 Estimator,
156 Field(
157 default_factory=Estimator, # type: ignore
158 description="Estimator settings",
159 alias=None,
160 json_schema_extra={
161 "units": None,
162 "required": False,
163 "examples": ["Estimator()"],
164 },
165 ),
166 ]
168 regression: Annotated[
169 Regression,
170 Field(
171 default_factory=Regression, # type: ignore
172 description="Regression settings",
173 alias=None,
174 json_schema_extra={
175 "units": None,
176 "required": False,
177 "examples": ["Regression()"],
178 },
179 ),
180 ]
182 stft: Annotated[
183 STFT,
184 Field(
185 default_factory=STFT, # type: ignore
186 description="Short-time Fourier transform settings",
187 alias=None,
188 json_schema_extra={
189 "units": None,
190 "required": False,
191 "examples": ["STFT()"],
192 },
193 ),
194 ]
196 @field_validator("channel_weight_specs", mode="before")
197 @classmethod
198 def validate_channel_weight_specs(cls, value, info: ValidationInfo):
199 """
200 Validator for channel_weight_specs field.
201 """
203 # Handle singleton cases
204 if isinstance(value, (ChannelWeightSpec, dict)):
205 value = [value]
207 if not isinstance(value, list):
208 raise TypeError(f"Not sure what to do with {type(value)}")
210 # Convert dicts to ChannelWeightSpecs objects
211 validated_specs = []
212 for item in value:
213 if isinstance(item, dict):
214 validated_specs.append(ChannelWeightSpec(**item))
215 elif isinstance(item, ChannelWeightSpec):
216 validated_specs.append(item)
217 else:
218 raise TypeError(
219 f"List entry must be a ChannelWeightSpec object or dict, not {type(item)}"
220 )
222 return validated_specs
224 @field_validator("bands", mode="before")
225 @classmethod
226 def validate_bands(cls, value, info: ValidationInfo):
227 # Get the field type dynamically from the model
228 field_name = info.field_name
229 if field_name is None:
230 raise ValueError("Field name is required for validation")
232 field_info = cls.model_fields[field_name]
234 # Extract the target class from List[TargetClass] annotation
235 target_class = get_args(field_info.annotation)[0]
237 values = validate_setter_input(value, target_class)
238 return [cast_to_class_if_dict(obj, target_class) for obj in values]
240 def add_band(self, band: Union[Band, dict]) -> None:
241 """
242 add a band
243 """
245 if not isinstance(band, (Band, dict)):
246 raise TypeError(f"List entry must be a Band object not {type(band)}")
247 if isinstance(band, dict):
248 obj = Band()
249 obj.from_dict(band)
250 else:
251 obj = band
253 self.bands.append(obj)
255 @computed_field
256 @property
257 def lower_bounds(self) -> np.ndarray:
258 """
259 get lower bounds index values into an array.
260 """
262 return np.array(sorted([band.index_min for band in self.bands]))
264 @computed_field
265 @property
266 def upper_bounds(self) -> np.ndarray:
267 """
268 get upper bounds index values into an array.
269 """
271 return np.array(sorted([band.index_max for band in self.bands]))
273 @computed_field
274 @property
275 def bands_dataframe(self) -> pd.DataFrame:
276 """
277 Utility function that transforms a list of bands into a dataframe
279 See notes in `_df_from_bands`.
281 Returns
282 -------
283 bands_df: pd.Dataframe
284 Same format as that generated by EMTFBandSetupFile.get_decimation_level()
285 """
286 bands_df = _df_from_bands(self.bands)
287 return bands_df
289 @computed_field
290 @property
291 def frequency_sample_interval(self) -> float:
292 """
293 Returns the delta_f in frequency domain df = 1 / (N * dt)
294 Here dt is the sample interval after decimation
296 Returns
297 -------
298 frequency_sample_interval: float
299 The frequency sample interval after decimation.
300 """
301 return self.decimation.sample_rate / self.stft.window.num_samples
303 @computed_field
304 @property
305 def band_edges(self) -> np.ndarray:
306 """
307 Returns the band edges as a numpy array
309 Returns
310 -------
311 band_edges: 2D numpy array, one row per frequency band and two columns
312 """
313 bands_df = self.bands_dataframe
314 band_edges = np.vstack(
315 (bands_df.frequency_min.values, bands_df.frequency_max.values)
316 ).T
317 return band_edges
319 def frequency_bands_obj(self) -> FrequencyBands:
320 """
321 Gets a FrequencyBands object that is used as input to processing.
323 Used by Aurora.
325 TODO: consider adding .to_frequency_bands() method directly to self.bands
327 Returns
328 -------
329 frequency_bands: FrequencyBands
330 A FrequencyBands object that can be used as an iterator for processing.
332 """
333 frequency_bands = FrequencyBands(band_edges=self.band_edges)
334 return frequency_bands
336 @property
337 def fft_frequencies(self) -> np.ndarray:
338 """
339 Gets the harmonics of the STFT.
341 Returns
342 -------
343 freqs: np.ndarray
344 The frequencies at which the stft will be available.
345 """
346 freqs = self.stft.window.fft_harmonics(self.decimation.sample_rate)
347 return freqs
349 @property
350 def harmonic_indices(self) -> List[int]:
351 """
352 Loops over all bands and returns a list of the harminic indices.
353 TODO: Distinguish the bands which are a processing construction vs harmonic indices which are FFT info.
355 Returns
356 -------
357 return_list: list of integers
358 The indices of the harmonics that are needed for processing.
359 """
360 return_list = []
361 for band in self.bands:
362 fc_indices = band.harmonic_indices
363 return_list += fc_indices.tolist()
364 return_list.sort()
365 return return_list
367 @property
368 def local_channels(self):
369 return self.input_channels + self.output_channels
371 def is_consistent_with_archived_fc_parameters(
372 self, fc_decimation: FCDecimation, remote: bool
373 ):
374 """
375 Usage: For an already existing spectrogram stored in an MTH5 archive, this compares the metadata
376 within the archive (fc_decimation) with an aurora decimation level (self), and tells whether the
377 parameters are in agreement. If True, this allows aurora to skip the calculation of FCs and instead
378 read them from the archive.
380 TODO: Merge all checks of TimeSeriesDecimation parameters into a single check.
381 - e.g. Compress all decimation checks to: assert fc_decimation.decimation == self.decimation
383 Parameters
384 ----------
385 decimation_level: FCDecimation
386 metadata describing the parameters used to compute an archived spectrogram
387 remote: bool
388 If True, we are looking for reference channels, not local channels in the FCGroup.
390 Iterates over FCDecimation attributes:
391 "channels_estimated": to ensure all expected channels are in the group
392 "decimation.anti_alias_filter": check that the expected AAF was applied
393 "decimation.sample_rate,
394 "decimation.method",
395 "stft.prewhitening_type",
396 "stft.recoloring",
397 "stft.pre_fft_detrend_type",
398 "stft.min_num_stft_windows",
399 "stft.window",
400 "stft.harmonic_indices",
401 Returns
402 -------
404 :return:
405 """
406 # channels_estimated: Checks that the archived spectrogram has the required channels
407 if remote:
408 required_channels = self.reference_channels
409 else:
410 required_channels = self.local_channels
411 try:
412 assert set(required_channels).issubset(fc_decimation.channels_estimated)
413 except AssertionError:
414 msg = (
415 f"required_channels for processing {required_channels} not available"
416 f"-- fc channels estimated are {fc_decimation.channels_estimated}"
417 )
418 logger.info(msg)
419 return False
421 # anti_alias_filter: Check that the data were filtered the same way
422 try:
423 assert (
424 fc_decimation.time_series_decimation.anti_alias_filter
425 == self.decimation.anti_alias_filter
426 )
427 except AssertionError:
428 cond1 = self.decimation.anti_alias_filter == "default"
429 cond2 = fc_decimation.time_series_decimation.anti_alias_filter is None
430 if cond1 & cond2:
431 pass
432 else:
433 msg = (
434 "Antialias Filters Not Compatible -- need to add handling for "
435 f"FCdec {fc_decimation.time_series_decimation.anti_alias_filter} and "
436 f"processing config:{self.decimation.anti_alias_filter}"
437 )
438 raise NotImplementedError(msg)
440 # sample_rate
441 try:
442 assert (
443 fc_decimation.time_series_decimation.sample_rate
444 == self.decimation.sample_rate
445 )
446 except AssertionError:
447 msg = (
448 f"Sample rates do not agree: fc {fc_decimation.time_series_decimation.sample_rate} differs from "
449 f"processing config {self.decimation.sample_rate}"
450 )
451 logger.info(msg)
452 return False
454 # transform method (fft, wavelet, etc.)
455 try:
456 assert (
457 fc_decimation.short_time_fourier_transform.method == self.stft.method
458 ) # FFT, Wavelet, etc.
459 except AssertionError:
460 msg = (
461 "Transform methods do not agree: "
462 f"fc {fc_decimation.short_time_fourier_transform.method} != processing config {self.stft.method}"
463 )
464 logger.info(msg)
465 return False
467 # prewhitening_type
468 try:
469 assert fc_decimation.stft.prewhitening_type == self.stft.prewhitening_type
470 except AssertionError:
471 msg = (
472 "prewhitening_type does not agree "
473 f"fc {fc_decimation.stft.prewhitening_type} != processing config {self.stft.prewhitening_type}"
474 )
475 logger.info(msg)
476 return False
478 # recoloring
479 try:
480 assert fc_decimation.stft.recoloring == self.stft.recoloring
481 except AssertionError:
482 msg = (
483 "recoloring does not agree "
484 f"fc {fc_decimation.stft.recoloring} != processing config {self.stft.recoloring}"
485 )
486 logger.info(msg)
487 return False
489 # pre_fft_detrend_type
490 try:
491 assert (
492 fc_decimation.stft.pre_fft_detrend_type
493 == self.stft.pre_fft_detrend_type
494 )
495 except AssertionError:
496 msg = (
497 "pre_fft_detrend_type does not agree "
498 f"fc {fc_decimation.stft.pre_fft_detrend_type} != processing config {self.stft.pre_fft_detrend_type}"
499 )
500 logger.info(msg)
501 return False
503 # min_num_stft_windows
504 try:
505 assert (
506 fc_decimation.stft.min_num_stft_windows
507 == self.stft.min_num_stft_windows
508 )
509 except AssertionError:
510 msg = (
511 "min_num_stft_windows do not agree "
512 f"fc {fc_decimation.stft.min_num_stft_windows} != processing config {self.stft.min_num_stft_windows}"
513 )
514 logger.info(msg)
515 return False
517 # window
518 try:
519 assert fc_decimation.stft.window == self.stft.window
520 except AssertionError:
521 msg = "window does not agree: "
522 msg = f"{msg} FC Group: {fc_decimation.stft.window} "
523 msg = f"{msg} Processing Config {self.stft.window}"
524 logger.info(msg)
525 return False
527 if fc_decimation.stft.harmonic_indices is None:
528 # harmonic_indices not set, skip this check
529 pass
530 elif -1 in fc_decimation.stft.harmonic_indices:
531 # if harmonic_indices is -1, it means the archive kept all so we can skip this check.
532 pass
533 else:
534 msg = "WIP: harmonic indices in AuroraDecimationlevel are derived from processing bands -- Not robustly tested to compare with FCDecimation"
535 logger.debug(msg)
536 harmonic_indices_requested = self.harmonic_indices
537 fcdec_group_set = set(fc_decimation.stft.harmonic_indices)
538 processing_set = set(harmonic_indices_requested)
539 if processing_set.issubset(fcdec_group_set):
540 pass
541 else:
542 msg = (
543 f"Processing FC indices {processing_set} is not contained "
544 f"in FC indices {fcdec_group_set}"
545 )
546 logger.info(msg)
547 return False
549 # Getting here means no checks were failed. The FCDecimation supports the processing config
550 return True
552 def to_fc_decimation(
553 self,
554 remote: bool = False,
555 ignore_harmonic_indices: bool = True,
556 ) -> FCDecimation:
557 """
558 Generates a FC Decimation() object for use with FC Layer in mth5.
560 TODO: this is being tested only in aurora -- move a test to mt_metadata or move the method.
561 Ignoring for now these properties
562 "time_period.end": "1980-01-01T00:00:00+00:00",
563 "time_period.start": "1980-01-01T00:00:00+00:00",
565 TODO: FIXME: Assignment of TSDecimation can be done in one shot once #235 is addressed.
567 Parameters
568 ----------
569 remote: bool
570 If True, use reference channels, if False, use local_channels. We may wish to not pass remote=True when
571 _building_ FCs however, because then not all channels will get built.
572 ignore_harmonic_indices: bool
573 If True, leave harmonic indices at default [-1,], which means all indices. If False, only the specific
574 harmonic indices needed for processing will be stored. Thus, when building FCs, it maybe best to leave
575 this as True, that way all FCs will be stored, so if the band setup is changed, the FCs will still be there.
577 Returns:
578 fc_dec_obj:mt_metadata.transfer_functions.processing.fourier_coefficients.decimation.Decimation
579 A decimation object configured for STFT processing
581 """
583 fc_dec_obj = FCDecimation() # type: ignore
584 fc_dec_obj.time_series_decimation.anti_alias_filter = (
585 self.decimation.anti_alias_filter
586 )
587 if remote:
588 fc_dec_obj.channels_estimated = self.reference_channels
589 else:
590 fc_dec_obj.channels_estimated = self.local_channels
591 fc_dec_obj.time_series_decimation.factor = self.decimation.factor
592 fc_dec_obj.time_series_decimation.level = self.decimation.level
593 if ignore_harmonic_indices:
594 pass
595 else:
596 # Now that harmonic_indices is list[int], this should work
597 fc_dec_obj.stft.harmonic_indices = self.harmonic_indices
598 fc_dec_obj.id = f"{self.decimation.level}"
599 fc_dec_obj.stft.method = self.stft.method
600 fc_dec_obj.stft.pre_fft_detrend_type = self.stft.pre_fft_detrend_type
601 fc_dec_obj.stft.prewhitening_type = self.stft.prewhitening_type
602 fc_dec_obj.stft.recoloring = self.stft.recoloring
603 fc_dec_obj.time_series_decimation.sample_rate = self.decimation.sample_rate
604 fc_dec_obj.stft.window = self.stft.window
606 return fc_dec_obj
609def _df_from_bands(band_list: List[Union[Band, dict, None]]) -> pd.DataFrame:
610 """
611 Utility function that transforms a list of bands into a dataframe
613 Note: The decimation_level here is +1 to agree with EMTF convention.
614 Not clear this is really necessary
615 TODO: Consider making this a method of FrequencyBands() class.
616 TODO: Check typehint -- should None be allowed value in the band_list?
617 TODO: Consider adding columns lower_closed, upper_closed to df
619 Parameters
620 ----------
621 band_list: list
622 obtained from mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel.bands
624 Returns
625 -------
626 out_df: pd.Dataframe
627 Same format as that generated by EMTFBandSetupFile.get_decimation_level()
628 """
629 df_columns = [
630 "decimation_level",
631 "lower_bound_index",
632 "upper_bound_index",
633 "frequency_min",
634 "frequency_max",
635 ]
636 n_rows = len(band_list)
637 df_columns_dict = {}
638 for col in df_columns:
639 df_columns_dict[col] = n_rows * [None]
640 for i_band, band in enumerate(band_list):
641 df_columns_dict["decimation_level"][i_band] = band.decimation_level + 1
642 df_columns_dict["lower_bound_index"][i_band] = band.index_min
643 df_columns_dict["upper_bound_index"][i_band] = band.index_max
644 df_columns_dict["frequency_min"][i_band] = band.frequency_min
645 df_columns_dict["frequency_max"][i_band] = band.frequency_max
646 out_df = pd.DataFrame(data=df_columns_dict)
647 out_df.sort_values(by="lower_bound_index", inplace=True)
648 out_df.reset_index(inplace=True, drop=True)
649 return out_df
652def get_fft_harmonics(samples_per_window: int, sample_rate: float) -> np.ndarray:
653 """
654 Works for odd and even number of points.
656 Development notes:
657 Could be modified with kwargs to support one_sided, two_sided, ignore_dc
658 ignore_nyquist, and etc. Consider taking FrequencyBands as an argument.
660 Parameters
661 ----------
662 samples_per_window: integer
663 Number of samples in a window that will be Fourier transformed.
664 sample_rate: float
665 Inverse of time step between samples,
666 Samples per second
668 Returns
669 -------
670 harmonic_frequencies: numpy array
671 The frequencies that the fft will be computed.
672 These are one-sided (positive frequencies only)
673 Does not return Nyquist
674 Does return DC component
675 """
676 n_fft_harmonics = int(samples_per_window / 2) # no bin at Nyquist,
677 delta_t = 1.0 / sample_rate
678 harmonic_frequencies = np.fft.fftfreq(samples_per_window, d=delta_t)
679 harmonic_frequencies = harmonic_frequencies[0:n_fft_harmonics]
680 return harmonic_frequencies