Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mth5 \ mth5 \ groups \ transfer_function.py: 61%

257 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-27 20:09 -0800

1# -*- coding: utf-8 -*- 

2from __future__ import annotations 

3 

4 

5"""Transfer function HDF5 helpers for MTH5.""" 

6 

7from typing import Any, Iterable 

8 

9# ============================================================================= 

10# Imports 

11# ============================================================================= 

12import numpy as np 

13import pandas as pd 

14import xarray as xr 

15 

16from mth5.groups import BaseGroup, EstimateDataset 

17from mth5.helpers import from_numpy_type, validate_name 

18from mth5.utils.exceptions import MTH5Error 

19 

20 

21def _check_channel_in_output( 

22 output_channels: Iterable[str] | None, channel: str 

23) -> bool: 

24 """Return ``True`` if ``channel`` is present in an output list. 

25 

26 Handles both normal lists and corrupted serialization from HDF5 attributes 

27 (for example ``['"ex"', '"ey"']``). 

28 

29 Parameters 

30 ---------- 

31 output_channels : Iterable[str] or None 

32 Output channel names, potentially serialized oddly in HDF5 attributes. 

33 channel : str 

34 Channel name to search for. 

35 

36 Returns 

37 ------- 

38 bool 

39 ``True`` when the channel is detected, otherwise ``False``. 

40 

41 Examples 

42 -------- 

43 >>> _check_channel_in_output(["ex", "ey"], "ex") 

44 True 

45 >>> _check_channel_in_output(['"ex"', '"ey"'], "ex") 

46 True 

47 >>> _check_channel_in_output([], "hx") 

48 False 

49 """ 

50 if not output_channels: 

51 return False 

52 

53 # Handle normal case 

54 if channel in output_channels: 

55 return True 

56 

57 # Handle corrupted HDF5 attribute serialization case 

58 # where ['ex', 'ey', 'hz'] becomes ['["ex"', '"ey"', '"hz"]'] 

59 for item in output_channels: 

60 if isinstance(item, str): 

61 # Check if the channel appears in the corrupted string 

62 if f'"{channel}"' in item or f"'{channel}'" in item: 

63 return True 

64 # Also check for cases where the quotes are missing 

65 if channel in item: 

66 return True 

67 

68 return False 

69 

70 

71from mt_metadata.timeseries import Electric, Magnetic, Run 

72from mt_metadata.transfer_functions.core import TF 

73from mt_metadata.transfer_functions.tf.statistical_estimate import StatisticalEstimate 

74 

75 

76# ============================================================================= 

77# Transfer Functions Group 

78# ============================================================================= 

79class TransferFunctionsGroup(BaseGroup): 

80 """Container for transfer functions under a station. 

81 

82 Each child group is a single transfer function estimation managed by 

83 :class:`TransferFunctionGroup`. 

84 

85 Examples 

86 -------- 

87 >>> from mth5 import mth5 

88 >>> m5 = mth5.MTH5() 

89 >>> _ = m5.open_mth5("/tmp/example.mth5", mode="a") 

90 >>> station = m5.stations_group.add_station("mt01") 

91 >>> tf_group = station.transfer_functions_group 

92 >>> tf_group.groups_list 

93 [] 

94 """ 

95 

96 def __init__(self, group: Any, **kwargs: Any) -> None: 

97 super().__init__(group, **kwargs) 

98 

99 def tf_summary(self, as_dataframe: bool = True) -> pd.DataFrame | np.ndarray: 

100 """Summarize transfer functions stored for the station. 

101 

102 Parameters 

103 ---------- 

104 as_dataframe : bool, default True 

105 If ``True`` return a pandas DataFrame, otherwise a NumPy structured array. 

106 

107 Returns 

108 ------- 

109 pandas.DataFrame or numpy.ndarray 

110 Summary rows including station reference, location, and TF metadata. 

111 

112 Examples 

113 -------- 

114 >>> summary = tf_group.tf_summary() 

115 >>> summary.columns[:4].tolist() # doctest: +SKIP 

116 ['station_hdf5_reference', 'station', 'latitude', 'longitude'] 

117 """ 

