Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mth5 \ mth5 \ tables \ mth5_table.py: 92%

192 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-27 20:09 -0800

1# -*- coding: utf-8 -*- 

2""" 

3MTH5 table utilities. 

4 

5This module provides the `MTH5Table` base class which wraps an HDF5 dataset 

6and offers convenience methods for row management, locating entries, and 

7exporting to `pandas.DataFrame`. 

8 

9Notes 

10----- 

11- Designed as a thin layer on top of NumPy/HDF5; for complex querying, prefer 

12 converting to a DataFrame via `to_dataframe()`. 

13- Datatypes are validated and kept consistent with the underlying dataset. 

14 

15""" 

16from __future__ import annotations 

17 

18# ============================================================================= 

19# Imports 

20# ============================================================================= 

21import weakref 

22from typing import Any, cast, Literal 

23 

24import h5py 

25import numpy as np 

26import pandas as pd 

27from loguru import logger 

28 

29from mth5.utils.exceptions import MTH5TableError 

30 

31 

32# ============================================================================= 

33# MTH5 Table Class 

34# ============================================================================= 

35 

36 

37class MTH5Table: 

38 """ 

39 Base wrapper around an HDF5 dataset representing a typed table. 

40 

41 Provides simple NumPy-based operations including row insertion/removal, 

42 basic locating utilities, and conversion to `pandas.DataFrame`. 

43 

44 Parameters 

45 ---------- 

46 hdf5_dataset : h5py.Dataset 

47 The HDF5 dataset that stores the table. 

48 default_dtype : numpy.dtype 

49 The default dtype schema for the table entries. 

50 

51 Raises 

52 ------ 

53 MTH5TableError 

54 If `hdf5_dataset` is not an instance of `h5py.Dataset`. 

55 

56 Examples 

57 -------- 

58 Create a simple table and add a row:: 

59 

60 >>> import h5py, numpy as np 

61 >>> f = h5py.File('example.h5', 'w') 

62 >>> dtype = np.dtype([('name', 'S16'), ('value', 'f8')]) 

63 >>> ds = f.create_dataset('table', (1,), maxshape=(None,), dtype=dtype) 

64 >>> from mth5.tables.mth5_table import MTH5Table 

65 >>> t = MTH5Table(ds, dtype) 

66 >>> row = np.array([('alpha'.encode('utf-8'), 1.23)], dtype=dtype) 

67 >>> t.add_row(row) 

68 1 

69 >>> df = t.to_dataframe() 

70 >>> df.head() 

71 

72 """ 

73 

74 def __init__(self, hdf5_dataset: h5py.Dataset, default_dtype: np.dtype) -> None: 

75 self.logger = logger 

76 self._default_dtype = self._validate_dtype(default_dtype) 

77 

78 # validate dtype with dataset 

79 if isinstance(hdf5_dataset, h5py.Dataset): 

80 # Use a weak reference to the dataset and ensure it's valid 

81 _ref = weakref.ref(hdf5_dataset)() 

82 if _ref is None: 

83 raise MTH5TableError("Dataset reference is not available.") 

84 self.array: h5py.Dataset = cast(h5py.Dataset, _ref) 

85 if self.array.dtype != self._default_dtype: 

86 self.update_dtype(self._default_dtype) 

87 else: 

88 msg = f"Input must be a h5py.Dataset not {type(hdf5_dataset)}" 

89 self.logger.error(msg) 

90 raise MTH5TableError(msg) 

91 

92 def __str__(self) -> str: 

93 """ 

94 Return a string representation of the table contents. 

95 

96 Returns 

97 ------- 

98 str 

99 A string representation of the table's DataFrame contents or an 

100 empty string if the table is empty. 

101 """ 

102 # if the array is empty 

103 if getattr(self.array, "size", 0) > 0: 

104 df = self.to_dataframe() 

105 

106 return df.__str__() 

107 return "" 

108 

109 def __repr__(self) -> str: 

110 return self.__str__() 

111 

112 def __eq__(self, other: MTH5Table | h5py.Dataset | object) -> bool: 

113 if isinstance(other, MTH5Table): 

114 return self.array == other.array 

115 elif isinstance(other, h5py.Dataset): 

116 return self.array == other 

117 else: 

118 msg = f"Cannot compare type={type(other)}" 

119 self.logger.error(msg) 

120 raise TypeError(msg) 

121 

122 def __ne__(self, other: MTH5Table | h5py.Dataset | object) -> bool: 

