Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ processing \ aurora \ run.py: 100%
68 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
1# =====================================================
2# Imports
3# =====================================================
4from typing import Annotated, Union
6from pydantic import computed_field, Field, field_validator, ValidationInfo
8from mt_metadata.base import MetadataBase
9from mt_metadata.common import TimePeriod
10from mt_metadata.processing.aurora.channel import Channel
13# =====================================================
14class Run(MetadataBase):
15 id: Annotated[
16 str,
17 Field(
18 default="",
19 description="run ID",
20 alias=None,
21 json_schema_extra={
22 "units": None,
23 "required": True,
24 "examples": ["001"],
25 },
26 ),
27 ]
29 input_channels: Annotated[
30 list[Channel],
31 Field(
32 default_factory=list,
33 description="List of input channels (source)",
34 alias=None,
35 json_schema_extra={
36 "units": None,
37 "required": True,
38 "examples": ["hx, hy"],
39 },
40 ),
41 ]
43 output_channels: Annotated[
44 list[Channel],
45 Field(
46 default_factory=list,
47 description="List of output channels (response)",
48 alias=None,
49 json_schema_extra={
50 "units": None,
51 "required": True,
52 "examples": ["ex, ey, hz"],
53 },
54 ),
55 ]
57 time_periods: Annotated[
58 list[TimePeriod],
59 Field(
60 default_factory=list,
61 description="List of time periods to process",
62 alias=None,
63 json_schema_extra={
64 "units": None,
65 "required": True,
66 "examples": [
67 "[{'start': '2020-01-01T00:00:00', 'end': '2020-01-01T01:00:00'}]"
68 ],
69 },
70 ),
71 ]
73 sample_rate: Annotated[
74 float,
75 Field(
76 default=1.0,
77 description="sample rate of the run",
78 alias=None,
79 json_schema_extra={
80 "units": "samples per second",
81 "required": True,
82 "examples": ["1"],
83 },
84 ),
85 ]
87 @field_validator("input_channels", "output_channels", mode="before")
88 @classmethod
89 def validate_channel_list(
90 cls, values: Union[list, str, Channel, dict], info: ValidationInfo
91 ) -> list[Channel]:
92 channels = []
93 if not isinstance(values, list):
94 values = [values]
96 for item in values:
97 if isinstance(item, str):
98 ch = Channel(id=item)
99 elif isinstance(item, Channel):
100 ch = item
102 elif isinstance(item, dict):
103 ch = Channel()
104 ch.from_dict(item)
106 else:
107 raise TypeError(f"not sure what to do with type {type(item)}")
109 channels.append(ch)
111 return channels
113 @field_validator("time_periods", mode="before")
114 @classmethod
115 def validate_time_periods(
116 cls, values: Union[list, dict, TimePeriod], info: ValidationInfo
117 ) -> list[TimePeriod]:
118 time_periods = []
119 if not isinstance(values, list):
120 values = [values]
122 for item in values:
123 if isinstance(item, TimePeriod):
124 tp = item
126 elif isinstance(item, dict):
127 tp = TimePeriod()
128 tp.from_dict(item)
130 else:
131 raise TypeError(f"not sure what to do with type {type(item)}")
133 time_periods.append(tp)
135 return time_periods
137 @computed_field
138 @property
139 def channel_scale_factors(self) -> dict[str, float]:
140 scale_factors = {}
141 for ch in self.input_channels + self.output_channels:
142 if ch.scale_factor is not None:
143 scale_factors[ch.id] = ch.scale_factor
144 return scale_factors
146 def set_channel_scale_factors(self, values: Union[dict, float]):
147 """
148 Validate and process channel scale factors.
150 Parameters
151 ----------
152 values : Union[dict, float]
153 The scale factors for the channels.
155 Raises
156 ------
157 TypeError
158 If the input is not a dictionary or float.
159 """
160 if not isinstance(values, dict):
161 raise TypeError(f"not sure what to do with type {type(values)}")
162 for i, channel in enumerate(self.input_channels):
163 if channel.id in values.keys():
164 self.input_channels[i].scale_factor = values[channel.id]
165 for i, channel in enumerate(self.output_channels):
166 if channel.id in values.keys():
167 self.output_channels[i].scale_factor = values[channel.id]
169 @computed_field
170 @property
171 def input_channel_names(self) -> list[str]:
172 """list of channel names"""
173 return [ch.id for ch in self.input_channels]
175 @computed_field
176 @property
177 def output_channel_names(self) -> list[str]:
178 """list of channel names"""
179 return [ch.id for ch in self.output_channels]