Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ processing \ fourier_coefficients \ fc.py: 92%
134 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
1# =====================================================
2# Imports
3# =====================================================
4from collections import OrderedDict
5from typing import Annotated
7import numpy as np
8from loguru import logger
9from pydantic import Field, field_validator, model_validator, ValidationInfo
11from mt_metadata import NULL_VALUES
12from mt_metadata.base import MetadataBase
13from mt_metadata.common import ListDict, TimePeriod
14from mt_metadata.common.enumerations import StrEnumerationBase
15from mt_metadata.processing.fourier_coefficients.decimation import Decimation
18# =====================================================
19class MethodEnum(StrEnumerationBase):
20 fft = "fft"
21 wavelet = "wavelet"
22 other = "other"
25class FC(MetadataBase):
26 decimation_levels: Annotated[
27 list[str],
28 Field(
29 default_factory=list,
30 description="List of decimation levels",
31 alias=None,
32 json_schema_extra={
33 "units": None,
34 "required": True,
35 "examples": ["[1, 2, 3]"],
36 },
37 ),
38 ]
40 id: Annotated[
41 str,
42 Field(
43 default="",
44 description="ID given to the FC group",
45 alias=None,
46 json_schema_extra={
47 "units": None,
48 "required": True,
49 "examples": ["aurora_01"],
50 },
51 ),
52 ]
54 channels_estimated: Annotated[
55 list[str],
56 Field(
57 default_factory=list,
58 description="list of channels estimated",
59 alias=None,
60 json_schema_extra={
61 "units": None,
62 "required": True,
63 "examples": [["ex", "hy"]],
64 },
65 ),
66 ]
68 starting_sample_rate: Annotated[
69 float,
70 Field(
71 default=1.0,
72 description="Starting sample rate of the time series used to estimate FCs.",
73 alias=None,
74 json_schema_extra={
75 "units": "samples per second",
76 "required": True,
77 "examples": [60],
78 },
79 ),
80 ]
82 method: Annotated[
83 MethodEnum,
84 Field(
85 default=MethodEnum.fft,
86 description="Fourier transform method",
87 alias=None,
88 json_schema_extra={
89 "units": None,
90 "required": True,
91 "examples": ["fft"],
92 },
93 ),
94 ]
96 time_period: Annotated[
97 TimePeriod,
98 Field(
99 default_factory=TimePeriod, # type: ignore
100 description="Time period of the FCs",
101 alias=None,
102 json_schema_extra={
103 "units": None,
104 "required": True,
105 "examples": [TimePeriod(start="2020-01-01", end="2020-01-02")],
106 },
107 ),
108 ]
110 levels: Annotated[
111 ListDict,
112 Field(
113 default_factory=ListDict, # type: ignore
114 description="ListDict of decimation levels and their parameters",
115 alias=None,
116 json_schema_extra={
117 "units": None,
118 "required": True,
119 "examples": ["ListDict containing Decimation objects"],
120 },
121 ),
122 ]
124 @field_validator("channels_estimated", "decimation_levels", mode="before")
125 @classmethod
126 def validate_channels_estimated(
127 cls, value: list[str] | np.ndarray | str, info: ValidationInfo
128 ) -> list[str]:
129 if isinstance(value, np.ndarray):
130 value = value.tolist()
132 if value in NULL_VALUES:
133 return []
134 elif isinstance(value, (list, tuple)):
135 return value
137 elif isinstance(value, (str)):
138 value = value.split(",")
139 return value
141 else:
142 raise TypeError(
143 "'channels_recorded' must be set with a list not " f"{type(value)}."
144 )
146 @field_validator("levels", mode="before")
147 @classmethod
148 def validate_levels(cls, value, info: ValidationInfo):
149 # Handle None values first
150 if value is None:
151 return ListDict()
153 # Handle string representations that might come from HDF5 storage
154 if isinstance(value, str):
155 # If it's a string representation, try to parse it or return empty ListDict
156 if value in ["", "none", "None", "ListDict()", "{}"]:
157 return ListDict()
158 # For other string values, try to maintain backward compatibility
159 logger.warning(f"Converting string representation of levels: {value}")
160 return ListDict()
162 if not isinstance(value, (list, tuple, dict, ListDict, OrderedDict)):
163 msg = (
164 "input dl_list must be an iterable, should be a list or dict "
165 f"not {type(value)}"
166 )
167 logger.error(msg)
168 raise TypeError(msg)
170 fails = []
171 levels = ListDict()
172 if isinstance(value, (dict, ListDict, OrderedDict)):
173 value_list = value.values()
175 elif isinstance(value, (list, tuple)):
176 value_list = value
178 for ii, decimation_level in enumerate(value_list):
179 try:
180 if isinstance(decimation_level, Decimation):
181 dl = decimation_level
182 else:
183 dl = Decimation() # type: ignore
184 if hasattr(decimation_level, "to_dict"):
185 decimation_level = decimation_level.to_dict()
186 dl.from_dict(decimation_level)
187 levels.append(dl)
188 except Exception as error:
189 msg = "Could not create decimation_level from dictionary: %s"
190 fails.append(msg % error)
191 logger.error(msg, error)
193 if len(fails) > 0:
194 raise TypeError("\n".join(fails))
196 return levels
198 @model_validator(mode="after")
199 def synchronize_levels(self) -> "FC":
200 """
201 Ensure that decimation_levels and levels are synchronized.
202 - Creates Decimation objects for any levels in decimation_levels that don't exist in levels
203 - Adds level names to decimation_levels for any existing levels not in the list
204 """
205 # First, ensure all levels in decimation_levels have corresponding Decimation objects
206 for level_name in self.decimation_levels:
207 level_name_str = str(level_name)
208 if level_name_str not in self.levels.keys():
209 # Create a new Decimation object with the level name as id
210 new_decimation = Decimation(id=level_name_str) # type: ignore
211 self.levels.append(new_decimation)
213 # Second, ensure all existing levels in the ListDict are in decimation_levels
214 for level_name in self.levels.keys():
215 if level_name not in self.decimation_levels:
216 self.decimation_levels.append(level_name)
218 return self
220 def has_decimation_level(self, level):
221 """
222 Check to see if the decimation_level already exists
224 :param level: decimation_level level to look for
225 :type level: string
226 :return: True if found, False if not
227 :rtype: boolean
229 """
231 if level in self.decimation_levels:
232 return True
233 return False
235 def decimation_level_index(self, level):
236 """
237 get index of the decimation_level in the decimation_level list
238 """
239 if self.has_decimation_level(level):
240 return self.levels.keys().index(str(level))
241 return None
243 def get_decimation_level(self, level):
244 """
245 Get a decimation_level
247 :param level: decimation_level level to look for
248 :type level: string
249 :return: decimation_level object based on decimation_level type
250 :rtype: :class:`mt_metadata.timeseries.decimation_level`
252 """
254 if self.has_decimation_level(level):
255 return self.levels[str(level)]
257 def add_decimation_level(self, fc_decimation):
258 """
259 Add a decimation_level to the list, check if one exists if it does overwrite it
261 :param fc_decimation: decimation level object to add
262 :type fc_decimation: :class:`mt_metadata.processing.fourier_coefficients.decimation_basemodel.Decimation`
264 """
265 if not isinstance(fc_decimation, (Decimation)):
266 msg = f"Input must be metadata.decimation_level not {type(fc_decimation)}"
267 logger.error(msg)
268 raise ValueError(msg)
270 level_id = fc_decimation.id
271 if self.has_decimation_level(level_id):
272 self.levels[level_id].update(fc_decimation)
273 logger.debug(f"level {level_id} already exists, updating metadata")
274 else:
275 self.levels.append(fc_decimation)
276 # Also add to decimation_levels list if not present
277 if level_id not in self.decimation_levels:
278 self.decimation_levels.append(level_id)
280 self.update_time_period()
282 def remove_decimation_level(self, decimation_level_id):
283 """
284 remove a ch from the survey
286 :param level: decimation_level level to look for
287 :type level: string
289 """
291 if self.has_decimation_level(decimation_level_id):
292 self.levels.remove(decimation_level_id)
293 # Also remove from decimation_levels list
294 if decimation_level_id in self.decimation_levels:
295 self.decimation_levels.remove(decimation_level_id)
296 else:
297 logger.warning(f"Could not find {decimation_level_id} to remove.")
299 self.update_time_period()
301 @property
302 def n_decimation_levels(self):
303 return len(self.levels)
305 def update_time_period(self):
306 """
307 update time period from ch information
308 """
309 start = []
310 end = []
311 for dl in self.levels:
312 if dl.time_period.start != "1980-01-01T00:00:00+00:00":
313 start.append(dl.time_period.start)
314 if dl.time_period.start != "1980-01-01T00:00:00+00:00":
315 end.append(dl.time_period.end)
316 if start:
317 if self.time_period.start == "1980-01-01T00:00:00+00:00":
318 self.time_period.start = min(start)
319 else:
320 if self.time_period.start > min(start):
321 self.time_period.start = min(start)
322 if end:
323 if self.time_period.end == "1980-01-01T00:00:00+00:00":
324 self.time_period.end = max(end)
325 else:
326 if self.time_period.end < max(end):
327 self.time_period.end = max(end)