Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ transfer_functions \ io \ emtfxml \ metadata \ data.py: 97%

143 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1# -*- coding: utf-8 -*- 

2""" 

3Created on Mon Sep 6 13:53:55 2021 

4 

5@author: jpeacock 

6""" 

7from typing import Annotated, ClassVar 

8from xml.etree import cElementTree as et 

9 

10# ============================================================================= 

11# Imports 

12# ============================================================================= 

13import numpy as np 

14from loguru import logger 

15from pydantic import computed_field, Field, field_validator, ValidationInfo 

16 

17from mt_metadata.base import MetadataBase 

18from mt_metadata.base.helpers import element_to_string 

19 

20 

21# ============================================================================= 

22 

23 

24class TransferFunction(MetadataBase): 

25 """ 

26 Deal with the complex XML format 

27 """ 

28 

29 _index_dict: ClassVar[dict] = {"hx": 0, "hy": 1, "ex": 0, "ey": 1, "hz": 0} 

30 _dtype_dict: ClassVar[dict] = { 

31 "complex": complex, 

32 "real": float, 

33 "complex128": "complex", 

34 "float64": "real", 

35 } 

36 _units_dict: ClassVar[dict] = {"z": "[mV/km]/[nT]", "t": "[]"} 

37 _name_dict: ClassVar[dict] = { 

38 "exhx": "zxx", 

39 "exhy": "zxy", 

40 "eyhx": "zyx", 

41 "eyhy": "zyy", 

42 "hzhx": "tx", 

43 "hzhy": "ty", 

44 } 

45 

46 _array_dtypes_dict: ClassVar[dict] = { 

47 "period": float, 

48 "z": complex, 

49 "z_var": float, 

50 "z_invsigcov": complex, 

51 "z_residcov": complex, 

52 "t": complex, 

53 "t_var": float, 

54 "t_invsigcov": complex, 

55 "t_residcov": complex, 

56 } 

57 

58 period: Annotated[ 

59 np.typing.NDArray[np.float64] | None, 

60 Field( 

61 default_factory=lambda: np.empty((0,), dtype=np.float64), 

62 description="periods for estimates", 

63 alias=None, 

64 json_schema_extra={ 

65 "units": "second", 

66 "required": True, 

67 "examples": ["0.01", "0.1", "1.0"], 

68 }, 

69 ), 

70 ] 

71 z: Annotated[ 

72 np.typing.NDArray[np.complex128] | None, 

73 Field( 

74 default_factory=lambda: np.empty((0, 2, 2), dtype=np.complex128), 

75 description="Estimates of the impedance tensor.", 

76 json_schema_extra={ 

77 "units": "[mV/km]/[nT]", 

78 "required": False, 

79 "examples": ["1.0+0.0j", "0.5+0.5j"], 

80 }, 

81 ), 

82 ] 

83 

84 z_var: Annotated[ 

85 np.typing.NDArray[np.float64] | None, 

86 Field( 

87 default_factory=lambda: np.empty((0, 2, 2), dtype=np.float64), 

88 description="Variance estimates for the impedance tensor.", 

89 json_schema_extra={ 

90 "units": None, 

91 "required": False, 

92 "examples": ["0.01", "0.1", "1.0"], 

93 }, 

94 ), 

95 ] 

96 

97 z_invsigcov: Annotated[ 

98 np.typing.NDArray[np.complex128] | None, 

99 Field( 

100 default_factory=lambda: np.empty((0, 2, 2), dtype=np.complex128), 

101 description="Inverse of the covariance matrix for the impedance tensor.", 

102 json_schema_extra={ 

103 "units": None, 

104 "required": False, 

105 "examples": ["1.0+0.0j", "0.5+0.5j"], 

106 }, 

107 ), 

108 ] 

109 z_residcov: Annotated[ 

110 np.typing.NDArray[np.complex128] | None, 

111 Field( 

112 default_factory=lambda: np.empty((0, 2, 2), dtype=np.complex128), 

113 description="Residual covariance matrix for the impedance tensor.", 

114 json_schema_extra={ 

115 "units": None, 

116 "required": False, 

117 "examples": ["1.0+0.0j", "0.5+0.5j"], 

118 }, 

119 ), 

120 ] 