118 

119 tf_list = [] 

120 for tf_id in self.groups_list: 

121 tf_group = self.get_transfer_function(tf_id) 

122 tf_entry = tf_group.tf_entry 

123 

124 tf_entry["station_hdf5_reference"][:] = self.hdf5_group.parent.ref 

125 tf_entry["station"][:] = self.hdf5_group.parent.attrs["id"] 

126 tf_entry["latitude"][:] = self.hdf5_group.parent.attrs["location.latitude"] 

127 tf_entry["longitude"][:] = self.hdf5_group.parent.attrs[ 

128 "location.longitude" 

129 ] 

130 tf_entry["elevation"][:] = self.hdf5_group.parent.attrs[ 

131 "location.elevation" 

132 ] 

133 

134 tf_list.append(tf_entry) 

135 tf_list = np.array(tf_list) 

136 

137 if as_dataframe: 

138 return pd.DataFrame(tf_list.flatten()) 

139 return tf_list 

140 

141 def _update_time_period_from_tf(self, tf_object: TF) -> None: 

142 """Propagate run time bounds from a TF object into station metadata.""" 

143 

144 if "1980" not in tf_object.station_metadata.time_period.start: 

145 if "1980" in self.hdf5_group.parent.attrs["time_period.start"]: 

146 self.hdf5_group.parent.attrs[ 

147 "time_period.start" 

148 ] = tf_object.station_metadata.time_period.start.isoformat() 

149 

150 elif ( 

151 self.hdf5_group.parent.attrs["time_period.start"] 

152 != tf_object.station_metadata.time_period.start 

153 ): 

154 if ( 

155 self.hdf5_group.parent.attrs["time_period.start"] 

156 > tf_object.station_metadata.time_period.start 

157 ): 

158 self.hdf5_group.parent.attrs[ 

159 "time_period.start" 

160 ] = tf_object.station_metadata.time_period.start.isoformat() 

161 

162 if "1980" not in tf_object.station_metadata.time_period.end: 

163 if "1980" in self.hdf5_group.parent.attrs["time_period.end"]: 

164 self.hdf5_group.parent.attrs[ 

165 "time_period.end" 

166 ] = tf_object.station_metadata.time_period.end.isoformat() 

167 

168 elif ( 

169 self.hdf5_group.parent.attrs["time_period.end"] 

170 != tf_object.station_metadata.time_period.end 

171 ): 

172 if ( 

173 self.hdf5_group.parent.attrs["time_period.end"] 

174 > tf_object.station_metadata.time_period.end 

175 ): 

176 self.hdf5_group.parent.attrs[ 

177 "time_period.end" 

178 ] = tf_object.station_metadata.time_period.end.isoformat() 

179 

180 def add_transfer_function( 

181 self, name: str, tf_object: TF | None = None 

182 ) -> "TransferFunctionGroup": 

183 """Add a transfer function group under this station. 

184 

185 Parameters 

186 ---------- 

187 name : str 

188 Transfer function identifier. 

189 tf_object : TF, optional 

190 Transfer function instance to seed metadata and datasets. 

191 

192 Returns 

193 ------- 

194 TransferFunctionGroup 

195 Wrapper for the created or existing transfer function. 

196 

197 Examples 

198 -------- 

199 >>> tf_group = station.transfer_functions_group 

200 >>> _ = tf_group.add_transfer_function("mt01_4096") 

201 """ 

202 name = validate_name(name) 

203 

204 if tf_object is not None: 

205 self._update_time_period_from_tf(tf_object) 

206 tf_group = TransferFunctionGroup( 

207 self.hdf5_group.create_group(name), 

208 group_metadata=tf_object.station_metadata.transfer_function, 

209 **self.dataset_options, 

210 ) 

211 tf_group.from_tf_object(tf_object, update_metadata=False) 

212 

213 else: 

214 tf_group = TransferFunctionGroup( 

215 self.hdf5_group.create_group(name), **self.dataset_options 

216 ) 

