Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ processing \ aurora \ processing.py: 43%
142 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
1# =====================================================
2# Imports
3# =====================================================
4from typing import Annotated
6from loguru import logger
7from pydantic import computed_field, Field, field_validator
9from mt_metadata.base import MetadataBase
10from mt_metadata.common.band import Band
11from mt_metadata.common.enumerations import StrEnumerationBase
12from mt_metadata.processing.aurora.channel_nomenclature import ChannelNomenclature
13from mt_metadata.processing.aurora.decimation_level import DecimationLevel
14from mt_metadata.processing.aurora.stations import Stations
17# =====================================================
18class BandSpecificationStyleEnum(StrEnumerationBase):
19 EMTF = "EMTF"
20 band_edges = "band_edges"
23class Processing(MetadataBase):
24 decimations: Annotated[
25 list[DecimationLevel],
26 Field(
27 default_factory=list,
28 description="decimation levels",
29 alias=None,
30 json_schema_extra={
31 "units": None,
32 "required": True,
33 "examples": ["0"],
34 },
35 ),
36 ]
38 band_specification_style: Annotated[
39 BandSpecificationStyleEnum | None,
40 Field(
41 default=None,
42 description="describes how bands were sourced",
43 alias=None,
44 json_schema_extra={
45 "units": None,
46 "required": False,
47 "examples": ["EMTF"],
48 },
49 ),
50 ]
52 band_setup_file: Annotated[
53 str | None,
54 Field(
55 default=None,
56 description="the band setup file used to define bands",
57 alias=None,
58 json_schema_extra={
59 "units": None,
60 "required": False,
61 "examples": ["/home/user/bs_test.cfg"],
62 },
63 ),
64 ]
66 id: Annotated[
67 str,
68 Field(
69 default="",
70 description="Configuration ID",
71 alias=None,
72 json_schema_extra={
73 "units": None,
74 "required": True,
75 "examples": ["0"],
76 },
77 ),
78 ]
80 channel_nomenclature: Annotated[
81 ChannelNomenclature,
82 Field(
83 default_factory=ChannelNomenclature, # type: ignore
84 description="Channel nomenclature",
85 alias=None,
86 json_schema_extra={
87 "units": None,
88 "required": True,
89 "examples": ["EMTF"],
90 },
91 ),
92 ]
94 stations: Annotated[
95 Stations,
96 Field(
97 default_factory=Stations, # type: ignore
98 description="Station information",
99 alias=None,
100 json_schema_extra={
101 "units": None,
102 "required": True,
103 "examples": ["Station1", "Station2"],
104 },
105 ),
106 ]
108 @field_validator("decimations", mode="before")
109 @classmethod
110 def validate_decimations(cls, value, info) -> list[DecimationLevel]:
111 decimation_levels = []
112 if isinstance(value, DecimationLevel):
113 decimation_levels.append(value)
115 elif isinstance(value, dict):
116 for key, obj in value.items():
117 if not isinstance(obj, DecimationLevel):
118 raise TypeError(
119 f"List entry must be a DecimationLevel object not {type(obj)}"
120 )
121 else:
122 decimation_levels.append(obj)
124 elif isinstance(value, list):
125 for obj in value:
126 if isinstance(obj, DecimationLevel):
127 decimation_levels.append(obj)
128 for obj in value:
129 if isinstance(obj, DecimationLevel):
130 decimation_levels.append(obj)
131 elif isinstance(obj, dict):
132 level = DecimationLevel() # type: ignore
133 level.from_dict(obj)
134 decimation_levels.append(level)
135 else:
136 raise TypeError(
137 f"List entry must be a DecimationLevel or dict object not {type(obj)}"
138 )
139 # TODO: Add some doc describing the role of this weird check for a long string
140 elif isinstance(value, str):
141 if len(value) > 4:
142 raise TypeError(f"Not sure what to do with {type(value)}")
143 else:
144 decimation_levels = []
146 else:
147 raise TypeError(f"Not sure what to do with {type(value)}")
149 return decimation_levels
151 @computed_field
152 @property
153 def decimations_dict(self) -> dict[int, DecimationLevel]:
154 """
155 need to have a dictionary, but it can't be an attribute cause that
156 gets confusing when reading in a json file
158 Returns
159 -------
160 dict[int, DecimationLevel]
161 A dictionary mapping decimation levels to their corresponding DecimationLevel objects.
163 """
164 return dict([(d.decimation.level, d) for d in self.decimations])
166 def get_decimation_level(self, level: int) -> DecimationLevel:
167 """
168 Get a decimation level for easy access
170 Parameters
171 ----------
172 level: int
173 The decimation level to retrieve.
175 Returns
176 -------
177 DecimationLevel
178 The DecimationLevel object corresponding to the specified level.
180 """
182 try:
183 decimation = self.decimations_dict[level]
185 except KeyError:
186 raise KeyError(f"Could not find {level} in decimations.")
188 if isinstance(decimation, dict):
189 decimation_level = DecimationLevel() # type: ignore
190 decimation_level.from_dict(decimation)
191 return decimation_level
193 return decimation
195 def add_decimation_level(self, decimation_level: DecimationLevel | dict):
196 """
197 add a decimation level
199 Parameters
200 ----------
201 decimation_level: DecimationLevel | dict
202 The decimation level to add, either as a DecimationLevel object or a dictionary.
203 Returns
204 -------
205 None
206 """
208 if not isinstance(decimation_level, (DecimationLevel, dict)):
209 raise TypeError(
210 f"List entry must be a DecimationLevel object not {type(decimation_level)}"
211 )
212 if isinstance(decimation_level, dict):
213 obj = DecimationLevel() # type: ignore
214 obj.from_dict(decimation_level)
216 else:
217 obj = decimation_level
219 self.decimations.append(obj)
221 @computed_field
222 @property
223 def band_edges_dict(self) -> dict[int, list[tuple[float, float]]]:
224 band_edges_dict = {}
225 for i_dec, decimation in enumerate(self.decimations):
226 band_edges_dict[i_dec] = decimation.band_edges
227 return band_edges_dict
229 def assign_decimation_level_data_emtf(self, sample_rate: float):
230 """
232 Warning: This does not actually tell us how many samples we are decimating down
233 at each level. That is assumed to be 4 but we need a way to bookkeep this in general
235 Parameters
236 ----------
237 sample_rate: float
238 The initial sampling rate of the data before any decimation
240 """
241 for key in sorted(self.decimations_dict.keys()):
242 if key in [0, "0"]:
243 d = 1
244 sr = sample_rate
245 else:
246 # careful with this hardcoded assumption of decimation by 4
247 d = 4
248 sr = sample_rate / (d ** int(key))
249 decimation_obj = self.decimations_dict[key]
250 decimation_obj.decimation.factor = d
251 decimation_obj.decimation.sample_rate = sr
253 def assign_bands(
254 self,
255 band_edges_dict: dict[int, list[tuple[float, float]]],
256 sample_rate: float,
257 decimation_factors: dict[int, int],
258 num_samples_window: dict[int, int] | int = 256,
259 ) -> None:
260 """
262 Warning: This does not actually tell us how many samples we are decimating down
263 at each level. That is assumed to be 4 but we need a way to bookkeep this in general
265 Parameters
266 ----------
267 band_edges: dict[int, list[tuple[float, float]]]
268 A dictionary mapping decimation levels to lists of frequency band edges.
269 keys are integers, starting with 0, values are arrays of edges
271 sample_rate: float
272 The initial sampling rate of the data before any decimation.
274 decimation_factors: dict[int, int]
275 A dictionary mapping decimation levels to their corresponding decimation factors.
277 num_samples_window: dict[int, int] | int, optional (default=256)
278 The number of samples in the STFT window for each decimation level. If an integer is provided,
279 it will be applied to all levels. If a dictionary is provided, it should map decimation levels to
280 their corresponding number of samples.
282 Returns
283 -------
284 None
285 """
286 num_decimation_levels = len(band_edges_dict.keys())
287 if isinstance(num_samples_window, int):
288 num_samples_window = num_decimation_levels * [num_samples_window]
290 for i_level in sorted(band_edges_dict.keys()):
291 band_edges = band_edges_dict[i_level]
292 if i_level in [0, "0"]:
293 d = decimation_factors[i_level] # 1
294 sr = sample_rate
295 else:
296 # careful with this hardcoded assumption of decimation by 4
297 d = decimation_factors[i_level] # 4
298 sr = 1.0 * sample_rate / (d ** int(i_level))
299 decimation_obj = DecimationLevel()
300 decimation_obj.decimation.level = int(i_level) # self.decimations_dict[key]
301 decimation_obj.decimation.factor = d
302 decimation_obj.decimation.sample_rate = sr
303 decimation_obj.stft.window.num_samples = num_samples_window[i_level]
304 frequencies = decimation_obj.fft_frequencies
306 for low, high in band_edges:
307 band = Band( # type: ignore
308 decimation_level=i_level,
309 frequency_min=low,
310 frequency_max=high,
311 )
312 band.set_indices_from_frequencies(frequencies)
313 decimation_obj.add_band(band)
314 self.add_decimation_level(decimation_obj)
316 def json_fn(self):
317 json_fn = self.id + "_processing_config.json"
318 return json_fn
320 @property
321 def num_decimation_levels(self):
322 return len(self.decimations)
324 def drop_reference_channels(self):
325 for decimation in self.decimations:
326 decimation.reference_channels = []
327 return
329 def set_input_channels(self, channels):
330 for decimation in self.decimations:
331 decimation.input_channels = channels
333 def set_output_channels(self, channels):
334 for decimation in self.decimations:
335 decimation.output_channels = channels
337 def set_reference_channels(self, channels):
338 for decimation in self.decimations:
339 decimation.reference_channels = channels
341 def set_default_input_output_channels(self):
342 self.set_input_channels(self.channel_nomenclature.default_input_channels)
343 self.set_output_channels(self.channel_nomenclature.default_output_channels)
345 def set_default_reference_channels(self):
346 self.set_reference_channels(
347 self.channel_nomenclature.default_reference_channels
348 )
350 def validate_processing(self, kernel_dataset):
351 """
352 Placeholder. Some of the checks and methods here maybe better placed in
353 TFKernel, which would validate the dataset against the processing config.
355 Things that are validated:
356 1. The default estimation engine from the json file is "RME_RR", which is fine (
357 we expect to in general to do more RR processing than SS) but if there is only
358 one station (no remote)then the RME_RR should be replaced by default with "RME".
360 2. make sure local station id is defined (correctly from kernel dataset)
361 """
363 # Make sure a RR method is not being called for a SS config
364 if not self.stations.remote:
365 for decimation in self.decimations:
366 if decimation.estimator.engine == "RME_RR":
367 logger.info("No RR station specified, switching RME_RR to RME")
368 decimation.estimator.engine = "RME"
370 # Make sure that a local station is defined
371 if not self.stations.local.id:
372 logger.warning(
373 "Local station not specified, should be set from Kernel Dataset"
374 )
375 self.stations.from_dataset_dataframe(kernel_dataset.df)