Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mth5 \ mth5 \ tables \ mth5_table.py: 92%
192 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-27 20:09 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-27 20:09 -0800
1# -*- coding: utf-8 -*-
2"""
3MTH5 table utilities.
5This module provides the `MTH5Table` base class which wraps an HDF5 dataset
6and offers convenience methods for row management, locating entries, and
7exporting to `pandas.DataFrame`.
9Notes
10-----
11- Designed as a thin layer on top of NumPy/HDF5; for complex querying, prefer
12 converting to a DataFrame via `to_dataframe()`.
13- Datatypes are validated and kept consistent with the underlying dataset.
15"""
16from __future__ import annotations
18# =============================================================================
19# Imports
20# =============================================================================
21import weakref
22from typing import Any, cast, Literal
24import h5py
25import numpy as np
26import pandas as pd
27from loguru import logger
29from mth5.utils.exceptions import MTH5TableError
32# =============================================================================
33# MTH5 Table Class
34# =============================================================================
37class MTH5Table:
38 """
39 Base wrapper around an HDF5 dataset representing a typed table.
41 Provides simple NumPy-based operations including row insertion/removal,
42 basic locating utilities, and conversion to `pandas.DataFrame`.
44 Parameters
45 ----------
46 hdf5_dataset : h5py.Dataset
47 The HDF5 dataset that stores the table.
48 default_dtype : numpy.dtype
49 The default dtype schema for the table entries.
51 Raises
52 ------
53 MTH5TableError
54 If `hdf5_dataset` is not an instance of `h5py.Dataset`.
56 Examples
57 --------
58 Create a simple table and add a row::
60 >>> import h5py, numpy as np
61 >>> f = h5py.File('example.h5', 'w')
62 >>> dtype = np.dtype([('name', 'S16'), ('value', 'f8')])
63 >>> ds = f.create_dataset('table', (1,), maxshape=(None,), dtype=dtype)
64 >>> from mth5.tables.mth5_table import MTH5Table
65 >>> t = MTH5Table(ds, dtype)
66 >>> row = np.array([('alpha'.encode('utf-8'), 1.23)], dtype=dtype)
67 >>> t.add_row(row)
68 1
69 >>> df = t.to_dataframe()
70 >>> df.head()
72 """
74 def __init__(self, hdf5_dataset: h5py.Dataset, default_dtype: np.dtype) -> None:
75 self.logger = logger
76 self._default_dtype = self._validate_dtype(default_dtype)
78 # validate dtype with dataset
79 if isinstance(hdf5_dataset, h5py.Dataset):
80 # Use a weak reference to the dataset and ensure it's valid
81 _ref = weakref.ref(hdf5_dataset)()
82 if _ref is None:
83 raise MTH5TableError("Dataset reference is not available.")
84 self.array: h5py.Dataset = cast(h5py.Dataset, _ref)
85 if self.array.dtype != self._default_dtype:
86 self.update_dtype(self._default_dtype)
87 else:
88 msg = f"Input must be a h5py.Dataset not {type(hdf5_dataset)}"
89 self.logger.error(msg)
90 raise MTH5TableError(msg)
92 def __str__(self) -> str:
93 """
94 Return a string representation of the table contents.
96 Returns
97 -------
98 str
99 A string representation of the table's DataFrame contents or an
100 empty string if the table is empty.
101 """
102 # if the array is empty
103 if getattr(self.array, "size", 0) > 0:
104 df = self.to_dataframe()
106 return df.__str__()
107 return ""
109 def __repr__(self) -> str:
110 return self.__str__()
112 def __eq__(self, other: MTH5Table | h5py.Dataset | object) -> bool:
113 if isinstance(other, MTH5Table):
114 return self.array == other.array
115 elif isinstance(other, h5py.Dataset):
116 return self.array == other
117 else:
118 msg = f"Cannot compare type={type(other)}"
119 self.logger.error(msg)
120 raise TypeError(msg)
122 def __ne__(self, other: MTH5Table | h5py.Dataset | object) -> bool:
123 return not self.__eq__(other)
125 def __len__(self) -> int:
126 return self.array.shape[0]
128 @property
129 def hdf5_reference(self) -> object:
130 return getattr(self.array, "ref", None)
132 @property
133 def dtype(self) -> np.dtype:
134 return self._default_dtype
136 @dtype.setter
137 def dtype(self, value: np.dtype) -> None:
138 """
139 Set the table dtype, updating the underlying dataset if it differs.
141 Parameters
142 ----------
143 value : numpy.dtype
144 New dtype to apply. Must match the existing field names.
146 Raises
147 ------
148 TypeError
149 If `value` is not an instance of `numpy.dtype`.
150 """
152 if not isinstance(value, np.dtype):
153 raise TypeError(f"Input dtype must be np.dtype not type {type(value)}")
155 if value != self._default_dtype:
156 self.update_dtype(value)
158 def _validate_dtype(self, value: np.dtype) -> np.dtype:
159 """
160 Validate that `value` is a `numpy.dtype`.
162 Parameters
163 ----------
164 value : numpy.dtype
165 Dtype to validate.
167 Returns
168 -------
169 numpy.dtype
170 The validated dtype.
172 Raises
173 ------
174 TypeError
175 If `value` is not a `numpy.dtype`.
176 """
177 if not isinstance(value, np.dtype):
178 msg = f"Input dtype must be np.dtype not type {type(value)}"
179 self.logger.exception(msg)
180 raise TypeError(msg)
181 return value
183 def _validate_dtype_names(self, value: np.dtype) -> np.dtype:
184 if self.dtype.names != value.names:
185 msg = f"New dtype must have the same names: {self.dtype.names}"
186 self.logger.exception(msg)
187 raise TypeError(msg)
189 return value
191 def check_dtypes(self, other_dtype: np.dtype) -> bool:
192 """
193 Check that dtypes match the table's dtype (including field names).
195 Parameters
196 ----------
197 other_dtype : numpy.dtype
198 The dtype to compare against the table's dtype.
200 Returns
201 -------
202 bool
203 True if the dtypes match; otherwise False.
204 """
205 other_dtype = self._validate_dtype(other_dtype)
206 try:
207 other_dtype = self._validate_dtype_names(other_dtype)
208 except TypeError:
209 return False
210 if self.dtype == other_dtype:
211 return True
212 return False
214 @property
215 def shape(self) -> tuple[int, ...]:
216 return self.array.shape
218 @property
219 def nrows(self) -> int:
220 return self.array.shape[0]
222 def locate(
223 self,
224 column: str,
225 value: Any,
226 test: Literal["eq", "lt", "le", "gt", "ge", "be", "bt"] = "eq",
227 ) -> np.ndarray:
228 """
229 Locate row indices where a column satisfies a comparison.
231 Parameters
232 ----------
233 column : str
234 Name of the column to test.
235 value : Any
236 Value to compare against. For string columns, a `str` is converted
237 to a `numpy.bytes_`. For time columns (`start`, `end`,
238 `start_date`, `end_date`), values are coerced to `numpy.datetime64`.
239 test : {'eq','lt','le','gt','ge','be','bt'}, default 'eq'
240 Type of comparison to perform.
241 - 'eq': equals
242 - 'lt': less than
243 - 'le': less than or equal to
244 - 'gt': greater than
245 - 'ge': greater than or equal to
246 - 'be': strictly between
247 - 'bt': alias for 'be'
249 Returns
250 -------
251 numpy.ndarray
252 Array of matching row indices.
254 Raises
255 ------
256 ValueError
257 If `test` is 'be'/'bt' and `value` is not a 2-length iterable.
259 Examples
260 --------
261 Find rows with value greater than 10::
263 >>> idx = t.locate('value', 10, test='gt')
264 """
265 if isinstance(value, str):
266 value = np.bytes_(value)
267 # use numpy datetime for testing against time.
268 if column in ["start", "end", "start_date", "end_date"]:
269 test_array = self.array[column].astype(np.datetime64)
270 value = np.datetime64(value)
271 else:
272 test_array = self.array[column]
273 if test == "eq":
274 index_values = np.where(test_array == value)[0]
275 elif test == "lt":
276 index_values = np.where(test_array < value)[0]
277 elif test == "le":
278 index_values = np.where(test_array <= value)[0]
279 elif test == "gt":
280 index_values = np.where(test_array > value)[0]
281 elif test == "ge":
282 index_values = np.where(test_array >= value)[0]
283 elif test == "be":
284 if not isinstance(value, (list, tuple, np.ndarray)):
285 msg = "If testing for between value must be an iterable of length 2."
286 self.logger.error(msg)
287 raise ValueError(msg)
288 index_values = np.where((test_array > value[0]) & (test_array < value[1]))[
289 0
290 ]
291 else:
292 raise ValueError("Test {0} not understood".format(test))
293 return index_values
295 def add_row(self, row: np.ndarray, index: int | None = None) -> int:
296 """
297 Add a row to the table.
299 Parameters
300 ----------
301 row : numpy.ndarray
302 Row to insert. Must have the same dtype (or same field names,
303 allowing safe casting) as the table.
304 index : int, optional
305 Index at which to insert the row. If None, appends to the end.
307 Returns
308 -------
309 int
310 Index of the inserted row.
312 Raises
313 ------
314 TypeError
315 If `row` is not a `numpy.ndarray`.
316 ValueError
317 If the dtype is incompatible with the table.
318 """
320 if not isinstance(row, (np.ndarray)):
321 msg = f"Input must be an numpy.ndarray not {type(row)}"
322 self.logger.exception(msg)
323 raise TypeError(msg)
324 if isinstance(row, np.ndarray):
325 if not self.check_dtypes(row.dtype):
326 if row.dtype.names == self.dtype.names:
327 row = row.astype(self.dtype)
328 else:
329 msg = (
330 f"Data types are not equal. Input dtypes: "
331 f"{row.dtype} Table dtypes: {self.dtype}"
332 )
333 self.logger.error(msg)
334 raise ValueError(msg)
335 if index is None:
336 index = self.nrows
337 if self.nrows == 1:
338 match = True
339 null_array = np.zeros(1, dtype=self.dtype)
340 if self.dtype.names is None:
341 raise TypeError("Table dtype must have named fields.")
342 for name in self.dtype.names:
343 if "reference" in name:
344 continue
345 if self.array[name][0] != null_array[name][0]:
346 match = False
347 break
348 if match:
349 index = 0
350 else:
351 new_shape = tuple([self.nrows + 1] + [ii for ii in self.shape[1:]])
352 self.array.resize(new_shape)
353 else:
354 new_shape = tuple([self.nrows + 1] + [ii for ii in self.shape[1:]])
355 self.array.resize(new_shape)
356 # add the row
357 self.array[index] = row
358 self.logger.debug(f"Added row as index {index} with values {row}")
360 return index
362 def update_row(self, entry: np.ndarray) -> int:
363 """
364 Update a row by locating its index and rewriting the entry.
366 Parameters
367 ----------
368 entry : numpy.ndarray
369 Entry to update, with the same dtype as the table.
371 Returns
372 -------
373 int
374 Row index that was updated, or the new row index if not found.
376 Notes
377 -----
378 Matching by `hdf5_reference` is not reliable; this uses `add_row`
379 and will append if the original row cannot be located.
380 """
381 try:
382 row_index = self.locate("hdf5_reference", entry["hdf5_reference"])[0]
383 return self.add_row(entry, index=row_index)
384 except IndexError:
385 self.logger.debug("Could not find row, adding a new one")
386 return self.add_row(entry)
388 def remove_row(self, index: int) -> int:
389 """
390 Remove a row by replacing it with a null entry.
392 Parameters
393 ----------
394 index : int
395 Index of the row to remove.
397 Returns
398 -------
399 int
400 Index that was updated with a null row.
402 Raises
403 ------
404 IndexError
405 If the index is out of bounds for the current shape.
407 Notes
408 -----
409 - There is no intrinsic index stored within the array; indexing is
410 on-the-fly. Prefer using the HDF5 reference column for robust
411 identification.
412 - The current approach inserts a null row at the specified index.
413 """
414 null_array = np.empty((1,), dtype=self.dtype)
415 try:
416 return self.add_row(null_array, index=index)
417 except IndexError as error:
418 msg = f"Could not find index {index} in shape {self.shape}"
419 self.logger.exception(msg)
420 raise IndexError(f"{error}\n{msg}")
422 def to_dataframe(self) -> pd.DataFrame:
423 """
424 Convert the table into a `pandas.DataFrame`.
426 Returns
427 -------
428 pandas.DataFrame
429 DataFrame with decoded string columns where applicable.
431 Examples
432 --------
433 Convert and preview::
435 >>> df = t.to_dataframe()
436 >>> df.head()
437 """
439 df = pd.DataFrame(self.array[()])
440 if self.dtype.names is None:
441 raise TypeError("Table dtype must have named fields.")
442 fields = self.dtype.fields or {}
443 for key in self.dtype.names:
444 field_info = fields.get(cast(Any, key))
445 if field_info is None:
446 continue
447 dtype_kind = field_info[0].kind
448 if dtype_kind in ["S", "U"]:
449 setattr(df, key, getattr(df, key).str.decode("utf-8"))
451 return df
453 def clear_table(self) -> None:
454 """
455 Reset the table by recreating the dataset with a single null row.
457 Notes
458 -----
459 Deletes the current dataset and replaces it with a new dataset with
460 the same compression/options and `dtype`, but shape `(1,)`.
461 """
463 root = self.array.parent
464 if not isinstance(root, (h5py.Group, h5py.File)):
465 raise TypeError("Unexpected parent type; expected Group or File.")
466 name = str(self.array.name).split("/")[-1]
467 ds_options = {
468 "compression": self.array.compression,
469 "compression_opts": self.array.compression_opts,
470 "shuffle": self.array.shuffle,
471 "fletcher32": self.array.fletcher32,
472 }
474 del root[name]
476 self.array = root.create_dataset(
477 name, (1,), maxshape=(None,), dtype=self.dtype, **ds_options
478 )
480 def update_dtype(self, new_dtype: np.dtype) -> None:
481 """
482 Update the dataset's dtype while preserving data and field names.
484 Parameters
485 ----------
486 new_dtype : numpy.dtype
487 New dtype to apply. Must have identical field names.
489 Notes
490 -----
491 Performs a manual copy into a new array to avoid unsafe casting
492 errors, then recreates the dataset with the new dtype and same
493 dataset options.
494 """
496 try:
497 new_dtype = self._validate_dtype_names(self._validate_dtype(new_dtype))
499 # need to do this manually otherwise get an error of not safe
500 new_array = np.ones(self.array.shape, dtype=new_dtype)
501 for key in self.array.dtype.fields.keys():
502 new_array[key] = self.array[key][()]
504 root = self.array.parent
505 if not isinstance(root, (h5py.Group, h5py.File)):
506 raise TypeError("Unexpected parent type; expected Group or File.")
507 name = str(self.array.name).split("/")[-1]
508 ds_options = {
509 "compression": self.array.compression,
510 "compression_opts": self.array.compression_opts,
511 "shuffle": self.array.shuffle,
512 "fletcher32": self.array.fletcher32,
513 }
515 del root[name]
517 self.array = root.create_dataset(
518 name,
519 data=new_array,
520 maxshape=(None,),
521 dtype=new_dtype,
522 **ds_options,
523 )
525 self._default_dtype = new_dtype
526 except:
527 self.logger.info(
528 "Could not update table dtype, likely an older file. Clearing table."
529 )
530 self.clear_table()