123 return not self.__eq__(other) 

124 

125 def __len__(self) -> int: 

126 return self.array.shape[0] 

127 

128 @property 

129 def hdf5_reference(self) -> object: 

130 return getattr(self.array, "ref", None) 

131 

132 @property 

133 def dtype(self) -> np.dtype: 

134 return self._default_dtype 

135 

136 @dtype.setter 

137 def dtype(self, value: np.dtype) -> None: 

138 """ 

139 Set the table dtype, updating the underlying dataset if it differs. 

140 

141 Parameters 

142 ---------- 

143 value : numpy.dtype 

144 New dtype to apply. Must match the existing field names. 

145 

146 Raises 

147 ------ 

148 TypeError 

149 If `value` is not an instance of `numpy.dtype`. 

150 """ 

151 

152 if not isinstance(value, np.dtype): 

153 raise TypeError(f"Input dtype must be np.dtype not type {type(value)}") 

154 

155 if value != self._default_dtype: 

156 self.update_dtype(value) 

157 

158 def _validate_dtype(self, value: np.dtype) -> np.dtype: 

159 """ 

160 Validate that `value` is a `numpy.dtype`. 

161 

162 Parameters 

163 ---------- 

164 value : numpy.dtype 

165 Dtype to validate. 

166 

167 Returns 

168 ------- 

169 numpy.dtype 

170 The validated dtype. 

171 

172 Raises 

173 ------ 

174 TypeError 

175 If `value` is not a `numpy.dtype`. 

176 """ 

177 if not isinstance(value, np.dtype): 

178 msg = f"Input dtype must be np.dtype not type {type(value)}" 

179 self.logger.exception(msg) 

180 raise TypeError(msg) 

181 return value 

182 

183 def _validate_dtype_names(self, value: np.dtype) -> np.dtype: 

184 if self.dtype.names != value.names: 

185 msg = f"New dtype must have the same names: {self.dtype.names}" 

186 self.logger.exception(msg) 

187 raise TypeError(msg) 

188 

189 return value 

190 

191 def check_dtypes(self, other_dtype: np.dtype) -> bool: 

192 """ 

193 Check that dtypes match the table's dtype (including field names). 

194 

195 Parameters 

196 ---------- 

197 other_dtype : numpy.dtype 

198 The dtype to compare against the table's dtype. 

199 

200 Returns 

201 ------- 

202 bool 

203 True if the dtypes match; otherwise False. 

204 """ 

205 other_dtype = self._validate_dtype(other_dtype) 

206 try: 

207 other_dtype = self._validate_dtype_names(other_dtype) 

208 except TypeError: 

209 return False 

210 if self.dtype == other_dtype: 

211 return True 

212 return False 

213 

214 @property 

215 def shape(self) -> tuple[int, ...]: 

216 return self.array.shape 

217 

218 @property 

219 def nrows(self) -> int: 

220 return self.array.shape[0] 

221 

222 def locate( 

223 self, 

224 column: str, 

225 value: Any, 

226 test: Literal["eq", "lt", "le", "gt", "ge", "be", "bt"] = "eq", 

227 ) -> np.ndarray: 

228 """ 

229 Locate row indices where a column satisfies a comparison. 

230 

231 Parameters 

232 ---------- 

233 column : str 

234 Name of the column to test. 

235 value : Any 

236 Value to compare against. For string columns, a `str` is converted 

237 to a `numpy.bytes_`. For time columns (`start`, `end`, 

238 `start_date`, `end_date`), values are coerced to `numpy.datetime64`. 

239 test : {'eq','lt','le','gt','ge','be','bt'}, default 'eq' 

240 Type of comparison to perform. 

241 - 'eq': equals 

242 - 'lt': less than 

243 - 'le': less than or equal to 

244 - 'gt': greater than 

245 - 'ge': greater than or equal to 

246 - 'be': strictly between 

247 - 'bt': alias for 'be' 

248 

249 Returns 

250 ------- 

251 numpy.ndarray 

252 Array of matching row indices. 

253 

254 Raises 

255 ------ 

256 ValueError 

257 If `test` is 'be'/'bt' and `value` is not a 2-length iterable. 

258 

259 Examples 

260 -------- 

261 Find rows with value greater than 10:: 

262 

263 >>> idx = t.locate('value', 10, test='gt') 

264 """ 

265 if isinstance(value, str): 

266 value = np.bytes_(value) 

267 # use numpy datetime for testing against time. 