121 t: Annotated[ 

122 np.typing.NDArray[np.complex128] | None, 

123 Field( 

124 default_factory=lambda: np.empty((0, 1, 2), dtype=np.complex128), 

125 description="Estimates of the tipper tensor.", 

126 json_schema_extra={ 

127 "units": "[]", 

128 "required": False, 

129 "examples": ["1.0+0.0j", "0.5+0.5j"], 

130 }, 

131 ), 

132 ] 

133 t_var: Annotated[ 

134 np.typing.NDArray[np.float64] | None, 

135 Field( 

136 default_factory=lambda: np.empty((0, 1, 2), dtype=np.float64), 

137 description="Variance estimates for the tipper tensor.", 

138 json_schema_extra={ 

139 "units": None, 

140 "required": False, 

141 "examples": ["0.01", "0.1", "1.0"], 

142 }, 

143 ), 

144 ] 

145 t_invsigcov: Annotated[ 

146 np.typing.NDArray[np.complex128] | None, 

147 Field( 

148 default_factory=lambda: np.empty((0, 2, 2), dtype=np.complex128), 

149 description="Inverse of the covariance matrix for the tipper tensor.", 

150 json_schema_extra={ 

151 "units": None, 

152 "required": False, 

153 "examples": ["1.0+0.0j", "0.5+0.5j"], 

154 }, 

155 ), 

156 ] 

157 t_residcov: Annotated[ 

158 np.typing.NDArray[np.complex128] | None, 

159 Field( 

160 default_factory=lambda: np.empty((0, 1, 1), dtype=np.complex128), 

161 description="Residual covariance matrix for the tipper tensor.", 

162 json_schema_extra={ 

163 "units": None, 

164 "required": False, 

165 "examples": ["1.0+0.0j", "0.5+0.5j"], 

166 }, 

167 ), 

168 ] 

169 

170 _write_dict: ClassVar[dict] = { 

171 "z": {"out": {0: "ex", 1: "ey"}, "in": {0: "hx", 1: "hy"}}, 

172 "z_var": {"out": {0: "ex", 1: "ey"}, "in": {0: "hx", 1: "hy"}}, 

173 "z_invsigcov": { 

174 "out": {0: "hx", 1: "hy"}, 

175 "in": {0: "hx", 1: "hy"}, 

176 }, 

177 "z_residcov": { 

178 "out": {0: "ex", 1: "ey"}, 

179 "in": {0: "ex", 1: "ey"}, 

180 }, 

181 "t": {"out": {0: "hz"}, "in": {0: "hx", 1: "hy"}}, 

182 "t_var": {"out": {0: "hz"}, "in": {0: "hx", 1: "hy"}}, 

183 "t_invsigcov": { 

184 "out": {0: "hx", 1: "hy"}, 

185 "in": {0: "hx", 1: "hy"}, 

186 }, 

187 "t_residcov": {"out": {0: "hz"}, "in": {0: "hz"}}, 

188 } 

189 

190 _skip_derived_data: ClassVar[bool] = True 

191 _derived_keys: ClassVar[list] = [ 

192 "rho", 

193 "rho_var", 

194 "phs", 

195 "phs_var", 

196 "tipphs", 

197 "tipphs_var", 

198 "tipmag", 

199 "tipmag_var", 

200 "zstrike", 

201 "zstrike_var", 

202 "zskew", 

203 "zskew_var", 

204 "zellip", 

205 "zellip_var", 

206 "tstrike", 

207 "tstrike_var", 

208 "tskew", 

209 "tskew_var", 

210 "tellip", 

211 "tellip_var", 

212 "indmag", 

213 "indmag_var", 

214 "indang", 

215 "indang_var", 

216 ] 

217 

218 @field_validator( 

219 "period", 

220 "z", 

221 "z_var", 

222 "z_invsigcov", 

223 "z_residcov", 

224 "t", 

225 "t_var", 

226 "t_invsigcov", 

227 "t_residcov", 

228 mode="before", 

229 ) 

