Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ timeseries \ experiment.py: 75%
305 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
1# -*- coding: utf-8 -*-
2"""
3Containers for the full metadata tree
5Experiment --> Survey --> Station --> Run --> Channel
7Each level has a list attribute
9Created on Mon Feb 8 21:25:40 2021
11:copyright:
12 Jared Peacock (jpeacock@usgs.gov)
14:license: MIT
16"""
17import json
19# =============================================================================
20# Imports
21# =============================================================================
22from collections import OrderedDict
23from pathlib import Path
24from typing import Annotated
25from xml.etree import cElementTree as et
27from loguru import logger
28from pydantic import computed_field, Field, field_validator
30from mt_metadata.base import helpers, MetadataBase
31from mt_metadata.common.list_dict import ListDict
33from . import Auxiliary, Electric, Magnetic, Run, Station, Survey
34from .filters import (
35 CoefficientFilter,
36 FIRFilter,
37 FrequencyResponseTableFilter,
38 PoleZeroFilter,
39 TimeDelayFilter,
40)
43# =============================================================================
46class Experiment(MetadataBase):
47 """
48 Top level of the metadata
49 """
51 surveys: Annotated[
52 ListDict | list | dict | OrderedDict,
53 Field(
54 default_factory=ListDict,
55 description="List of surveys in the experiment",
56 title="List of Surveys",
57 json_schema_extra={
58 "required": False,
59 "units": None,
60 "examples": [{"id": "survey_1"}, {"id": "survey_2"}],
61 },
62 ),
63 ]
65 def __str__(self) -> str:
66 lines = ["Experiment Contents", "-" * 20]
67 if len(self.surveys) > 0:
68 lines.append(f"Number of Surveys: {len(self.surveys)}")
69 for survey in self.surveys:
70 lines.append(f" Survey ID: {survey.id}")
71 lines.append(f" Number of Stations: {survey.n_stations}")
72 lines.append(f" Number of Filters: {len(survey.filters.keys())}")
73 lines.append(f" {'-' * 20}")
74 for f_key, f_object in survey.filters.items():
75 lines.append(f" Filter Name: {f_key}")
76 lines.append(f" Filter Type: {f_object.type}")
77 lines.append(f" {'-' * 20}")
78 for station in survey.stations:
79 lines.append(f" Station ID: {station.id}")
80 lines.append(f" Number of Runs: {station.n_runs}")
81 lines.append(f" {'-' * 20}")
82 for run in station.runs:
83 lines.append(f" Run ID: {run.id}")
84 lines.append(f" Number of Channels: {run.n_channels}")
85 lines.append(
86 " Recorded Channels: "
87 + ", ".join(run.channels_recorded_all)
88 )
89 lines.append(f" Start: {run.time_period.start}")
90 lines.append(f" End: {run.time_period.end}")
92 lines.append(f" {'-' * 20}")
94 return "\n".join(lines)
96 def __repr__(self) -> str:
97 return self.__str__()
99 def __eq__(self, other) -> bool:
100 return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
102 def __ne__(self, other) -> bool:
103 return not self.__eq__(other)
105 def merge(self, other: "Experiment") -> "Experiment":
106 """
107 Merge two Experiment objects
108 """
109 if isinstance(other, Experiment):
110 self.surveys.extend(other.surveys)
112 return self
113 else:
114 msg = f"Can only merge Experiment objects, not {type(other)}"
115 logger.error(msg)
116 raise TypeError(msg)
118 @computed_field
119 @property
120 def n_surveys(self) -> int:
121 return len(self.surveys)
123 @field_validator("surveys", mode="before")
124 @classmethod
125 def validate_surveys(cls, value) -> ListDict:
126 """set the survey list"""
128 if not isinstance(value, (list, tuple, dict, ListDict, OrderedDict)):
129 msg = (
130 "input station_list must be an iterable, should be a list or dict "
131 f"not {type(value)}"
132 )
133 logger.error(msg)
134 raise TypeError(msg)
136 fails = []
137 surveys = ListDict()
138 if isinstance(value, (dict, ListDict, OrderedDict)):
139 value_list = value.values()
141 elif isinstance(value, (list, tuple)):
142 value_list = value
144 for ii, survey in enumerate(value_list):
145 if isinstance(survey, (dict, OrderedDict)):
146 s = Survey()
147 s.from_dict(survey)
148 surveys.append(s)
149 elif not isinstance(survey, Survey):
150 msg = f"Item {ii} is not type(Survey); type={type(survey)}"
151 fails.append(msg)
152 logger.error(msg)
153 else:
154 surveys.append(survey)
155 if len(fails) > 0:
156 raise TypeError("\n".join(fails))
157 return surveys
159 @property
160 def survey_names(self) -> list[str]:
161 """Return names of surveys in experiment"""
162 return self.surveys.keys()
164 def has_survey(self, survey_id: str) -> bool:
165 """
166 Has survey id
168 :param survey_id: DESCRIPTION
169 :type survey_id: TYPE
170 :return: DESCRIPTION
171 :rtype: TYPE
173 """
174 if survey_id in self.survey_names:
175 return True
176 return False
178 def survey_index(self, survey_id: str) -> int | None:
179 """
180 Get survey index
182 :param survey_id: DESCRIPTION
183 :type survey_id: TYPE
184 :return: DESCRIPTION
185 :rtype: TYPE
187 """
189 if self.has_survey(survey_id):
190 return self.survey_names.index(survey_id)
191 return None
193 def add_survey(self, survey_obj: "Survey") -> None:
194 """
195 Add a survey, if has the same name update that object.
197 :param survey_obj: DESCRIPTION
198 :type survey_obj: `:class:`mt_metadata.timeseries.Survey`
199 :return: DESCRIPTION
200 :rtype: TYPE
202 """
204 if not isinstance(survey_obj, Survey):
205 raise TypeError(
206 f"Input must be a mt_metadata.timeseries.Survey object not {type(survey_obj)}"
207 )
209 if self.has_survey(survey_obj.id):
210 self.surveys[survey_obj.id].update(survey_obj)
211 logger.debug(f"survey {survey_obj.id} already exists, updating metadata")
212 else:
213 self.surveys.append(survey_obj)
215 def get_survey(self, survey_id: str) -> "Survey":
216 """
217 Get a survey from the survey id
219 :param survey_id: DESCRIPTION
220 :type survey_id: TYPE
221 :return: DESCRIPTION
222 :rtype: TYPE
224 """
226 if self.has_survey(survey_id):
227 return self.surveys[survey_id]
228 else:
229 logger.warning(f"Could not find survey {survey_id}")
230 return None
232 def remove_survey(self, survey_id: str, update: bool = True) -> None:
233 """
234 Remove a survey from the experiment
236 :param survey_id: DESCRIPTION
237 :type survey_id: TYPE
238 :return: DESCRIPTION
239 :rtype: TYPE
241 """
243 if self.has_survey(survey_id):
244 self.surveys.remove(survey_id)
245 logger.debug(f"Removed survey {survey_id} from experiment")
247 else:
248 logger.warning(f"Could not find survey {survey_id} to remove")
250 def to_dict(self, nested: bool = False, required: bool = True) -> dict:
251 """
252 create a dictionary for the experiment object.
254 :param nested: DESCRIPTION, defaults to False
255 :type nested: TYPE, optional
256 :param single: DESCRIPTION, defaults to False
257 :type single: TYPE, optional
258 :param required: DESCRIPTION, defaults to True
259 :type required: TYPE, optional
260 :return: DESCRIPTION
261 :rtype: TYPE
263 """
265 kwargs = {"nested": nested, "single": True, "required": required}
267 ex_dict = {"experiment": {"surveys": []}}
268 for survey in self.surveys:
269 survey_dict = survey.to_dict(**kwargs)
270 survey_dict["stations"] = []
271 survey_dict["filters"] = []
272 for station in survey.stations:
273 station_dict = station.to_dict(**kwargs)
274 station_dict["runs"] = []
275 for run in station.runs:
276 run_dict = run.to_dict(**kwargs)
277 run_dict["channels"] = []
278 for channel in run.channels:
279 run_dict["channels"].append(channel.to_dict(**kwargs))
280 station_dict["runs"].append(run_dict)
281 survey_dict["stations"].append(station_dict)
282 for f_key, f_object in survey.filters.items():
283 survey_dict["filters"].append(f_object.to_dict(**kwargs))
284 ex_dict["experiment"]["surveys"].append(survey_dict)
286 return ex_dict
288 def from_dict(self, ex_dict: dict | OrderedDict, skip_none: bool = True) -> None:
289 """
290 fill from an input dictionary
292 :param ex_dict: DESCRIPTION
293 :type ex_dict: TYPE
294 :return: DESCRIPTION
295 :rtype: TYPE
297 """
299 if not isinstance(ex_dict, dict):
300 msg = f"experiemnt input must be a dictionary not {type(ex_dict)}"
301 logger.debug(msg)
302 raise TypeError(msg)
303 if "experiment" not in ex_dict.keys():
304 return
306 for survey_dict in ex_dict["experiment"]["surveys"]:
307 survey_object = Survey()
308 survey_object.from_dict(survey_dict, skip_none=skip_none)
309 self.add_survey(survey_object)
311 def to_json(
312 self,
313 fn: str | Path = None,
314 nested: bool = False,
315 indent: str = " " * 4,
316 required: bool = True,
317 ) -> str | None:
318 """
319 Write a json string from a given object, taking into account other
320 class objects contained within the given object.
322 :param nested: make the returned json nested
323 :type nested: [ True | False ] , default is False
325 """
327 if fn is not None:
328 with open(fn, "w") as fid:
329 json.dump(
330 self.to_dict(nested=nested, required=required),
331 fid,
332 cls=helpers.NumpyEncoder,
333 indent=indent,
334 )
336 else:
337 return json.dumps(
338 self.to_dict(nested=nested, required=required),
339 cls=helpers.NumpyEncoder,
340 indent=indent,
341 )
343 def from_json(self, json_str: str, skip_none: bool = True) -> None:
344 """
345 read in a json string and update attributes of an object
347 :param json_str: json string or file path
348 :type json_str: string or :class:`pathlib.Path`
350 """
351 if isinstance(json_str, str):
352 try:
353 json_path = Path(json_str)
354 if json_path.exists():
355 with open(json_path, "r") as fid:
356 json_dict = json.load(fid)
357 except OSError:
358 pass
359 json_dict = json.loads(json_str)
360 elif isinstance(json_str, Path):
361 if json_str.exists():
362 with open(json_str, "r") as fid:
363 json_dict = json.load(fid)
364 elif not isinstance(json_str, (str, Path)):
365 msg = "Input must be valid JSON string not %"
366 logger.error(msg, type(json_str))
367 raise TypeError(msg % type(json_str))
368 self.from_dict(json_dict, skip_none=skip_none)
370 def to_xml(
371 self, fn: str | Path = None, required: bool = True, sort: bool = True
372 ) -> et.Element:
373 """
374 Write XML version of the experiment
376 :param fn: DESCRIPTION
377 :type fn: TYPE
378 :return: DESCRIPTION
379 :rtype: TYPE
381 """
383 experiment_element = et.Element(self.__class__.__name__)
384 if sort:
385 self.surveys.sort()
386 for survey in self.surveys:
387 survey.update_bounding_box()
388 survey.update_time_period()
389 survey_element = survey.to_xml(required=required)
390 filter_element = et.SubElement(survey_element, "filters")
391 for key, value in survey.filters.items():
392 filter_element.append(value.to_xml(required=required))
393 if sort:
394 survey.stations.sort()
395 for station in survey.stations:
396 station.update_time_period()
397 station_element = station.to_xml(required=required)
398 if sort:
399 station.runs.sort()
400 for run in station.runs:
401 run.update_time_period()
402 run_element = run.to_xml(required=required)
403 if sort:
404 run.channels.sort()
405 for channel in run.channels:
406 if channel.type in ["electric"]:
407 if (
408 channel.positive.latitude == 0
409 and channel.positive.longitude == 0
410 and channel.positive.elevation == 0
411 ):
412 channel.positive.latitude = station.location.latitude
413 channel.positive.longitude = station.location.longitude
414 channel.positive.elevation = station.location.elevation
415 else:
416 if (
417 channel.location.latitude == 0
418 and channel.location.longitude == 0
419 and channel.location.elevation == 0
420 ):
421 channel.location.latitude = station.location.latitude
422 channel.location.longitude = station.location.longitude
423 channel.location.elevation = station.location.elevation
425 run_element.append(channel.to_xml(required=required))
426 station_element.append(run_element)
427 survey_element.append(station_element)
428 experiment_element.append(survey_element)
430 if fn:
431 with open(fn, "w") as fid:
432 fid.write(helpers.element_to_string(experiment_element))
433 return experiment_element
435 def from_xml(
436 self,
437 fn: str | Path = None,
438 element: et.Element | None = None,
439 sort: bool = True,
440 skip_none: bool = True,
441 ) -> None:
442 """
444 :param fn: DESCRIPTION, defaults to None
445 :type fn: TYPE, optional
446 :param element: DESCRIPTION, defaults to None
447 :type element: TYPE, optional
448 :return: DESCRIPTION
449 :rtype: TYPE
453 """
454 if fn:
455 experiment_element = et.parse(fn).getroot()
456 if element is not None:
457 experiment_element = element
459 # need to set the lists for each layer, otherwise you get duplicates.
460 for survey_element in list(experiment_element):
461 survey_dict = helpers.element_to_dict(survey_element)
462 stations = self._pop_dictionary(survey_dict["survey"], "station")
463 survey_obj = Survey()
464 survey_obj.from_dict(survey_dict, skip_none=skip_none)
465 fd = survey_dict["survey"].pop("filters")
466 filter_dict = self._read_filter_dict(fd)
467 survey_obj.filters.update(filter_dict)
469 for station_dict in stations:
470 station_obj = Station()
471 runs = self._pop_dictionary(station_dict, "run")
472 station_obj.from_dict(station_dict, skip_none=skip_none)
473 for run_dict in runs:
474 run_obj = Run()
476 for ch in ["electric", "magnetic", "auxiliary"]:
477 try:
478 for ch_dict in self._pop_dictionary(run_dict, ch):
479 if ch == "electric":
480 channel = Electric()
481 elif ch == "magnetic":
482 channel = Magnetic()
483 elif ch == "auxiliary":
484 channel = Auxiliary()
485 channel.from_dict(ch_dict, skip_none=skip_none)
486 run_obj.add_channel(channel)
487 except KeyError:
488 logger.debug(f"Could not find channel {ch}")
489 run_obj.from_dict(run_dict, skip_none=skip_none)
490 station_obj.add_run(run_obj)
491 survey_obj.add_station(station_obj)
492 self.add_survey(survey_obj)
494 if sort:
495 self.sort()
497 def _pop_dictionary(self, in_dict: dict, element: str) -> list:
498 """
499 Pop off a key from an input dictionary, make sure output is a list
501 :param in_dict: DESCRIPTION
502 :type in_dict: TYPE
503 :param element: DESCRIPTION
504 :type element: TYPE
505 :return: DESCRIPTION
506 :rtype: TYPE
508 """
510 elements = in_dict.pop(element)
511 if not isinstance(elements, list):
512 elements = [elements]
514 return elements
516 def to_pickle(self, fn: str | Path = None) -> None:
517 """
518 Write a pickle version of the experiment
520 :param fn: DESCRIPTION
521 :type fn: TYPE
522 :return: DESCRIPTION
523 :rtype: TYPE
525 """
527 def from_pickle(self, fn: str | Path = None) -> None:
528 """
529 Read pickle version of experiment
531 :param fn: DESCRIPTION
532 :type fn: TYPE
533 :return: DESCRIPTION
534 :rtype: TYPE
536 """
538 # def validate_experiment(self):
539 # """
540 # Validate experiment is legal
542 # :return: DESCRIPTION
543 # :rtype: TYPE
545 # """
546 # pass
548 def _read_filter_dict(self, filters_dict: dict | None) -> ListDict:
549 """
550 Read in filter element an put it in the correct object
552 :param filter_element: DESCRIPTION
553 :type filter_element: TYPE
554 :return: DESCRIPTION
555 :rtype: TYPE
557 """
558 return_dict = ListDict()
559 if filters_dict is None:
560 return return_dict
562 for key, value in filters_dict.items():
563 if key in ["pole_zero_filter"]:
564 if isinstance(value, list):
565 for v in value:
566 mt_filter = PoleZeroFilter(**v)
567 return_dict[mt_filter.name.lower()] = mt_filter
568 else:
569 mt_filter = PoleZeroFilter(value)
570 return_dict[mt_filter.name.lower()] = mt_filter
572 elif key in ["coefficient_filter"]:
573 if isinstance(value, list):
574 for v in value:
575 mt_filter = CoefficientFilter(**v)
576 return_dict[mt_filter.name.lower()] = mt_filter
577 else:
578 mt_filter = CoefficientFilter(value)
579 return_dict[mt_filter.name.lower()] = mt_filter
581 elif key in ["time_delay_filter"]:
582 if isinstance(value, list):
583 for v in value:
584 mt_filter = TimeDelayFilter(**v)
585 return_dict[mt_filter.name.lower()] = mt_filter
586 else:
587 mt_filter = TimeDelayFilter(value)
588 return_dict[mt_filter.name.lower()] = mt_filter
590 elif key in ["frequency_response_table_filter"]:
591 if isinstance(value, list):
592 for v in value:
593 mt_filter = FrequencyResponseTableFilter(**v)
594 return_dict[mt_filter.name.lower()] = mt_filter
595 else:
596 mt_filter = FrequencyResponseTableFilter(value)
597 return_dict[mt_filter.name.lower()] = mt_filter
599 elif key in ["fir_filter"]:
600 if isinstance(value, list):
601 for v in value:
602 mt_filter = FIRFilter(**v)
603 return_dict[mt_filter.name.lower()] = mt_filter
604 else:
605 mt_filter = FIRFilter(value)
606 return_dict[mt_filter.name.lower()] = mt_filter
608 return return_dict
610 def sort(self, inplace: bool = True) -> "Experiment":
611 """
612 sort surveys, stations, runs, channels alphabetically/numerically
614 :param inplace: DESCRIPTION, defaults to True
615 :type inplace: TYPE, optional
616 :return: DESCRIPTION
617 :rtype: TYPE
619 """
621 if inplace:
622 self.surveys.sort()
623 for survey in self.surveys:
624 survey.stations.sort()
625 for station in survey.stations:
626 station.runs.sort()
627 for run in station.runs:
628 run.channels.sort()
630 else:
631 ex = Experiment()
632 ex.from_dict(self.to_dict())
633 ex.sort()
634 return ex