Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ transfer_functions \ io \ emtfxml \ metadata \ data.py: 97%
143 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
1# -*- coding: utf-8 -*-
2"""
3Created on Mon Sep 6 13:53:55 2021
5@author: jpeacock
6"""
7from typing import Annotated, ClassVar
8from xml.etree import cElementTree as et
10# =============================================================================
11# Imports
12# =============================================================================
13import numpy as np
14from loguru import logger
15from pydantic import computed_field, Field, field_validator, ValidationInfo
17from mt_metadata.base import MetadataBase
18from mt_metadata.base.helpers import element_to_string
21# =============================================================================
24class TransferFunction(MetadataBase):
25 """
26 Deal with the complex XML format
27 """
29 _index_dict: ClassVar[dict] = {"hx": 0, "hy": 1, "ex": 0, "ey": 1, "hz": 0}
30 _dtype_dict: ClassVar[dict] = {
31 "complex": complex,
32 "real": float,
33 "complex128": "complex",
34 "float64": "real",
35 }
36 _units_dict: ClassVar[dict] = {"z": "[mV/km]/[nT]", "t": "[]"}
37 _name_dict: ClassVar[dict] = {
38 "exhx": "zxx",
39 "exhy": "zxy",
40 "eyhx": "zyx",
41 "eyhy": "zyy",
42 "hzhx": "tx",
43 "hzhy": "ty",
44 }
46 _array_dtypes_dict: ClassVar[dict] = {
47 "period": float,
48 "z": complex,
49 "z_var": float,
50 "z_invsigcov": complex,
51 "z_residcov": complex,
52 "t": complex,
53 "t_var": float,
54 "t_invsigcov": complex,
55 "t_residcov": complex,
56 }
58 period: Annotated[
59 np.typing.NDArray[np.float64] | None,
60 Field(
61 default_factory=lambda: np.empty((0,), dtype=np.float64),
62 description="periods for estimates",
63 alias=None,
64 json_schema_extra={
65 "units": "second",
66 "required": True,
67 "examples": ["0.01", "0.1", "1.0"],
68 },
69 ),
70 ]
71 z: Annotated[
72 np.typing.NDArray[np.complex128] | None,
73 Field(
74 default_factory=lambda: np.empty((0, 2, 2), dtype=np.complex128),
75 description="Estimates of the impedance tensor.",
76 json_schema_extra={
77 "units": "[mV/km]/[nT]",
78 "required": False,
79 "examples": ["1.0+0.0j", "0.5+0.5j"],
80 },
81 ),
82 ]
84 z_var: Annotated[
85 np.typing.NDArray[np.float64] | None,
86 Field(
87 default_factory=lambda: np.empty((0, 2, 2), dtype=np.float64),
88 description="Variance estimates for the impedance tensor.",
89 json_schema_extra={
90 "units": None,
91 "required": False,
92 "examples": ["0.01", "0.1", "1.0"],
93 },
94 ),
95 ]
97 z_invsigcov: Annotated[
98 np.typing.NDArray[np.complex128] | None,
99 Field(
100 default_factory=lambda: np.empty((0, 2, 2), dtype=np.complex128),
101 description="Inverse of the covariance matrix for the impedance tensor.",
102 json_schema_extra={
103 "units": None,
104 "required": False,
105 "examples": ["1.0+0.0j", "0.5+0.5j"],
106 },
107 ),
108 ]
109 z_residcov: Annotated[
110 np.typing.NDArray[np.complex128] | None,
111 Field(
112 default_factory=lambda: np.empty((0, 2, 2), dtype=np.complex128),
113 description="Residual covariance matrix for the impedance tensor.",
114 json_schema_extra={
115 "units": None,
116 "required": False,
117 "examples": ["1.0+0.0j", "0.5+0.5j"],
118 },
119 ),
120 ]
121 t: Annotated[
122 np.typing.NDArray[np.complex128] | None,
123 Field(
124 default_factory=lambda: np.empty((0, 1, 2), dtype=np.complex128),
125 description="Estimates of the tipper tensor.",
126 json_schema_extra={
127 "units": "[]",
128 "required": False,
129 "examples": ["1.0+0.0j", "0.5+0.5j"],
130 },
131 ),
132 ]
133 t_var: Annotated[
134 np.typing.NDArray[np.float64] | None,
135 Field(
136 default_factory=lambda: np.empty((0, 1, 2), dtype=np.float64),
137 description="Variance estimates for the tipper tensor.",
138 json_schema_extra={
139 "units": None,
140 "required": False,
141 "examples": ["0.01", "0.1", "1.0"],
142 },
143 ),
144 ]
145 t_invsigcov: Annotated[
146 np.typing.NDArray[np.complex128] | None,
147 Field(
148 default_factory=lambda: np.empty((0, 2, 2), dtype=np.complex128),
149 description="Inverse of the covariance matrix for the tipper tensor.",
150 json_schema_extra={
151 "units": None,
152 "required": False,
153 "examples": ["1.0+0.0j", "0.5+0.5j"],
154 },
155 ),
156 ]
157 t_residcov: Annotated[
158 np.typing.NDArray[np.complex128] | None,
159 Field(
160 default_factory=lambda: np.empty((0, 1, 1), dtype=np.complex128),
161 description="Residual covariance matrix for the tipper tensor.",
162 json_schema_extra={
163 "units": None,
164 "required": False,
165 "examples": ["1.0+0.0j", "0.5+0.5j"],
166 },
167 ),
168 ]
170 _write_dict: ClassVar[dict] = {
171 "z": {"out": {0: "ex", 1: "ey"}, "in": {0: "hx", 1: "hy"}},
172 "z_var": {"out": {0: "ex", 1: "ey"}, "in": {0: "hx", 1: "hy"}},
173 "z_invsigcov": {
174 "out": {0: "hx", 1: "hy"},
175 "in": {0: "hx", 1: "hy"},
176 },
177 "z_residcov": {
178 "out": {0: "ex", 1: "ey"},
179 "in": {0: "ex", 1: "ey"},
180 },
181 "t": {"out": {0: "hz"}, "in": {0: "hx", 1: "hy"}},
182 "t_var": {"out": {0: "hz"}, "in": {0: "hx", 1: "hy"}},
183 "t_invsigcov": {
184 "out": {0: "hx", 1: "hy"},
185 "in": {0: "hx", 1: "hy"},
186 },
187 "t_residcov": {"out": {0: "hz"}, "in": {0: "hz"}},
188 }
190 _skip_derived_data: ClassVar[bool] = True
191 _derived_keys: ClassVar[list] = [
192 "rho",
193 "rho_var",
194 "phs",
195 "phs_var",
196 "tipphs",
197 "tipphs_var",
198 "tipmag",
199 "tipmag_var",
200 "zstrike",
201 "zstrike_var",
202 "zskew",
203 "zskew_var",
204 "zellip",
205 "zellip_var",
206 "tstrike",
207 "tstrike_var",
208 "tskew",
209 "tskew_var",
210 "tellip",
211 "tellip_var",
212 "indmag",
213 "indmag_var",
214 "indang",
215 "indang_var",
216 ]
218 @field_validator(
219 "period",
220 "z",
221 "z_var",
222 "z_invsigcov",
223 "z_residcov",
224 "t",
225 "t_var",
226 "t_invsigcov",
227 "t_residcov",
228 mode="before",
229 )
230 @classmethod
231 def validate_array(cls, value, info: ValidationInfo) -> np.ndarray | None:
232 """
233 Validate that the value is a numpy array or None.
234 """
235 if value is None:
236 return None
237 if isinstance(value, (list, tuple, np.ndarray)):
238 return np.array(value, dtype=cls._array_dtypes_dict[info.field_name])
239 else:
240 msg = (
241 f"input values must be an list, tuple, or np.ndarray, not {type(value)}"
242 )
243 logger.error(msg)
244 raise TypeError(msg)
246 def initialize_arrays(self, n_periods: int) -> None:
247 """Initialize arrays for the transfer function data.
249 :param n_periods: number of periods
250 :type n_periods: int
251 :return: None
252 :rtype: None
253 """
254 self.period = np.zeros(n_periods)
255 self.z = np.zeros((n_periods, 2, 2), dtype=self._array_dtypes_dict["z"])
256 self.z_var = np.zeros_like(self.z, dtype=self._array_dtypes_dict["z_var"])
257 self.z_invsigcov = np.zeros_like(
258 self.z, dtype=self._array_dtypes_dict["z_invsigcov"]
259 )
260 self.z_residcov = np.zeros_like(
261 self.z, dtype=self._array_dtypes_dict["z_residcov"]
262 )
263 self.t = np.zeros((n_periods, 1, 2), dtype=self._array_dtypes_dict["t"])
264 self.t_var = np.zeros_like(self.t, dtype=self._array_dtypes_dict["t_var"])
265 self.t_invsigcov = np.zeros(
266 (n_periods, 2, 2), dtype=self._array_dtypes_dict["t_invsigcov"]
267 )
268 self.t_residcov = np.zeros(
269 (n_periods, 1, 1), dtype=self._array_dtypes_dict["t_residcov"]
270 )
272 @computed_field
273 @property
274 def array_dict(self) -> dict:
275 return {
276 "z": self.z,
277 "z_var": self.z_var,
278 "z_invsigcov": self.z_invsigcov,
279 "z_residcov": self.z_residcov,
280 "t": self.t,
281 "t_var": self.t_var,
282 "t_invsigcov": self.t_invsigcov,
283 "t_residcov": self.t_residcov,
284 }
286 @computed_field
287 @property
288 def n_periods(self) -> int:
289 if self.period is not None:
290 return self.period.size
291 return 0
293 def read_block(self, block: dict, period_index: int) -> None:
294 """
295 Read a period block which is root_dict["data"]["period"][ii]
297 :param block: read a period block
298 :type block: dict
299 :param period_index: index of the period in the data
300 :type period_index: int
301 :return: None
302 :rtype: None
304 """
306 for key in block.keys():
307 comp = key.replace("_", "").replace(".", "_")
308 if comp in ["value"]:
309 continue
310 elif self._skip_derived_data:
311 if comp in self._derived_keys:
312 continue
313 try:
314 dtype = self._dtype_dict[block[key]["type"]]
315 except KeyError:
316 dtype = "unknown"
318 try:
319 value_list = block[key]["value"]
320 except KeyError:
321 logger.debug("No value for %s at period index %s", comp, period_index)
322 continue
324 if not isinstance(value_list, list):
325 value_list = [value_list]
326 for item in value_list:
327 index_0 = self._index_dict[item["output"].lower()]
328 index_1 = self._index_dict[item["input"].lower()]
329 if dtype is complex:
330 value = item["value"].split()
331 value = complex(float(value[0]), float(value[1]))
332 elif dtype in (float, int):
333 value = dtype(item["value"])
334 elif dtype in ["unknown"]:
335 value = item["value"].split()
336 if len(value) > 1:
337 value = complex(float(value[0]), float(value[1]))
338 else:
339 value = float(value[0])
341 self.array_dict[comp][period_index, index_0, index_1] = value
343 def read_dict(self, root_dict: dict) -> None:
344 """
345 read root_dict["data"]
346 This is the main data block for the transfer function data.
347 :param root_dict: dictionary containing the transfer function data
348 :type root_dict: dict
349 :return: None
350 :rtype: None
352 """
353 if self._skip_derived_data:
354 logger.debug("Skipping derived quantities.")
355 try:
356 n_periods = int(float((root_dict["data"]["count"].strip())))
357 except KeyError:
358 n_periods = len(root_dict["data"]["period"])
360 self.initialize_arrays(n_periods)
361 for ii, block in enumerate(root_dict["data"]["period"]):
362 self.period[ii] = float(block["value"]) # type: ignore[assignment]
363 self.read_block(block, ii)
365 def write_block(self, parent: et.Element, index: int) -> et.Element:
366 """
367 Write a data block
369 :param parent: DESCRIPTION
370 :type parent: TYPE
371 :return: DESCRIPTION
372 :rtype: TYPE
374 """
376 period_element = et.SubElement(
377 parent,
378 "Period",
379 {"value": f"{self.period[index]:.12e}", "units": "secs"}, # type: ignore[arg-type]
380 )
382 for key in self.array_dict.keys():
383 if self.array_dict[key] is None:
384 continue
385 if self.array_dict[key].size == 0:
386 logger.debug(f"No data for {key}, skipping.")
387 continue
388 arr = np.nan_to_num(self.array_dict[key][index])
390 # set zeros to empty value of 1E32
391 if arr.dtype == complex:
392 arr[np.where(arr == 0)] = 1e32 + 1e32j
393 else:
394 arr[np.where(arr == 0)] = 1e32
396 attr_dict = {
397 "type": self._dtype_dict[arr.dtype.name],
398 "size": str(arr.shape)[1:-1].replace(",", ""),
399 }
400 try:
401 attr_dict["units"] = self._units_dict[key]
402 except KeyError:
403 pass
405 comp_element = et.SubElement(
406 period_element, key.replace("_", ".").upper(), attr_dict
407 )
408 idx_dict = self._write_dict[key]
409 shape = arr.shape
410 for ii in range(shape[0]):
411 for jj in range(shape[1]):
412 ch_out = idx_dict["out"][ii]
413 ch_in = idx_dict["in"][jj]
414 a_dict = {}
415 try:
416 a_dict["name"] = self._name_dict[ch_out + ch_in].capitalize()
417 except KeyError:
418 pass
419 a_dict["output"] = ch_out.capitalize()
420 a_dict["input"] = ch_in.capitalize()
421 ch_element = et.SubElement(comp_element, "value", a_dict)
422 ch_value = f"{arr[ii, jj].real:.6e}"
423 if attr_dict["type"] in ["complex"]:
424 ch_value = f"{ch_value} {arr[ii, jj].imag:.6e}"
425 ch_element.text = ch_value
427 return period_element
429 def to_xml(self, string: bool = False, required: bool = True) -> et.Element | str:
430 """
431 Write data blocks
433 :param parent: DESCRIPTION
434 :type parent: TYPE
435 :return: DESCRIPTION
436 :rtype: TYPE
438 """
439 root = et.Element("Data", {"count": f"{self.n_periods:.0f}"})
441 for index in range(self.period.size): # type: ignore[attribute-error]
442 self.write_block(root, index)
444 if string:
445 return element_to_string(root)
446 return root