Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mth5 \ mth5 \ groups \ transfer_function.py: 61%
257 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-27 20:09 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-27 20:09 -0800
1# -*- coding: utf-8 -*-
2from __future__ import annotations
5"""Transfer function HDF5 helpers for MTH5."""
7from typing import Any, Iterable
9# =============================================================================
10# Imports
11# =============================================================================
12import numpy as np
13import pandas as pd
14import xarray as xr
16from mth5.groups import BaseGroup, EstimateDataset
17from mth5.helpers import from_numpy_type, validate_name
18from mth5.utils.exceptions import MTH5Error
21def _check_channel_in_output(
22 output_channels: Iterable[str] | None, channel: str
23) -> bool:
24 """Return ``True`` if ``channel`` is present in an output list.
26 Handles both normal lists and corrupted serialization from HDF5 attributes
27 (for example ``['"ex"', '"ey"']``).
29 Parameters
30 ----------
31 output_channels : Iterable[str] or None
32 Output channel names, potentially serialized oddly in HDF5 attributes.
33 channel : str
34 Channel name to search for.
36 Returns
37 -------
38 bool
39 ``True`` when the channel is detected, otherwise ``False``.
41 Examples
42 --------
43 >>> _check_channel_in_output(["ex", "ey"], "ex")
44 True
45 >>> _check_channel_in_output(['"ex"', '"ey"'], "ex")
46 True
47 >>> _check_channel_in_output([], "hx")
48 False
49 """
50 if not output_channels:
51 return False
53 # Handle normal case
54 if channel in output_channels:
55 return True
57 # Handle corrupted HDF5 attribute serialization case
58 # where ['ex', 'ey', 'hz'] becomes ['["ex"', '"ey"', '"hz"]']
59 for item in output_channels:
60 if isinstance(item, str):
61 # Check if the channel appears in the corrupted string
62 if f'"{channel}"' in item or f"'{channel}'" in item:
63 return True
64 # Also check for cases where the quotes are missing
65 if channel in item:
66 return True
68 return False
71from mt_metadata.timeseries import Electric, Magnetic, Run
72from mt_metadata.transfer_functions.core import TF
73from mt_metadata.transfer_functions.tf.statistical_estimate import StatisticalEstimate
76# =============================================================================
77# Transfer Functions Group
78# =============================================================================
79class TransferFunctionsGroup(BaseGroup):
80 """Container for transfer functions under a station.
82 Each child group is a single transfer function estimation managed by
83 :class:`TransferFunctionGroup`.
85 Examples
86 --------
87 >>> from mth5 import mth5
88 >>> m5 = mth5.MTH5()
89 >>> _ = m5.open_mth5("/tmp/example.mth5", mode="a")
90 >>> station = m5.stations_group.add_station("mt01")
91 >>> tf_group = station.transfer_functions_group
92 >>> tf_group.groups_list
93 []
94 """
96 def __init__(self, group: Any, **kwargs: Any) -> None:
97 super().__init__(group, **kwargs)
99 def tf_summary(self, as_dataframe: bool = True) -> pd.DataFrame | np.ndarray:
100 """Summarize transfer functions stored for the station.
102 Parameters
103 ----------
104 as_dataframe : bool, default True
105 If ``True`` return a pandas DataFrame, otherwise a NumPy structured array.
107 Returns
108 -------
109 pandas.DataFrame or numpy.ndarray
110 Summary rows including station reference, location, and TF metadata.
112 Examples
113 --------
114 >>> summary = tf_group.tf_summary()
115 >>> summary.columns[:4].tolist() # doctest: +SKIP
116 ['station_hdf5_reference', 'station', 'latitude', 'longitude']
117 """
119 tf_list = []
120 for tf_id in self.groups_list:
121 tf_group = self.get_transfer_function(tf_id)
122 tf_entry = tf_group.tf_entry
124 tf_entry["station_hdf5_reference"][:] = self.hdf5_group.parent.ref
125 tf_entry["station"][:] = self.hdf5_group.parent.attrs["id"]
126 tf_entry["latitude"][:] = self.hdf5_group.parent.attrs["location.latitude"]
127 tf_entry["longitude"][:] = self.hdf5_group.parent.attrs[
128 "location.longitude"
129 ]
130 tf_entry["elevation"][:] = self.hdf5_group.parent.attrs[
131 "location.elevation"
132 ]
134 tf_list.append(tf_entry)
135 tf_list = np.array(tf_list)
137 if as_dataframe:
138 return pd.DataFrame(tf_list.flatten())
139 return tf_list
141 def _update_time_period_from_tf(self, tf_object: TF) -> None:
142 """Propagate run time bounds from a TF object into station metadata."""
144 if "1980" not in tf_object.station_metadata.time_period.start:
145 if "1980" in self.hdf5_group.parent.attrs["time_period.start"]:
146 self.hdf5_group.parent.attrs[
147 "time_period.start"
148 ] = tf_object.station_metadata.time_period.start.isoformat()
150 elif (
151 self.hdf5_group.parent.attrs["time_period.start"]
152 != tf_object.station_metadata.time_period.start
153 ):
154 if (
155 self.hdf5_group.parent.attrs["time_period.start"]
156 > tf_object.station_metadata.time_period.start
157 ):
158 self.hdf5_group.parent.attrs[
159 "time_period.start"
160 ] = tf_object.station_metadata.time_period.start.isoformat()
162 if "1980" not in tf_object.station_metadata.time_period.end:
163 if "1980" in self.hdf5_group.parent.attrs["time_period.end"]:
164 self.hdf5_group.parent.attrs[
165 "time_period.end"
166 ] = tf_object.station_metadata.time_period.end.isoformat()
168 elif (
169 self.hdf5_group.parent.attrs["time_period.end"]
170 != tf_object.station_metadata.time_period.end
171 ):
172 if (
173 self.hdf5_group.parent.attrs["time_period.end"]
174 > tf_object.station_metadata.time_period.end
175 ):
176 self.hdf5_group.parent.attrs[
177 "time_period.end"
178 ] = tf_object.station_metadata.time_period.end.isoformat()
180 def add_transfer_function(
181 self, name: str, tf_object: TF | None = None
182 ) -> "TransferFunctionGroup":
183 """Add a transfer function group under this station.
185 Parameters
186 ----------
187 name : str
188 Transfer function identifier.
189 tf_object : TF, optional
190 Transfer function instance to seed metadata and datasets.
192 Returns
193 -------
194 TransferFunctionGroup
195 Wrapper for the created or existing transfer function.
197 Examples
198 --------
199 >>> tf_group = station.transfer_functions_group
200 >>> _ = tf_group.add_transfer_function("mt01_4096")
201 """
202 name = validate_name(name)
204 if tf_object is not None:
205 self._update_time_period_from_tf(tf_object)
206 tf_group = TransferFunctionGroup(
207 self.hdf5_group.create_group(name),
208 group_metadata=tf_object.station_metadata.transfer_function,
209 **self.dataset_options,
210 )
211 tf_group.from_tf_object(tf_object, update_metadata=False)
213 else:
214 tf_group = TransferFunctionGroup(
215 self.hdf5_group.create_group(name), **self.dataset_options
216 )
218 return tf_group
220 def get_transfer_function(self, tf_id: str) -> "TransferFunctionGroup":
221 """Return an existing transfer function by id.
223 Parameters
224 ----------
225 tf_id : str
226 Name of the transfer function.
228 Returns
229 -------
230 TransferFunctionGroup
231 Wrapper for the requested transfer function.
233 Raises
234 ------
235 MTH5Error
236 If the transfer function does not exist.
238 Examples
239 --------
240 >>> existing = station.transfer_functions_group.get_transfer_function("mt01_4096")
241 >>> existing.name # doctest: +SKIP
242 'mt01_4096'
243 """
245 tf_id = validate_name(tf_id)
246 try:
247 return TransferFunctionGroup(self.hdf5_group[tf_id], **self.dataset_options)
248 except KeyError:
249 msg = f"{tf_id} does not exist, " + "check station_list for existing names"
250 self.logger.debug("Error" + msg)
251 raise MTH5Error(msg)
253 def remove_transfer_function(self, tf_id: str) -> None:
254 """Delete a transfer function reference from the station.
256 Parameters
257 ----------
258 tf_id : str
259 Transfer function name.
261 Notes
262 -----
263 HDF5 deletion removes the reference only; storage is not reclaimed.
265 Examples
266 --------
267 >>> tf_group.remove_transfer_function("mt01_4096")
268 """
270 tf_id = validate_name(tf_id)
271 try:
272 del self.hdf5_group[tf_id]
273 self.logger.info(
274 "Deleting a station does not reduce the HDF5"
275 "file size it simply remove the reference. If "
276 "file size reduction is your goal, simply copy"
277 " what you want into another file."
278 )
279 except KeyError:
280 msg = f"{tf_id} does not exist, " "check station_list for existing names"
281 self.logger.debug("Error" + msg)
282 raise MTH5Error(msg)
284 def get_tf_object(self, tf_id: str) -> TF:
285 """Return a populated :class:`mt_metadata.transfer_functions.core.TF`.
287 Parameters
288 ----------
289 tf_id : str
290 Transfer function name to convert.
292 Returns
293 -------
294 mt_metadata.transfer_functions.core.TF
295 Transfer function populated with metadata and estimates.
297 Examples
298 --------
299 >>> tf_obj = tf_group.get_tf_object("mt01_4096") # doctest: +SKIP
300 """
302 tf_group = self.get_transfer_function(tf_id)
304 return tf_group.to_tf_object()
307class TransferFunctionGroup(BaseGroup):
308 """Wrapper for a single transfer function estimation."""
310 def __init__(self, group: Any, **kwargs: Any) -> None:
311 super().__init__(group, **kwargs)
313 self._accepted_estimates = [
314 "transfer_function",
315 "transfer_function_error",
316 "inverse_signal_power",
317 "residual_covariance",
318 "impedance",
319 "impedance_error",
320 "tipper",
321 "tipper_error",
322 ]
324 self._period_metadata = StatisticalEstimate(
325 **{
326 "name": "period",
327 "data_type": "real",
328 "description": "Periods at which transfer function is estimated",
329 "units": "samples per second",
330 }
331 )
333 def has_estimate(self, estimate: str) -> bool:
334 """Return ``True`` if an estimate exists and is populated."""
336 if estimate in self.groups_list:
337 est = self.get_estimate(estimate)
338 if est.hdf5_dataset.shape == (1, 1, 1):
339 return False
340 return True
341 elif estimate in ["impedance"]:
342 est = self.get_estimate("transfer_function")
343 if est.hdf5_dataset.shape == (1, 1, 1):
344 return False
345 elif _check_channel_in_output(
346 est.metadata.output_channels, "ex"
347 ) and _check_channel_in_output(est.metadata.output_channels, "ey"):
348 return True
349 return False
350 elif estimate in ["tipper"]:
351 est = self.get_estimate("transfer_function")
352 if est.hdf5_dataset.shape == (1, 1, 1):
353 return False
354 elif _check_channel_in_output(est.metadata.output_channels, "hz"):
355 return True
356 return False
357 elif estimate in ["covariance"]:
358 try:
359 res = self.get_estimate("residual_covariance")
360 isp = self.get_estimate("inverse_signal_power")
362 if res.hdf5_dataset.shape != (
363 1,
364 1,
365 1,
366 ) and isp.hdf5_dataset.shape != (
367 1,
368 1,
369 1,
370 ):
371 return True
372 return False
373 except (KeyError, MTH5Error):
374 return False
375 return False
377 @property
378 def period(self) -> np.ndarray | None:
379 """Return period array stored in ``period`` dataset, if present."""
381 try:
382 return self.hdf5_group["period"][()]
383 except KeyError:
384 return None
386 @period.setter
387 def period(self, period: Any) -> None:
388 if period is not None:
389 period = np.array(period, dtype=float)
391 try:
392 _ = self.add_statistical_estimate(
393 "period",
394 estimate_data=period,
395 estimate_metadata=self._period_metadata,
396 chunks=True,
397 max_shape=(None,),
398 )
399 except (OSError, RuntimeError, ValueError):
400 self.logger.debug("period already exists, overwriting")
401 self.hdf5_group["period"][...] = period
403 def add_statistical_estimate(
404 self,
405 estimate_name: str,
406 estimate_data: np.ndarray | xr.DataArray | None = None,
407 estimate_metadata: StatisticalEstimate | None = None,
408 max_shape: tuple[int | None, int | None, int | None] = (None, None, None),
409 chunks: bool = True,
410 **kwargs: Any,
411 ) -> EstimateDataset:
412 """Add a statistical estimate dataset.
414 Parameters
415 ----------
416 estimate_name : str
417 Dataset name.
418 estimate_data : numpy.ndarray or xarray.DataArray, optional
419 Estimate values; if ``None`` a placeholder array is created.
420 estimate_metadata : StatisticalEstimate, optional
421 Metadata describing the estimate.
422 max_shape : tuple of int or None, default (None, None, None)
423 Maximum shape for resizable datasets.
424 chunks : bool, default True
425 Chunking flag forwarded to HDF5 dataset creation.
427 Returns
428 -------
429 EstimateDataset
430 Wrapper combining dataset and metadata.
432 Raises
433 ------
434 TypeError
435 If ``estimate_data`` is not array-like.
437 Examples
438 --------
439 >>> est = tf_group.add_statistical_estimate("transfer_function")
440 >>> isinstance(est, EstimateDataset)
441 True
442 """
444 estimate_name = validate_name(estimate_name)
446 if estimate_metadata is None:
447 estimate_metadata = StatisticalEstimate()
448 estimate_metadata.name = estimate_name
449 if estimate_data is not None:
450 if not isinstance(estimate_data, (np.ndarray, xr.DataArray)):
451 msg = f"Need to input a numpy or xarray.DataArray not {type(estimate_data)}"
452 self.logger.exception(msg)
453 raise TypeError(msg)
454 if isinstance(estimate_data, xr.DataArray):
455 estimate_metadata.output_channels = estimate_data.coords[
456 "output"
457 ].values.tolist()
458 estimate_metadata.input_channels = estimate_data.coords[
459 "input"
460 ].values.tolist()
461 estimate_metadata.name = validate_name(estimate_data.name)
462 estimate_metadata.data_type = estimate_data.dtype.name
464 estimate_data = estimate_data.to_numpy()
465 dtype = estimate_data.dtype
466 else:
467 dtype = complex
468 chunks = True
469 estimate_data = np.zeros((1, 1, 1), dtype=dtype)
470 try:
471 dataset = self.hdf5_group.create_dataset(
472 estimate_name,
473 data=estimate_data,
474 dtype=dtype,
475 chunks=chunks,
476 maxshape=max_shape,
477 **self.dataset_options,
478 )
480 estimate_dataset = EstimateDataset(
481 dataset, dataset_metadata=estimate_metadata
482 )
483 except (OSError, RuntimeError, ValueError) as error:
484 self.logger.error(error)
485 msg = f"estimate {estimate_metadata.name} already exists, returning existing group."
486 self.logger.debug(msg)
488 estimate_dataset = self.get_estimate(estimate_metadata.name)
489 return estimate_dataset
491 def get_estimate(self, estimate_name: str) -> EstimateDataset:
492 """Return a statistical estimate dataset by name."""
493 estimate_name = validate_name(estimate_name)
495 try:
496 estimate_dataset = self.hdf5_group[estimate_name]
497 estimate_metadata = StatisticalEstimate(**dict(estimate_dataset.attrs))
498 return EstimateDataset(estimate_dataset, dataset_metadata=estimate_metadata)
499 except KeyError:
500 msg = (
501 f"{estimate_name} does not exist, "
502 "check groups_list for existing names"
503 )
504 self.logger.error(msg)
505 raise MTH5Error(msg)
506 except OSError as error:
507 self.logger.error(error)
508 raise MTH5Error(error)
510 def remove_estimate(self, estimate_name: str) -> None:
511 """Remove a statistical estimate dataset reference."""
513 estimate_name = validate_name(estimate_name.lower())
515 try:
516 del self.hdf5_group[estimate_name]
517 self.logger.info(
518 "Deleting a estimate does not reduce the HDF5"
519 "file size it simply remove the reference. If "
520 "file size reduction is your goal, simply copy"
521 " what you want into another file."
522 )
523 except KeyError:
524 msg = (
525 f"{estimate_name} does not exist, "
526 + "check groups_list for existing names"
527 )
528 self.logger.error(msg)
529 raise MTH5Error(msg)
531 def to_tf_object(self) -> TF:
532 """Convert this group into a populated :class:`TF` object.
534 Returns
535 -------
536 mt_metadata.transfer_functions.core.TF
537 TF instance with survey, station, runs, channels, period, and
538 estimate datasets applied.
540 Raises
541 ------
542 ValueError
543 If no period dataset is present.
545 Examples
546 --------
547 >>> tf_obj = tf_group.to_tf_object() # doctest: +SKIP
548 """
550 tf_obj = TF()
552 # get survey metadata
553 survey_dict = dict(self.hdf5_group.parent.parent.parent.parent.attrs)
554 for key, value in survey_dict.items():
555 survey_dict[key] = from_numpy_type(value)
556 tf_obj.survey_metadata.from_dict({"survey": survey_dict})
558 # get station metadata
559 station_dict = dict(self.hdf5_group.parent.parent.attrs)
560 for key, value in station_dict.items():
561 station_dict[key] = from_numpy_type(value)
562 tf_obj.station_metadata.from_dict({"station": station_dict})
564 # need to update transfer function metadata
565 tf_dict = dict(self.hdf5_group.attrs)
566 for key, value in tf_dict.items():
567 tf_dict[key] = from_numpy_type(value)
568 tf_obj.station_metadata.transfer_function.from_dict(
569 {"transfer_function": tf_dict}
570 )
572 # add run and channel metadata
573 tf_obj.station_metadata.runs = []
574 for run_id in tf_obj.station_metadata.transfer_function.runs_processed:
575 if run_id in ["", None, "None"]:
576 continue
577 try:
578 run = self.hdf5_group.parent.parent[validate_name(run_id)]
579 run_dict = dict(run.attrs)
580 for key, value in run_dict.items():
581 run_dict[key] = from_numpy_type(value)
582 run_obj = Run(**run_dict)
584 for ch_id in run.keys():
585 ch = run[validate_name(ch_id)]
586 ch_dict = dict(ch.attrs)
587 for key, value in ch_dict.items():
588 ch_dict[key] = from_numpy_type(value)
589 if ch_dict["type"] == "electric":
590 ch_obj = Electric(**ch_dict)
591 elif ch_dict["type"] == "magnetic":
592 ch_obj = Magnetic(**ch_dict)
593 run_obj.add_channel(ch_obj)
594 tf_obj.station_metadata.add_run(run_obj)
595 except KeyError:
596 self.logger.info(f"Could not get run {run_id} for transfer function")
597 if self.period is not None:
598 tf_obj.period = self.period
599 else:
600 msg = "Period must not be None to create a transfer function object"
601 self.logger.error(msg)
602 raise ValueError(msg)
603 for estimate_name in self.groups_list:
604 if estimate_name in ["period"]:
605 continue
606 estimate = self.get_estimate(estimate_name)
608 try:
609 setattr(tf_obj, estimate_name, estimate.to_numpy())
610 except AttributeError as error:
611 self.logger.exception(error)
613 # need to update time periods
614 tf_obj.station_metadata.update_time_period()
615 tf_obj.survey_metadata.update_time_period()
616 return tf_obj
618 def from_tf_object(self, tf_obj: TF, update_metadata: bool = True) -> None:
619 """Populate datasets from a :class:`TF` object.
621 Parameters
622 ----------
623 tf_obj : TF
624 Transfer function object containing estimates and metadata.
625 update_metadata : bool, default True
626 If ``True`` write transfer function metadata to HDF5.
628 Raises
629 ------
630 ValueError
631 If ``tf_obj`` is not a ``TF`` instance.
633 Examples
634 --------
635 >>> tf_group.from_tf_object(tf_obj) # doctest: +SKIP
636 """
638 if not isinstance(tf_obj, TF):
639 msg = f"Input must be a TF object not {type(tf_obj)}"
640 self.logger.error(msg)
641 raise ValueError(msg)
642 self.period = tf_obj.period
643 if update_metadata:
644 self.metadata.update(tf_obj.station_metadata.transfer_function)
645 self.write_metadata()
647 # if transfer function is available then impedance and tipper are
648 # redundant.
649 if tf_obj.has_transfer_function():
650 accepted_estimates = self._accepted_estimates[0:4]
651 else:
652 accepted_estimates = self._accepted_estimates
653 for estimate_name in accepted_estimates:
654 try:
655 estimate = getattr(tf_obj, estimate_name)
656 if estimate is not None:
657 _ = self.add_statistical_estimate(estimate_name, estimate)
658 else:
659 self.logger.debug(f"Did not find {estimate_name} in TF. Skipping")
660 except AttributeError:
661 self.logger.debug(f"Did not find {estimate_name} in TF. Skipping")