268 if column in ["start", "end", "start_date", "end_date"]: 

269 test_array = self.array[column].astype(np.datetime64) 

270 value = np.datetime64(value) 

271 else: 

272 test_array = self.array[column] 

273 if test == "eq": 

274 index_values = np.where(test_array == value)[0] 

275 elif test == "lt": 

276 index_values = np.where(test_array < value)[0] 

277 elif test == "le": 

278 index_values = np.where(test_array <= value)[0] 

279 elif test == "gt": 

280 index_values = np.where(test_array > value)[0] 

281 elif test == "ge": 

282 index_values = np.where(test_array >= value)[0] 

283 elif test == "be": 

284 if not isinstance(value, (list, tuple, np.ndarray)): 

285 msg = "If testing for between value must be an iterable of length 2." 

286 self.logger.error(msg) 

287 raise ValueError(msg) 

288 index_values = np.where((test_array > value[0]) & (test_array < value[1]))[ 

289 0 

290 ] 

291 else: 

292 raise ValueError("Test {0} not understood".format(test)) 

293 return index_values 

294 

295 def add_row(self, row: np.ndarray, index: int | None = None) -> int: 

296 """ 

297 Add a row to the table. 

298 

299 Parameters 

300 ---------- 

301 row : numpy.ndarray 

302 Row to insert. Must have the same dtype (or same field names, 

303 allowing safe casting) as the table. 

304 index : int, optional 

305 Index at which to insert the row. If None, appends to the end. 

306 

307 Returns 

308 ------- 

309 int 

310 Index of the inserted row. 

311 

312 Raises 

313 ------ 

314 TypeError 

315 If `row` is not a `numpy.ndarray`. 

316 ValueError 

317 If the dtype is incompatible with the table. 

318 """ 

319 

320 if not isinstance(row, (np.ndarray)): 

321 msg = f"Input must be an numpy.ndarray not {type(row)}" 

322 self.logger.exception(msg) 

323 raise TypeError(msg) 

324 if isinstance(row, np.ndarray): 

325 if not self.check_dtypes(row.dtype): 

326 if row.dtype.names == self.dtype.names: 

327 row = row.astype(self.dtype) 

328 else: 

329 msg = ( 

330 f"Data types are not equal. Input dtypes: " 

331 f"{row.dtype} Table dtypes: {self.dtype}" 

332 ) 

333 self.logger.error(msg) 

334 raise ValueError(msg) 

335 if index is None: 

336 index = self.nrows 

337 if self.nrows == 1: 

338 match = True 

339 null_array = np.zeros(1, dtype=self.dtype) 

340 if self.dtype.names is None: 

341 raise TypeError("Table dtype must have named fields.") 

342 for name in self.dtype.names: 

343 if "reference" in name: 

344 continue 

345 if self.array[name][0] != null_array[name][0]: 

346 match = False 

347 break 

348 if match: 

349 index = 0 

350 else: 

351 new_shape = tuple([self.nrows + 1] + [ii for ii in self.shape[1:]]) 

352 self.array.resize(new_shape) 

353 else: 

354 new_shape = tuple([self.nrows + 1] + [ii for ii in self.shape[1:]]) 

355 self.array.resize(new_shape) 

356 # add the row 

357 self.array[index] = row 

358 self.logger.debug(f"Added row as index {index} with values {row}") 

359 

360 return index 

361 

362 def update_row(self, entry: np.ndarray) -> int: 

363 """ 

364 Update a row by locating its index and rewriting the entry. 

365 

366 Parameters 

367 ---------- 

368 entry : numpy.ndarray 

369 Entry to update, with the same dtype as the table. 

370 

371 Returns 

372 ------- 

373 int 

374 Row index that was updated, or the new row index if not found. 

375 

376 Notes 

377 ----- 

378 Matching by `hdf5_reference` is not reliable; this uses `add_row` 

379 and will append if the original row cannot be located. 

380 """ 

381 try: 

382 row_index = self.locate("hdf5_reference", entry["hdf5_reference"])[0] 

383 return self.add_row(entry, index=row_index) 

384 except IndexError: 

385 self.logger.debug("Could not find row, adding a new one") 

386 return self.add_row(entry) 

387 

388 def remove_row(self, index: int) -> int: 