217 

218 return tf_group 

219 

220 def get_transfer_function(self, tf_id: str) -> "TransferFunctionGroup": 

221 """Return an existing transfer function by id. 

222 

223 Parameters 

224 ---------- 

225 tf_id : str 

226 Name of the transfer function. 

227 

228 Returns 

229 ------- 

230 TransferFunctionGroup 

231 Wrapper for the requested transfer function. 

232 

233 Raises 

234 ------ 

235 MTH5Error 

236 If the transfer function does not exist. 

237 

238 Examples 

239 -------- 

240 >>> existing = station.transfer_functions_group.get_transfer_function("mt01_4096") 

241 >>> existing.name # doctest: +SKIP 

242 'mt01_4096' 

243 """ 

244 

245 tf_id = validate_name(tf_id) 

246 try: 

247 return TransferFunctionGroup(self.hdf5_group[tf_id], **self.dataset_options) 

248 except KeyError: 

249 msg = f"{tf_id} does not exist, " + "check station_list for existing names" 

250 self.logger.debug("Error" + msg) 

251 raise MTH5Error(msg) 

252 

253 def remove_transfer_function(self, tf_id: str) -> None: 

254 """Delete a transfer function reference from the station. 

255 

256 Parameters 

257 ---------- 

258 tf_id : str 

259 Transfer function name. 

260 

261 Notes 

262 ----- 

263 HDF5 deletion removes the reference only; storage is not reclaimed. 

264 

265 Examples 

266 -------- 

267 >>> tf_group.remove_transfer_function("mt01_4096") 

268 """ 

269 

270 tf_id = validate_name(tf_id) 

271 try: 

272 del self.hdf5_group[tf_id] 

273 self.logger.info( 

274 "Deleting a station does not reduce the HDF5" 

275 "file size it simply remove the reference. If " 

276 "file size reduction is your goal, simply copy" 

277 " what you want into another file." 

278 ) 

279 except KeyError: 

280 msg = f"{tf_id} does not exist, " "check station_list for existing names" 

281 self.logger.debug("Error" + msg) 

282 raise MTH5Error(msg) 

283 

284 def get_tf_object(self, tf_id: str) -> TF: 

285 """Return a populated :class:`mt_metadata.transfer_functions.core.TF`. 

286 

287 Parameters 

288 ---------- 

289 tf_id : str 

290 Transfer function name to convert. 

291 

292 Returns 

293 ------- 

294 mt_metadata.transfer_functions.core.TF 

295 Transfer function populated with metadata and estimates. 

296 

297 Examples 

298 -------- 

299 >>> tf_obj = tf_group.get_tf_object("mt01_4096") # doctest: +SKIP 

300 """ 

301 

302 tf_group = self.get_transfer_function(tf_id) 

303 

304 return tf_group.to_tf_object() 

305 

306 

307class TransferFunctionGroup(BaseGroup): 

308 """Wrapper for a single transfer function estimation.""" 

309 

310 def __init__(self, group: Any, **kwargs: Any) -> None: 

311 super().__init__(group, **kwargs) 

312 

313 self._accepted_estimates = [ 

314 "transfer_function", 

315 "transfer_function_error", 

316 "inverse_signal_power", 

317 "residual_covariance", 

318 "impedance", 

319 "impedance_error", 

320 "tipper", 

321 "tipper_error", 

322 ] 

323 

324 self._period_metadata = StatisticalEstimate( 

325 **{ 

326 "name": "period", 

327 "data_type": "real", 

328 "description": "Periods at which transfer function is estimated", 

329 "units": "samples per second", 

330 } 

331 ) 

332 

333 def has_estimate(self, estimate: str) -> bool: 

334 """Return ``True`` if an estimate exists and is populated.""" 

335 

336 if estimate in self.groups_list: 

337 est = self.get_estimate(estimate) 

338 if est.hdf5_dataset.shape == (1, 1, 1): 

339 return False 

340 return True 

