Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ processing \ fourier_coefficients \ decimation.py: 83%
201 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
1# =====================================================
2# Imports
3# =====================================================
4from collections import OrderedDict
5from typing import Annotated, List, Optional
7import numpy as np
8from loguru import logger
9from pydantic import Field, field_validator, model_validator, ValidationInfo
11from mt_metadata.base import MetadataBase
12from mt_metadata.common import ListDict, TimePeriod
13from mt_metadata.processing import ShortTimeFourierTransform, TimeSeriesDecimation
14from mt_metadata.processing.fourier_coefficients.fc_channel import FCChannel
17# =====================================================
18class Decimation(MetadataBase):
19 id: Annotated[
20 str,
21 Field(
22 default="",
23 description="Decimation level ID",
24 alias=None,
25 json_schema_extra={
26 "units": None,
27 "required": True,
28 "examples": ["1"],
29 },
30 ),
31 ]
33 channels_estimated: Annotated[
34 list[str],
35 Field(
36 default_factory=list,
37 description="list of channels",
38 alias=None,
39 json_schema_extra={
40 "units": None,
41 "required": True,
42 "examples": ["[ex, hy]"],
43 },
44 ),
45 ]
47 time_period: Annotated[
48 TimePeriod,
49 Field(
50 default_factory=TimePeriod, # type: ignore
51 description="Time period over which these FCs were estimated",
52 alias=None,
53 json_schema_extra={
54 "units": None,
55 "required": True,
56 "examples": ["TimePeriod()"],
57 },
58 ),
59 ]
61 channels: Annotated[
62 ListDict,
63 Field(
64 default_factory=ListDict,
65 description="List of channels",
66 alias=None,
67 json_schema_extra={
68 "units": None,
69 "required": True,
70 "examples": ["[ex, hy]"],
71 },
72 ),
73 ]
75 time_series_decimation: Annotated[
76 TimeSeriesDecimation,
77 Field(
78 default_factory=TimeSeriesDecimation, # type: ignore
79 description="Time series decimation settings",
80 alias=None,
81 json_schema_extra={
82 "units": None,
83 "required": True,
84 "examples": ["TimeSeriesDecimation()"],
85 },
86 ),
87 ]
89 short_time_fourier_transform: Annotated[
90 ShortTimeFourierTransform,
91 Field(
92 default_factory=ShortTimeFourierTransform, # type: ignore
93 description="Short time Fourier transform settings",
94 alias=None,
95 json_schema_extra={
96 "units": None,
97 "required": True,
98 "examples": ["ShortTimeFourierTransform()"],
99 },
100 ),
101 ]
103 @field_validator("short_time_fourier_transform", mode="before")
104 @classmethod
105 def validate_short_time_fourier_transform(
106 cls, value: ShortTimeFourierTransform, info: ValidationInfo
107 ) -> ShortTimeFourierTransform:
108 if not isinstance(value, ShortTimeFourierTransform):
109 msg = f"Input must be metadata ShortTimeFourierTransform not {type(value)}"
110 raise TypeError(msg)
111 if value.per_window_detrend_type:
112 msg = f"per_window_detrend_type was set to {value.per_window_detrend_type}"
113 msg += "however, this is not supported -- setting to empty string"
114 logger.debug(msg)
115 value.per_window_detrend_type = ""
116 return value
118 @field_validator("channels_estimated", mode="before")
119 @classmethod
120 def validate_channels_estimated(
121 cls, value: list[str], info: ValidationInfo
122 ) -> list[str]:
123 if not isinstance(value, list):
124 msg = f"Input must be a list of strings not {type(value)}"
125 raise TypeError(msg)
126 for item in value:
127 if not isinstance(item, str):
128 msg = f"All items in the list must be strings not {type(item)}"
129 raise TypeError(msg)
130 return value
132 @field_validator("channels", mode="before")
133 @classmethod
134 def validate_channels(cls, value: ListDict, info: ValidationInfo) -> ListDict:
135 # Handle None values first
136 if value is None:
137 return ListDict()
139 # Handle string representations that might come from HDF5 storage
140 if isinstance(value, str):
141 # If it's a string representation, try to parse it or return empty ListDict
142 if value in ["", "none", "None", "ListDict()", "{}"]:
143 return ListDict()
144 # For other string values, try to maintain backward compatibility
145 logger.warning(f"Converting string representation of channels: {value}")
146 return ListDict()
148 if not isinstance(value, (list, tuple, dict, ListDict, OrderedDict)):
149 msg = (
150 "input ch_list must be an iterable, should be a list or dict "
151 f"not {type(value)}"
152 )
153 logger.error(msg)
154 raise TypeError(msg)
156 fails = []
157 channels = ListDict()
158 if isinstance(value, (dict, ListDict, OrderedDict)):
159 value_list = value.values()
161 elif isinstance(value, (list, tuple)):
162 value_list = value
164 for ii, channel in enumerate(value_list):
165 try:
166 ch = FCChannel()
167 if hasattr(channel, "to_dict"):
168 channel = channel.to_dict()
169 ch.from_dict(channel)
170 channels.append(ch)
171 except Exception as error:
172 msg = "Could not create channel from dictionary: %s"
173 fails.append(msg % error)
174 logger.error(msg, error)
176 if len(fails) > 0:
177 raise TypeError("\n".join(fails))
179 return channels
181 @model_validator(mode="after")
182 def validate_channels_consistency(self):
183 """
184 Ensure that channels_estimated and channels are synchronized.
186 - If a channel name exists in channels_estimated but not in channels,
187 create a new FCChannel with that component name
188 - Ensure all channels in channels ListDict have their component names
189 in channels_estimated
190 """
191 channels_estimated = self.channels_estimated
192 channels = self.channels
194 # Get existing channel component names from the channels ListDict
195 existing_channel_names = set(channels.keys()) if channels.keys() else set()
197 # Get the set of estimated channel names
198 estimated_channel_names = (
199 set(channels_estimated) if channels_estimated else set()
200 )
202 # Find channels that are estimated but don't exist in channels ListDict
203 missing_channels = estimated_channel_names - existing_channel_names
205 # Create FCChannel objects for missing channels
206 for channel_name in missing_channels:
207 logger.info(f"Creating FCChannel for estimated channel: {channel_name}")
208 new_channel = FCChannel(component=channel_name)
209 channels.append(new_channel)
211 # Find channels in ListDict that aren't in channels_estimated and add them
212 extra_channels = existing_channel_names - estimated_channel_names
213 if extra_channels:
214 logger.info(f"Adding channels to channels_estimated: {extra_channels}")
215 # Add the extra channel names to channels_estimated
216 self.channels_estimated.extend(list(extra_channels))
218 return self
220 def add(self, other):
221 """
223 :param other:
224 :return:
225 """
226 if isinstance(other, Decimation):
227 self.channels.extend(other.channels)
229 return self
230 else:
231 msg = f"Can only merge ch objects, not {type(other)}"
232 logger.error(msg)
233 raise TypeError(msg)
235 # ----- Begin (Possibly Temporary) methods for integrating TimeSeriesDecimation, STFT Classes -----#
237 @property
238 def decimation(self) -> TimeSeriesDecimation:
239 """
240 Passthrough method to access self.time_series_decimation
241 """
242 return self.time_series_decimation
244 @property
245 def stft(self):
246 return self.short_time_fourier_transform
248 # ----- End (Possibly Temporary) methods for integrating TimeSeriesDecimation, STFT Classes -----#
250 def update(self, other, match=[]):
251 """
252 Update attribute values from another like element, skipping None
254 :param other: DESCRIPTION
255 :type other: TYPE
256 :return: DESCRIPTION
257 :rtype: TYPE
259 """
260 if not isinstance(other, type(self)):
261 logger.warning("Cannot update %s with %s", type(self), type(other))
262 for k in match:
263 if self.get_attr_from_name(k) != other.get_attr_from_name(k):
264 msg = "%s is not equal %s != %s"
265 logger.error(
266 msg,
267 k,
268 self.get_attr_from_name(k),
269 other.get_attr_from_name(k),
270 )
271 raise ValueError(
272 msg,
273 k,
274 self.get_attr_from_name(k),
275 other.get_attr_from_name(k),
276 )
277 for k, v in other.to_dict(single=True).items():
278 if hasattr(v, "size"):
279 if v.size > 0:
280 self.update_attribute(k, v)
281 else:
282 if v not in [None, 0.0, [], "", "1980-01-01T00:00:00+00:00"]:
283 self.update_attribute(k, v)
285 ## Need this because channels are set when setting channels_recorded
286 ## and it initiates an empty channel, but we need to fill it with
287 ## the appropriate metadata.
288 for ch in other.channels:
289 self.add_channel(ch)
291 def has_channel(self, component: str) -> bool:
292 """
293 Check to see if the channel already exists
295 :param component: channel component to look for
296 :type component: string
297 :return: True if found, False if not
298 :rtype: boolean
300 """
302 if component in self.channels_estimated:
303 return True
304 return False
306 def channel_index(self, component):
307 """
308 get index of the channel in the channel list
309 """
310 if self.has_channel(component):
311 return self.channels_estimated.index(component)
313 def get_channel(self, component: str) -> FCChannel | None:
314 """
315 Get a channel
317 :param component: channel component to look for
318 :type component: string
319 :return: FCChannel object based on channel type
320 :rtype: :class:`mt_metadata.timeseries.Channel`
322 """
324 if self.has_channel(component):
325 return self.channels[component]
327 def add_channel(self, channel_obj: FCChannel) -> None:
328 """
329 Add a channel to the list, check if one exists if it does overwrite it
331 :param channel_obj: channel object to add
332 :type channel_obj: :class:`mt_metadata.transfer_functions.processing.fourier_coefficients.Channel`
334 """
335 if not isinstance(channel_obj, (FCChannel)):
336 msg = f"Input must be metadata FCChannel not {type(channel_obj)}"
337 logger.error(msg)
338 raise ValueError(msg)
340 if self.has_channel(channel_obj.component):
341 self.channels[channel_obj.component].update(channel_obj)
342 logger.debug(
343 f"ch {channel_obj.component} already exists, updating metadata"
344 )
346 else:
347 self.channels.append(channel_obj)
349 self.update_time_period()
351 def remove_channel(self, channel_id: str) -> None:
352 """
353 remove a channel from the survey
355 :param component: channel component to look for
356 :type component: string
358 """
360 if self.has_channel(channel_id):
361 self.channels.remove(channel_id)
362 self.channels_estimated.remove(channel_id)
363 else:
364 logger.warning(f"Could not find {channel_id} to remove.")
366 self.update_time_period()
368 @property
369 def n_channels(self):
370 return len(self.channels)
372 def update_time_period(self):
373 """
374 update time period from ch information
375 """
376 start = []
377 end = []
378 for ch in self.channels:
379 if ch.time_period.start != "1980-01-01T00:00:00+00:00":
380 start.append(ch.time_period.start)
381 if ch.time_period.start != "1980-01-01T00:00:00+00:00":
382 end.append(ch.time_period.end)
383 if start:
384 if self.time_period.start == "1980-01-01T00:00:00+00:00":
385 self.time_period.start = min(start)
386 else:
387 if self.time_period.start > min(start):
388 self.time_period.start = min(start)
389 if end:
390 if self.time_period.end == "1980-01-01T00:00:00+00:00":
391 self.time_period.end = max(end)
392 else:
393 if self.time_period.end < max(end):
394 self.time_period.end = max(end)
396 def is_valid_for_time_series_length(self, n_samples_ts: int) -> bool:
397 """
398 Given a time series of len n_samples_ts, checks if there are sufficient samples to STFT.
400 """
401 required_num_samples = (
402 self.stft.window.num_samples
403 + (self.stft.min_num_stft_windows - 1)
404 * self.stft.window.num_samples_advance
405 )
406 if n_samples_ts < required_num_samples:
407 msg = (
408 f"{n_samples_ts} not enough samples for minimum of "
409 f"{self.stft.min_num_stft_windows} stft windows of length "
410 f"{self.stft.window.num_samples} and overlap {self.stft.window.overlap}"
411 )
412 logger.warning(msg)
413 return False
414 else:
415 return True
417 @property
418 def fft_frequencies(self) -> np.ndarray:
419 """Returns the one-sided fft frequencies (without Nyquist)"""
420 return self.stft.window.fft_harmonics(self.decimation.sample_rate)
423def fc_decimations_creator(
424 initial_sample_rate: float,
425 decimation_factors: Optional[list] = None,
426 max_levels: Optional[int] = 6,
427 time_period: Optional[TimePeriod] = None,
428) -> List[Decimation]:
429 """
431 Creates mt_metadata FCDecimation objects that parameterize Fourier coefficient decimation levels.
433 Note 1: This does not yet work through the assignment of which bands to keep. Refer to
434 mt_metadata.transfer_functions.processing.Processing.assign_bands() to see how this was done in the past
436 Parameters
437 ----------
438 initial_sample_rate: float
439 Sample rate of the "level0" data -- usually the sample rate during field acquisition.
440 decimation_factors: Optional[list]
441 The decimation factors that will be applied at each FC decimation level
442 max_levels: Optional[int]
443 The maximum number of decimation levels to allow
444 time_period: Optional[TimePeriod]
445 Provides the start and end times
447 Returns
448 -------
449 fc_decimations: list
450 Each element of the list is an object of type
451 mt_metadata.transfer_functions.processing.fourier_coefficients.Decimation,
452 (a.k.a. FCDecimation).
454 The order of the list corresponds the order of the cascading decimation
455 - No decimation levels are omitted.
456 - This could be changed in future by using a dict instead of a list,
457 - e.g. decimation_factors = dict(zip(np.arange(max_levels), decimation_factors))
459 """
460 if not decimation_factors:
461 # msg = "No decimation factors given, set default values to EMTF default values [1, 4, 4, 4, ..., 4]")
462 # logger.info(msg)
463 default_decimation_factor = 4
464 decimation_factors = max_levels * [default_decimation_factor]
465 decimation_factors[0] = 1
467 # See Note 1
468 fc_decimations = []
469 for i_dec_level, decimation_factor in enumerate(decimation_factors):
470 fc_dec = Decimation()
471 fc_dec.time_series_decimation.level = i_dec_level
472 fc_dec.id = f"{i_dec_level}"
473 fc_dec.time_series_decimation.factor = decimation_factor
474 if i_dec_level == 0:
475 current_sample_rate = 1.0 * initial_sample_rate
476 else:
477 current_sample_rate /= decimation_factor
478 fc_dec.time_series_decimation.sample_rate = current_sample_rate
480 if time_period:
481 if isinstance(time_period, TimePeriod):
482 fc_dec.time_period = time_period
483 else:
484 msg = (
485 f"Not sure how to assign time_period with type {type(time_period)}"
486 )
487 logger.info(msg)
488 raise NotImplementedError(msg)
490 fc_decimations.append(fc_dec)
492 return fc_decimations
495def get_degenerate_fc_decimation(sample_rate: float) -> list:
496 """
497 WIP
499 Makes a default fc_decimation list.
500 This "degenerate" config will only operate on the first decimation level.
501 This is useful for testing. It could also be used in future on an MTH5 stored
502 time series in decimation levels already as separate runs.
504 Parameters
505 ----------
506 sample_rate: float
507 The sample rate associated with the time-series to convert to spectrogram
509 Returns
510 -------
511 output: list
512 List has only one element which is of type FCDecimation, aka.
514 """
515 output = fc_decimations_creator(
516 sample_rate,
517 decimation_factors=[
518 1,
519 ],
520 max_levels=1,
521 )
522 return output