230 @classmethod 

231 def validate_array(cls, value, info: ValidationInfo) -> np.ndarray | None: 

232 """ 

233 Validate that the value is a numpy array or None. 

234 """ 

235 if value is None: 

236 return None 

237 if isinstance(value, (list, tuple, np.ndarray)): 

238 return np.array(value, dtype=cls._array_dtypes_dict[info.field_name]) 

239 else: 

240 msg = ( 

241 f"input values must be an list, tuple, or np.ndarray, not {type(value)}" 

242 ) 

243 logger.error(msg) 

244 raise TypeError(msg) 

245 

246 def initialize_arrays(self, n_periods: int) -> None: 

247 """Initialize arrays for the transfer function data. 

248 

249 :param n_periods: number of periods 

250 :type n_periods: int 

251 :return: None 

252 :rtype: None 

253 """ 

254 self.period = np.zeros(n_periods) 

255 self.z = np.zeros((n_periods, 2, 2), dtype=self._array_dtypes_dict["z"]) 

256 self.z_var = np.zeros_like(self.z, dtype=self._array_dtypes_dict["z_var"]) 

257 self.z_invsigcov = np.zeros_like( 

258 self.z, dtype=self._array_dtypes_dict["z_invsigcov"] 

259 ) 

260 self.z_residcov = np.zeros_like( 

261 self.z, dtype=self._array_dtypes_dict["z_residcov"] 

262 ) 

263 self.t = np.zeros((n_periods, 1, 2), dtype=self._array_dtypes_dict["t"]) 

264 self.t_var = np.zeros_like(self.t, dtype=self._array_dtypes_dict["t_var"]) 

265 self.t_invsigcov = np.zeros( 

266 (n_periods, 2, 2), dtype=self._array_dtypes_dict["t_invsigcov"] 

267 ) 

268 self.t_residcov = np.zeros( 

269 (n_periods, 1, 1), dtype=self._array_dtypes_dict["t_residcov"] 

270 ) 

271 

272 @computed_field 

273 @property 

274 def array_dict(self) -> dict: 

275 return { 

276 "z": self.z, 

277 "z_var": self.z_var, 

278 "z_invsigcov": self.z_invsigcov, 

279 "z_residcov": self.z_residcov, 

280 "t": self.t, 

281 "t_var": self.t_var, 

282 "t_invsigcov": self.t_invsigcov, 

283 "t_residcov": self.t_residcov, 

284 } 

285 

286 @computed_field 

287 @property 

288 def n_periods(self) -> int: 

289 if self.period is not None: 

290 return self.period.size 

291 return 0 

292 

293 def read_block(self, block: dict, period_index: int) -> None: 

294 """ 

295 Read a period block which is root_dict["data"]["period"][ii] 

296 

297 :param block: read a period block 

298 :type block: dict 

299 :param period_index: index of the period in the data 

300 :type period_index: int 

301 :return: None 

302 :rtype: None 

303 

304 """ 

305 

306 for key in block.keys(): 

307 comp = key.replace("_", "").replace(".", "_") 

308 if comp in ["value"]: 

309 continue 

310 elif self._skip_derived_data: 

311 if comp in self._derived_keys: 

312 continue 

313 try: 

314 dtype = self._dtype_dict[block[key]["type"]] 

315 except KeyError: 

316 dtype = "unknown" 

317 

318 try: 

319 value_list = block[key]["value"] 

320 except KeyError: 

321 logger.debug("No value for %s at period index %s", comp, period_index) 

322 continue 

323 

324 if not isinstance(value_list, list): 

325 value_list = [value_list] 

326 for item in value_list: 

327 index_0 = self._index_dict[item["output"].lower()] 

328 index_1 = self._index_dict[item["input"].lower()] 

329 if dtype is complex: 

330 value = item["value"].split() 

331 value = complex(float(value[0]), float(value[1])) 

332 elif dtype in (float, int): 

333 value = dtype(item["value"]) 

334 elif dtype in ["unknown"]: 