389 """ 

390 Remove a row by replacing it with a null entry. 

391 

392 Parameters 

393 ---------- 

394 index : int 

395 Index of the row to remove. 

396 

397 Returns 

398 ------- 

399 int 

400 Index that was updated with a null row. 

401 

402 Raises 

403 ------ 

404 IndexError 

405 If the index is out of bounds for the current shape. 

406 

407 Notes 

408 ----- 

409 - There is no intrinsic index stored within the array; indexing is 

410 on-the-fly. Prefer using the HDF5 reference column for robust 

411 identification. 

412 - The current approach inserts a null row at the specified index. 

413 """ 

414 null_array = np.empty((1,), dtype=self.dtype) 

415 try: 

416 return self.add_row(null_array, index=index) 

417 except IndexError as error: 

418 msg = f"Could not find index {index} in shape {self.shape}" 

419 self.logger.exception(msg) 

420 raise IndexError(f"{error}\n{msg}") 

421 

422 def to_dataframe(self) -> pd.DataFrame: 

423 """ 

424 Convert the table into a `pandas.DataFrame`. 

425 

426 Returns 

427 ------- 

428 pandas.DataFrame 

429 DataFrame with decoded string columns where applicable. 

430 

431 Examples 

432 -------- 

433 Convert and preview:: 

434 

435 >>> df = t.to_dataframe() 

436 >>> df.head() 

437 """ 

438 

439 df = pd.DataFrame(self.array[()]) 

440 if self.dtype.names is None: 

441 raise TypeError("Table dtype must have named fields.") 

442 fields = self.dtype.fields or {} 

443 for key in self.dtype.names: 

444 field_info = fields.get(cast(Any, key)) 

445 if field_info is None: 

446 continue 

447 dtype_kind = field_info[0].kind 

448 if dtype_kind in ["S", "U"]: 

449 setattr(df, key, getattr(df, key).str.decode("utf-8")) 

450 

451 return df 

452 

453 def clear_table(self) -> None: 

454 """ 

455 Reset the table by recreating the dataset with a single null row. 

456 

457 Notes 

458 ----- 

459 Deletes the current dataset and replaces it with a new dataset with 

460 the same compression/options and `dtype`, but shape `(1,)`. 

461 """ 

462 

463 root = self.array.parent 

464 if not isinstance(root, (h5py.Group, h5py.File)): 

465 raise TypeError("Unexpected parent type; expected Group or File.") 

466 name = str(self.array.name).split("/")[-1] 

467 ds_options = { 

468 "compression": self.array.compression, 

469 "compression_opts": self.array.compression_opts, 

470 "shuffle": self.array.shuffle, 

471 "fletcher32": self.array.fletcher32, 

472 } 

473 

474 del root[name] 

475 

476 self.array = root.create_dataset( 

477 name, (1,), maxshape=(None,), dtype=self.dtype, **ds_options 

478 ) 

479 

480 def update_dtype(self, new_dtype: np.dtype) -> None: 

481 """ 

482 Update the dataset's dtype while preserving data and field names. 

483 

484 Parameters 

485 ---------- 

486 new_dtype : numpy.dtype 

487 New dtype to apply. Must have identical field names. 

488 

489 Notes 

490 ----- 

491 Performs a manual copy into a new array to avoid unsafe casting 

492 errors, then recreates the dataset with the new dtype and same 

493 dataset options. 

494 """ 

495 

496 try: 

497 new_dtype = self._validate_dtype_names(self._validate_dtype(new_dtype)) 

498 

499 # need to do this manually otherwise get an error of not safe 

500 new_array = np.ones(self.array.shape, dtype=new_dtype) 

501 for key in self.array.dtype.fields.keys(): 

502 new_array[key] = self.array[key][()] 

503 

504 root = self.array.parent 

505 if not isinstance(root, (h5py.Group, h5py.File)): 

506 raise TypeError("Unexpected parent type; expected Group or File.") 

507 name = str(self.array.name).split("/")[-1] 

508 ds_options = { 

509 "compression": self.array.compression, 

510 "compression_opts": self.array.compression_opts, 

511 "shuffle": self.array.shuffle, 

512 "fletcher32": self.array.fletcher32, 

513 } 

514 

515 del root[name] 

516 

517 self.array = root.create_dataset( 

518 name, 

519 data=new_array, 

520 maxshape=(None,), 

521 dtype=new_dtype, 

522 **ds_options, 

523 ) 

524 

525 self._default_dtype = new_dtype 

526 except: 

527 self.logger.info( 

528 "Could not update table dtype, likely an older file. Clearing table." 

529 ) 

530 self.clear_table()