Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mth5 \ mth5 \ groups \ estimate_dataset.py: 51%
87 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-27 20:09 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-27 20:09 -0800
1# -*- coding: utf-8 -*-
2"""
3Created on Thu Mar 10 09:02:16 2022
5@author: jpeacock
6"""
8# =============================================================================
9# Imports
10# =============================================================================
11from __future__ import annotations
13import weakref
14from typing import Any
16import h5py
17import numpy as np
18import xarray as xr
19from loguru import logger
20from mt_metadata.transfer_functions.tf.statistical_estimate import StatisticalEstimate
21from mt_metadata.utils.validators import validate_attribute
23from mth5.helpers import add_attributes_to_metadata_class_pydantic, to_numpy_type
24from mth5.utils.exceptions import MTH5Error
27# =============================================================================
30class EstimateDataset:
31 """
32 Container for statistical estimates of transfer functions.
34 This class holds multi-dimensional statistical estimates for transfer
35 functions with full metadata management. Estimates are stored as HDF5
36 datasets with dimensions for period, output channels, and input channels.
38 Parameters
39 ----------
40 dataset : h5py.Dataset
41 HDF5 dataset containing the statistical estimate data.
42 dataset_metadata : mt_metadata.transfer_functions.tf.StatisticalEstimate, optional
43 Metadata object for the estimate. If provided and write_metadata is True,
44 the metadata will be written to the HDF5 attributes. Defaults to None.
45 write_metadata : bool, optional
46 If True, write metadata to the HDF5 dataset attributes. Defaults to True.
47 **kwargs : Any
48 Additional keyword arguments (reserved for future use).
50 Attributes
51 ----------
52 hdf5_dataset : h5py.Dataset
53 Weak reference to the HDF5 dataset.
54 metadata : StatisticalEstimate
55 Metadata container for the estimate.
56 logger : loguru.logger
57 Logger instance for reporting messages.
59 Raises
60 ------
61 MTH5Error
62 If dataset_metadata is provided but is not of type StatisticalEstimate
63 or a compatible metadata class.
64 TypeError
65 If input data cannot be converted to numpy array or has wrong dtype/shape.
67 Notes
68 -----
69 The estimate data is stored in 3D form with shape:
70 (n_periods, n_output_channels, n_input_channels)
72 Metadata is automatically synchronized between the pydantic model and
73 HDF5 attributes on initialization and after any modifications.
75 Examples
76 --------
77 Create an estimate dataset from an HDF5 group:
79 >>> import h5py
80 >>> import numpy as np
81 >>> from mt_metadata.transfer_functions.tf.statistical_estimate import StatisticalEstimate
82 >>> # Create HDF5 file with estimate dataset
83 >>> with h5py.File('estimate.h5', 'w') as f:
84 ... # Create dataset with shape (10 periods, 2 outputs, 2 inputs)
85 ... data = np.random.rand(10, 2, 2)
86 ... dset = f.create_dataset('estimate', data=data)
87 ... # Create EstimateDataset
88 ... est = EstimateDataset(dset, write_metadata=True)
90 Convert estimate to xarray and back:
92 >>> periods = np.logspace(-3, 3, 10) # 10 periods from 1e-3 to 1e3 s
93 >>> xr_data = est.to_xarray(periods)
94 >>> # Modify xarray coordinates
95 >>> new_xr = xr_data.rename({'output': 'new_output', 'input': 'new_input'})
96 >>> est.from_xarray(new_xr) # Load modified data back
98 Access estimate data in different formats:
100 >>> # Get numpy array
101 >>> np_data = est.to_numpy()
102 >>> print(np_data.shape) # (10, 2, 2)
103 >>> # Get xarray with proper coordinates
104 >>> xr_data = est.to_xarray(periods)
105 >>> print(xr_data.dims) # ('period', 'output', 'input')
107 """
109 def __init__(
110 self,
111 dataset: h5py.Dataset,
112 dataset_metadata: StatisticalEstimate | None = None,
113 write_metadata: bool = True,
114 **kwargs: Any,
115 ) -> None:
116 """
117 Initialize an EstimateDataset.
119 Parameters
120 ----------
121 dataset : h5py.Dataset
122 HDF5 dataset for storing estimate data.
123 dataset_metadata : StatisticalEstimate | None, optional
124 Metadata object. If provided, updates internal metadata.
125 Defaults to None.
126 write_metadata : bool, optional
127 Write metadata to HDF5 attributes. Defaults to True.
128 **kwargs : Any
129 Additional keyword arguments (reserved for future use).
131 Raises
132 ------
133 MTH5Error
134 If dataset_metadata type doesn't match expected metadata class.
136 Examples
137 --------
138 Create and initialize an estimate dataset:
140 >>> import h5py
141 >>> import numpy as np
142 >>> from mt_metadata.transfer_functions.tf.statistical_estimate import StatisticalEstimate
143 >>> with h5py.File('estimate.h5', 'w') as f:
144 ... data = np.random.rand(5, 2, 2) # 5 periods, 2 outputs, 2 inputs
145 ... dset = f.create_dataset('estimate', data=data)
146 ... est = EstimateDataset(dset) # Auto-initialize metadata
148 """
149 if dataset is not None and isinstance(dataset, (h5py.Dataset)):
150 self.hdf5_dataset = weakref.ref(dataset)()
151 self.logger = logger
153 # set metadata to the appropriate class. Standards is not a
154 # Base object so should be skipped. If the class name is not
155 # defined yet set to Base class.
156 self.metadata = add_attributes_to_metadata_class_pydantic(StatisticalEstimate)
157 self.metadata.hdf5_reference = self.hdf5_dataset.ref
158 self.metadata.mth5_type = validate_attribute(self._class_name)
160 # if the input data set already has filled attributes, namely if the
161 # channel data already exists then read them in with our writing back
162 if "mth5_type" in list(self.hdf5_dataset.attrs.keys()):
163 self.metadata.from_dict(
164 {self.hdf5_dataset.attrs["mth5_type"]: dict(self.hdf5_dataset.attrs)}
165 )
166 # if metadata is input, make sure that its the same class type amd write
167 # to the hdf5 dataset
168 if dataset_metadata is not None:
169 if not isinstance(self.metadata, type(dataset_metadata)):
170 msg = (
171 f"metadata must be type metadata.{self._class_name} not "
172 "{type(dataset_metadata)}"
173 )
174 self.logger.error(msg)
175 raise MTH5Error(msg)
176 # load from dict because of the extra attributes for MTH5
177 self.metadata.update(dataset_metadata)
178 # self.metadata.hdf5_reference = self.hdf5_dataset.ref
179 # self.metadata.mth5_type = self._class_name
181 # write out metadata to make sure that its in the file.
182 if write_metadata:
183 self.write_metadata()
184 # if the attrs don't have the proper metadata keys yet write them
185 if not "mth5_type" in list(self.hdf5_dataset.attrs.keys()):
186 self.write_metadata()
188 def __str__(self) -> str:
189 """
190 Return string representation of the estimate as JSON.
192 Returns
193 -------
194 str
195 JSON representation of the estimate metadata.
197 Examples
198 --------
199 >>> est_str = str(est)
200 >>> print(est_str[:50]) # Print first 50 characters
201 {"estimate": {"name": "estimate", ...
203 """
204 return self.metadata.to_json()
206 def __repr__(self) -> str:
207 """
208 Return official string representation of the estimate.
210 Returns
211 -------
212 str
213 JSON representation of the estimate metadata.
215 Examples
216 --------
217 >>> repr(est) == str(est)
218 True
220 """
221 return self.__str__()
223 @property
224 def _class_name(self) -> str:
225 """
226 Extract the class name without 'Dataset' suffix.
228 Returns
229 -------
230 str
231 Class name with 'Dataset' suffix removed.
233 Examples
234 --------
235 >>> est._class_name
236 'Estimate'
238 """
239 return self.__class__.__name__.split("Dataset")[0]
241 def read_metadata(self) -> None:
242 """
243 Read metadata from HDF5 attributes into metadata container.
245 Reads all attributes from the HDF5 dataset and loads them into
246 the internal metadata object for validation and access.
248 Returns
249 -------
250 None
252 Notes
253 -----
254 This is automatically called during initialization if 'mth5_type'
255 attribute exists in the HDF5 dataset.
257 Examples
258 --------
259 Reload metadata from HDF5 after external modification:
261 >>> # Metadata was modified in HDF5
262 >>> est.read_metadata() # Reload changes
263 >>> print(est.metadata.name) # Access updated name
265 """
266 meta_dict = read_attrs_to_dict(dict(self.hdf5_dataset.attrs), self.metadata)
267 # Defensive check: skip if meta_dict is empty
268 if not meta_dict:
269 self.logger.debug(
270 f"No metadata found for {self._class_name}, skipping from_dict."
271 )
272 return
273 self.metadata.from_dict({self._class_name: meta_dict})
274 self._has_read_metadata = True
276 def write_metadata(self) -> None:
277 """
278 Write metadata from container to HDF5 dataset attributes.
280 Converts the pydantic metadata model to a dictionary and writes
281 each field as an HDF5 attribute. Values are converted to appropriate
282 numpy types for compatibility.
284 Returns
285 -------
286 None
288 Notes
289 -----
290 All existing attributes with the same names will be overwritten.
291 This is called automatically during initialization and after
292 metadata updates.
294 Examples
295 --------
296 Save updated metadata to HDF5:
298 >>> est.metadata.name = "Updated Estimate"
299 >>> est.write_metadata() # Persist to file
300 >>> # Verify write
301 >>> print(est.hdf5_dataset.attrs['name'])
302 b'Updated Estimate'
304 """
305 meta_dict = self.metadata.to_dict()[self.metadata._class_name.lower()]
306 for key, value in meta_dict.items():
307 value = to_numpy_type(value)
308 self.hdf5_dataset.attrs.create(key, value)
310 def replace_dataset(self, new_data_array: np.ndarray) -> None:
311 """
312 Replace entire dataset with new data.
314 Resizes the HDF5 dataset if necessary and replaces all data.
315 Converts input to numpy array if needed.
317 Parameters
318 ----------
319 new_data_array : np.ndarray
320 New estimate data to store. Should have shape
321 (n_periods, n_output_channels, n_input_channels).
323 Returns
324 -------
325 None
327 Raises
328 ------
329 TypeError
330 If input cannot be converted to numpy array.
332 Notes
333 -----
334 If new data has different shape, HDF5 dataset will be resized.
335 This is generally safe but may fragment the HDF5 file.
337 Examples
338 --------
339 Replace estimate with new data:
341 >>> import numpy as np
342 >>> new_estimate = np.random.rand(10, 2, 2) # 10 periods, 2 channels
343 >>> est.replace_dataset(new_estimate)
344 >>> print(est.to_numpy().shape)
345 (10, 2, 2)
347 Replace with data from list (auto-converted to array):
349 >>> data_list = [[[1, 2], [3, 4]]] * 5 # 5 periods
350 >>> est.replace_dataset(data_list)
351 >>> est.to_numpy().shape
352 (5, 2, 2)
354 """
355 if not isinstance(new_data_array, np.ndarray):
356 try:
357 new_data_array = np.array(new_data_array)
358 except (ValueError, TypeError) as error:
359 msg = f"{error} Input must be a numpy array not {type(new_data_array)}"
360 self.logger.exception(msg)
361 raise TypeError(msg)
362 if new_data_array.shape != self.hdf5_dataset.shape:
363 self.hdf5_dataset.resize(new_data_array.shape)
364 self.hdf5_dataset[...] = new_data_array
366 def to_xarray(self, period: np.ndarray | list) -> xr.DataArray:
367 """
368 Convert estimate to xarray DataArray.
370 Creates an xarray DataArray with proper coordinates for periods,
371 output channels, and input channels. Includes metadata as attributes.
373 Parameters
374 ----------
375 period : np.ndarray | list
376 Period values for coordinate. Should have length equal to
377 estimate first dimension (n_periods).
379 Returns
380 -------
381 xr.DataArray
382 DataArray with dimensions (period, output, input) and
383 coordinates from metadata.
385 Notes
386 -----
387 Metadata changes in xarray are not validated and will not be
388 synchronized back to HDF5 without explicit call to from_xarray().
389 Data is loaded entirely into memory.
391 Examples
392 --------
393 Convert to xarray with logarithmic period spacing:
395 >>> import numpy as np
396 >>> periods = np.logspace(-2, 3, 10) # 10 periods from 0.01 to 1000
397 >>> xr_data = est.to_xarray(periods)
398 >>> print(xr_data.dims)
399 ('period', 'output', 'input')
400 >>> print(xr_data.coords['period'].values)
401 [1.00e-02 3.16e-02 ... 1.00e+03]
403 Select data by period range:
405 >>> subset = xr_data.sel(period=slice(0.1, 100))
406 >>> print(subset.shape)
407 (8, 2, 2)
409 """
410 return xr.DataArray(
411 data=self.hdf5_dataset[()],
412 dims=["period", "output", "input"],
413 name=self.metadata.name,
414 coords=[
415 ("period", period),
416 ("output", self.metadata.output_channels),
417 ("input", self.metadata.input_channels),
418 ],
419 attrs=self.metadata.to_dict(single=True),
420 )
422 def to_numpy(self) -> np.ndarray:
423 """
424 Convert estimate to numpy array.
426 Returns the HDF5 dataset as a numpy array. Data is loaded
427 entirely into memory.
429 Returns
430 -------
431 np.ndarray
432 3D array with shape (n_periods, n_output_channels, n_input_channels).
434 Notes
435 -----
436 For large estimates, this loads all data into RAM. Consider using
437 HDF5 slicing for memory-efficient access.
439 Examples
440 --------
441 Get full estimate as numpy array:
443 >>> data = est.to_numpy()
444 >>> print(data.shape)
445 (10, 2, 2)
446 >>> print(data.dtype)
447 float64
449 Access specific period and channels:
451 >>> data = est.to_numpy()
452 >>> # Get first 5 periods, output channel 0, input channel 1
453 >>> subset = data[:5, 0, 1]
454 >>> print(subset.shape)
455 (5,)
457 """
458 return self.hdf5_dataset[()]
460 def from_numpy(self, new_estimate: np.ndarray) -> None:
461 """
462 Load estimate data from numpy array.
464 Validates dtype and shape compatibility, resizes dataset if needed,
465 and stores the data.
467 Parameters
468 ----------
469 new_estimate : np.ndarray
470 Estimate data to load. Must be convertible to numpy array.
471 Preferred shape: (n_periods, n_output_channels, n_input_channels).
473 Returns
474 -------
475 None
477 Raises
478 ------
479 TypeError
480 If dtype doesn't match existing dataset or input cannot
481 be converted to numpy array.
483 Notes
484 -----
485 'data' is a built-in Python function and cannot be used as parameter name.
486 The dataset will be resized if shape doesn't match.
488 Examples
489 --------
490 Load estimate from numpy array:
492 >>> import numpy as np
493 >>> new_data = np.random.rand(5, 2, 2)
494 >>> est.from_numpy(new_data)
495 >>> print(est.to_numpy().shape)
496 (5, 2, 2)
498 Load with automatic dtype conversion:
500 >>> float_data = np.array([[[1.0, 2.0]]], dtype=np.float64)
501 >>> est.from_numpy(float_data)
503 """
504 if not isinstance(new_estimate, np.ndarray):
505 try:
506 new_estimate = np.array(new_estimate)
507 except (ValueError, TypeError) as error:
508 msg = f"{error} Input must be a numpy array not {type(new_estimate)}"
509 self.logger.exception(msg)
510 raise TypeError(msg)
511 if new_estimate.dtype != self.hdf5_dataset.dtype:
512 msg = f"Input array must be type {new_estimate.dtype} not {self.hdf5_dataset.dtype}"
513 self.logger.error(msg)
514 raise TypeError(msg)
515 if new_estimate.shape != self.hdf5_dataset.shape:
516 self.hdf5_dataset.resize(new_estimate.shape)
517 self.hdf5_dataset[...] = new_estimate
519 def from_xarray(self, data: xr.DataArray) -> None:
520 """
521 Load estimate data from xarray DataArray.
523 Updates metadata from xarray coordinates and attributes, then
524 stores the data.
526 Parameters
527 ----------
528 data : xr.DataArray
529 DataArray containing estimate. Expected dimensions:
530 (period, output, input).
532 Returns
533 -------
534 None
536 Notes
537 -----
538 This will update output_channels, input_channels, name, and data_type
539 from the xarray object. All changes are persisted to HDF5.
541 Examples
542 --------
543 Load estimate from modified xarray:
545 >>> xr_data = est.to_xarray(periods)
546 >>> # Modify data and metadata
547 >>> modified = xr_data * 2 # Scale by 2
548 >>> est.from_xarray(modified)
549 >>> print(est.to_numpy()[0, 0, 0]) # Verify scale
551 Rename channels and reload:
553 >>> xr_data = est.to_xarray(periods)
554 >>> new_xr = xr_data.rename({
555 ... 'output': ['Ex', 'Ey'],
556 ... 'input': ['Bx', 'By']
557 ... })
558 >>> est.from_xarray(new_xr)
559 >>> print(est.metadata.output_channels)
560 ['Ex', 'Ey']
562 """
563 self.metadata.output_channels = data.coords["output"].values.tolist()
564 self.metadata.input_channels = data.coords["input"].values.tolist()
565 self.metadata.name = data.name
566 self.metadata.data_type = data.dtype.name
568 self.write_metadata()
570 self.from_numpy(data.to_numpy())