335 value = item["value"].split() 

336 if len(value) > 1: 

337 value = complex(float(value[0]), float(value[1])) 

338 else: 

339 value = float(value[0]) 

340 

341 self.array_dict[comp][period_index, index_0, index_1] = value 

342 

343 def read_dict(self, root_dict: dict) -> None: 

344 """ 

345 read root_dict["data"] 

346 This is the main data block for the transfer function data. 

347 :param root_dict: dictionary containing the transfer function data 

348 :type root_dict: dict 

349 :return: None 

350 :rtype: None 

351 

352 """ 

353 if self._skip_derived_data: 

354 logger.debug("Skipping derived quantities.") 

355 try: 

356 n_periods = int(float((root_dict["data"]["count"].strip()))) 

357 except KeyError: 

358 n_periods = len(root_dict["data"]["period"]) 

359 

360 self.initialize_arrays(n_periods) 

361 for ii, block in enumerate(root_dict["data"]["period"]): 

362 self.period[ii] = float(block["value"]) # type: ignore[assignment] 

363 self.read_block(block, ii) 

364 

365 def write_block(self, parent: et.Element, index: int) -> et.Element: 

366 """ 

367 Write a data block 

368 

369 :param parent: DESCRIPTION 

370 :type parent: TYPE 

371 :return: DESCRIPTION 

372 :rtype: TYPE 

373 

374 """ 

375 

376 period_element = et.SubElement( 

377 parent, 

378 "Period", 

379 {"value": f"{self.period[index]:.12e}", "units": "secs"}, # type: ignore[arg-type] 

380 ) 

381 

382 for key in self.array_dict.keys(): 

383 if self.array_dict[key] is None: 

384 continue 

385 if self.array_dict[key].size == 0: 

386 logger.debug(f"No data for {key}, skipping.") 

387 continue 

388 arr = np.nan_to_num(self.array_dict[key][index]) 

389 

390 # set zeros to empty value of 1E32 

391 if arr.dtype == complex: 

392 arr[np.where(arr == 0)] = 1e32 + 1e32j 

393 else: 

394 arr[np.where(arr == 0)] = 1e32 

395 

396 attr_dict = { 

397 "type": self._dtype_dict[arr.dtype.name], 

398 "size": str(arr.shape)[1:-1].replace(",", ""), 

399 } 

400 try: 

401 attr_dict["units"] = self._units_dict[key] 

402 except KeyError: 

403 pass 

404 

405 comp_element = et.SubElement( 

406 period_element, key.replace("_", ".").upper(), attr_dict 

407 ) 

408 idx_dict = self._write_dict[key] 

409 shape = arr.shape 

410 for ii in range(shape[0]): 

411 for jj in range(shape[1]): 

412 ch_out = idx_dict["out"][ii] 

413 ch_in = idx_dict["in"][jj] 

414 a_dict = {} 

415 try: 

416 a_dict["name"] = self._name_dict[ch_out + ch_in].capitalize() 

417 except KeyError: 

418 pass 

419 a_dict["output"] = ch_out.capitalize() 

420 a_dict["input"] = ch_in.capitalize() 

421 ch_element = et.SubElement(comp_element, "value", a_dict) 

422 ch_value = f"{arr[ii, jj].real:.6e}" 

423 if attr_dict["type"] in ["complex"]: 

424 ch_value = f"{ch_value} {arr[ii, jj].imag:.6e}" 

425 ch_element.text = ch_value 

426 

427 return period_element 

428 

429 def to_xml(self, string: bool = False, required: bool = True) -> et.Element | str: 

430 """ 

431 Write data blocks 

432 

433 :param parent: DESCRIPTION 

434 :type parent: TYPE 

435 :return: DESCRIPTION 

436 :rtype: TYPE 

437 

438 """ 

439 root = et.Element("Data", {"count": f"{self.n_periods:.0f}"}) 

440 

441 for index in range(self.period.size): # type: ignore[attribute-error] 

442 self.write_block(root, index) 

443 

444 if string: 

445 return element_to_string(root) 

446 return root