341 elif estimate in ["impedance"]: 

342 est = self.get_estimate("transfer_function") 

343 if est.hdf5_dataset.shape == (1, 1, 1): 

344 return False 

345 elif _check_channel_in_output( 

346 est.metadata.output_channels, "ex" 

347 ) and _check_channel_in_output(est.metadata.output_channels, "ey"): 

348 return True 

349 return False 

350 elif estimate in ["tipper"]: 

351 est = self.get_estimate("transfer_function") 

352 if est.hdf5_dataset.shape == (1, 1, 1): 

353 return False 

354 elif _check_channel_in_output(est.metadata.output_channels, "hz"): 

355 return True 

356 return False 

357 elif estimate in ["covariance"]: 

358 try: 

359 res = self.get_estimate("residual_covariance") 

360 isp = self.get_estimate("inverse_signal_power") 

361 

362 if res.hdf5_dataset.shape != ( 

363 1, 

364 1, 

365 1, 

366 ) and isp.hdf5_dataset.shape != ( 

367 1, 

368 1, 

369 1, 

370 ): 

371 return True 

372 return False 

373 except (KeyError, MTH5Error): 

374 return False 

375 return False 

376 

377 @property 

378 def period(self) -> np.ndarray | None: 

379 """Return period array stored in ``period`` dataset, if present.""" 

380 

381 try: 

382 return self.hdf5_group["period"][()] 

383 except KeyError: 

384 return None 

385 

386 @period.setter 

387 def period(self, period: Any) -> None: 

388 if period is not None: 

389 period = np.array(period, dtype=float) 

390 

391 try: 

392 _ = self.add_statistical_estimate( 

393 "period", 

394 estimate_data=period, 

395 estimate_metadata=self._period_metadata, 

396 chunks=True, 

397 max_shape=(None,), 

398 ) 

399 except (OSError, RuntimeError, ValueError): 

400 self.logger.debug("period already exists, overwriting") 

401 self.hdf5_group["period"][...] = period 

402 

403 def add_statistical_estimate( 

404 self, 

405 estimate_name: str, 

406 estimate_data: np.ndarray | xr.DataArray | None = None, 

407 estimate_metadata: StatisticalEstimate | None = None, 

408 max_shape: tuple[int | None, int | None, int | None] = (None, None, None), 

409 chunks: bool = True, 

410 **kwargs: Any, 

411 ) -> EstimateDataset: 

412 """Add a statistical estimate dataset. 

413 

414 Parameters 

415 ---------- 

416 estimate_name : str 

417 Dataset name. 

418 estimate_data : numpy.ndarray or xarray.DataArray, optional 

419 Estimate values; if ``None`` a placeholder array is created. 

420 estimate_metadata : StatisticalEstimate, optional 

421 Metadata describing the estimate. 

422 max_shape : tuple of int or None, default (None, None, None) 

423 Maximum shape for resizable datasets. 

424 chunks : bool, default True 

425 Chunking flag forwarded to HDF5 dataset creation. 

426 

427 Returns 

428 ------- 

429 EstimateDataset 

430 Wrapper combining dataset and metadata. 

431 

432 Raises 

433 ------ 

434 TypeError 

435 If ``estimate_data`` is not array-like. 

436 

437 Examples 

438 -------- 

439 >>> est = tf_group.add_statistical_estimate("transfer_function") 

440 >>> isinstance(est, EstimateDataset) 

441 True 

442 """ 

443 

444 estimate_name = validate_name(estimate_name) 

445 

446 if estimate_metadata is None: 

447 estimate_metadata = StatisticalEstimate() 

448 estimate_metadata.name = estimate_name 

449 if estimate_data is not None: 

450 if not isinstance(estimate_data, (np.ndarray, xr.DataArray)): 

451 msg = f"Need to input a numpy or xarray.DataArray not {type(estimate_data)}" 

452 self.logger.exception(msg) 

453 raise TypeError(msg) 

454 if isinstance(estimate_data, xr.DataArray): 

455 estimate_metadata.output_channels = estimate_data.coords[ 

456 "output" 

457 ].values.tolist() 

458 estimate_metadata.input_channels = estimate_data.coords[ 

459 "input" 

460 ].values.tolist() 

461 estimate_metadata.name = validate_name(estimate_data.name) 

462 estimate_metadata.data_type = estimate_data.dtype.name 

463 

464 estimate_data = estimate_data.to_numpy() 

465 dtype = estimate_data.dtype 

466 else: 

467 dtype = complex 

468 chunks = True 

469 estimate_data = np.zeros((1, 1, 1), dtype=dtype) 

470 try: 

471 dataset = self.hdf5_group.create_dataset( 

472 estimate_name, 

473 data=estimate_data, 

474 dtype=dtype, 

475 chunks=chunks, 

476 maxshape=max_shape, 

477 **self.dataset_options, 

478 ) 

479 

480 estimate_dataset = EstimateDataset( 

481 dataset, dataset_metadata=estimate_metadata 

482 ) 

483 except (OSError, RuntimeError, ValueError) as error: 

484 self.logger.error(error) 

485 msg = f"estimate {estimate_metadata.name} already exists, returning existing group." 

486 self.logger.debug(msg) 

487 

488 estimate_dataset = self.get_estimate(estimate_metadata.name) 

489 return estimate_dataset 

490 

491 def get_estimate(self, estimate_name: str) -> EstimateDataset: 

492 """Return a statistical estimate dataset by name.""" 

493 estimate_name = validate_name(estimate_name) 

494 

495 try: 

496 estimate_dataset = self.hdf5_group[estimate_name] 

497 estimate_metadata = StatisticalEstimate(**dict(estimate_dataset.attrs)) 

498 return EstimateDataset(estimate_dataset, dataset_metadata=estimate_metadata) 

499 except KeyError: 

500 msg = ( 

501 f"{estimate_name} does not exist, " 

502 "check groups_list for existing names" 

503 ) 

504 self.logger.error(msg) 

505 raise MTH5Error(msg) 

506 except OSError as error: 

507 self.logger.error(error) 

508 raise MTH5Error(error) 

509 

510 def remove_estimate(self, estimate_name: str) -> None: 

511 """Remove a statistical estimate dataset reference.""" 

512 

513 estimate_name = validate_name(estimate_name.lower()) 

514 

515 try: 

516 del self.hdf5_group[estimate_name] 

517 self.logger.info( 

518 "Deleting a estimate does not reduce the HDF5" 

519 "file size it simply remove the reference. If " 

520 "file size reduction is your goal, simply copy" 

521 " what you want into another file." 

522 ) 

523 except KeyError: 

524 msg = ( 

525 f"{estimate_name} does not exist, " 

526 + "check groups_list for existing names" 

527 ) 

528 self.logger.error(msg) 

529 raise MTH5Error(msg) 

530 

531 def to_tf_object(self) -> TF: 

532 """Convert this group into a populated :class:`TF` object. 

533 

534 Returns 

535 ------- 

536 mt_metadata.transfer_functions.core.TF 

537 TF instance with survey, station, runs, channels, period, and 

538 estimate datasets applied. 

539 

540 Raises 

541 ------ 

542 ValueError 

543 If no period dataset is present. 

544 

545 Examples 

546 -------- 

547 >>> tf_obj = tf_group.to_tf_object() # doctest: +SKIP 

548 """ 

549 

550 tf_obj = TF() 

551 

552 # get survey metadata 

553 survey_dict = dict(self.hdf5_group.parent.parent.parent.parent.attrs) 

554 for key, value in survey_dict.items(): 

555 survey_dict[key] = from_numpy_type(value) 

556 tf_obj.survey_metadata.from_dict({"survey": survey_dict}) 

557 

558 # get station metadata 

559 station_dict = dict(self.hdf5_group.parent.parent.attrs) 

560 for key, value in station_dict.items(): 

561 station_dict[key] = from_numpy_type(value) 

562 tf_obj.station_metadata.from_dict({"station": station_dict}) 

563 

564 # need to update transfer function metadata 

565 tf_dict = dict(self.hdf5_group.attrs) 

566 for key, value in tf_dict.items(): 

567 tf_dict[key] = from_numpy_type(value) 

568 tf_obj.station_metadata.transfer_function.from_dict( 

569 {"transfer_function": tf_dict} 

570 ) 

571 

572 # add run and channel metadata 

573 tf_obj.station_metadata.runs = [] 

574 for run_id in tf_obj.station_metadata.transfer_function.runs_processed: 

575 if run_id in ["", None, "None"]: 

576 continue 

577 try: 

578 run = self.hdf5_group.parent.parent[validate_name(run_id)] 

579 run_dict = dict(run.attrs) 

580 for key, value in run_dict.items(): 

581 run_dict[key] = from_numpy_type(value) 

582 run_obj = Run(**run_dict) 

583 

584 for ch_id in run.keys(): 

585 ch = run[validate_name(ch_id)] 

586 ch_dict = dict(ch.attrs) 

587 for key, value in ch_dict.items(): 

588 ch_dict[key] = from_numpy_type(value) 

589 if ch_dict["type"] == "electric": 

590 ch_obj = Electric(**ch_dict) 

591 elif ch_dict["type"] == "magnetic": 

592 ch_obj = Magnetic(**ch_dict) 

593 run_obj.add_channel(ch_obj) 

594 tf_obj.station_metadata.add_run(run_obj) 

595 except KeyError: 

596 self.logger.info(f"Could not get run {run_id} for transfer function") 

597 if self.period is not None: 

598 tf_obj.period = self.period 

599 else: 

600 msg = "Period must not be None to create a transfer function object" 

601 self.logger.error(msg) 

602 raise ValueError(msg) 

603 for estimate_name in self.groups_list: 

604 if estimate_name in ["period"]: 

605 continue 

606 estimate = self.get_estimate(estimate_name) 

607 

608 try: 

609 setattr(tf_obj, estimate_name, estimate.to_numpy()) 

610 except AttributeError as error: 

611 self.logger.exception(error) 

612 

613 # need to update time periods 

614 tf_obj.station_metadata.update_time_period() 

615 tf_obj.survey_metadata.update_time_period() 

616 return tf_obj 

617 

618 def from_tf_object(self, tf_obj: TF, update_metadata: bool = True) -> None: 

619 """Populate datasets from a :class:`TF` object. 

620 

621 Parameters 

622 ---------- 

623 tf_obj : TF 

624 Transfer function object containing estimates and metadata. 

625 update_metadata : bool, default True 

626 If ``True`` write transfer function metadata to HDF5. 

627 

628 Raises 

629 ------ 

630 ValueError 

631 If ``tf_obj`` is not a ``TF`` instance. 

632 

633 Examples 

634 -------- 

635 >>> tf_group.from_tf_object(tf_obj) # doctest: +SKIP 

636 """ 

637 

638 if not isinstance(tf_obj, TF): 

639 msg = f"Input must be a TF object not {type(tf_obj)}" 

640 self.logger.error(msg) 

641 raise ValueError(msg) 

642 self.period = tf_obj.period 

643 if update_metadata: 

644 self.metadata.update(tf_obj.station_metadata.transfer_function) 

645 self.write_metadata() 

646 

647 # if transfer function is available then impedance and tipper are 

648 # redundant. 

649 if tf_obj.has_transfer_function(): 

650 accepted_estimates = self._accepted_estimates[0:4] 

651 else: 

652 accepted_estimates = self._accepted_estimates 

653 for estimate_name in accepted_estimates: 

654 try: 

655 estimate = getattr(tf_obj, estimate_name) 

656 if estimate is not None: 

657 _ = self.add_statistical_estimate(estimate_name, estimate) 

658 else: 

659 self.logger.debug(f"Did not find {estimate_name} in TF. Skipping") 

660 except AttributeError: 

661 self.logger.debug(f"Did not find {estimate_name} in TF. Skipping")