Coverage for C:\src\imod-python\imod\mf6\pkgbase.py: 98%

48 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-08 13:27 +0200

1import abc 

2import numbers 

3import pathlib 

4from typing import Mapping 

5 

6import numpy as np 

7import xarray as xr 

8import xugrid as xu 

9 

10import imod 

11from imod.mf6.interfaces.ipackagebase import IPackageBase 

12from imod.typing.grid import GridDataArray, GridDataset, merge_with_dictionary 

13 

14TRANSPORT_PACKAGES = ("adv", "dsp", "ssm", "mst", "ist", "src") 

15EXCHANGE_PACKAGES = ("gwfgwf", "gwfgwt", "gwtgwt") 

16 

17 

18class PackageBase(IPackageBase, abc.ABC): 

19 """ 

20 This class is used for storing a collection of Xarray DataArrays or UgridDataArrays 

21 in a dataset. A load-from-file method is also provided. Storing to file is done by calling 

22 object.dataset.to_netcdf(...) 

23 """ 

24 

25 # This method has been added to allow mock.patch to mock created objects 

26 # https://stackoverflow.com/questions/64737213/how-to-patch-the-new-method-of-a-class 

27 def __new__(cls, *_, **__): 

28 return super(PackageBase, cls).__new__(cls) 

29 

30 def __init__( 

31 self, variables_to_merge: Mapping[str, GridDataArray | float | int | bool | str] 

32 ): 

33 # Merge variables, perform exact join to verify if coordinates values 

34 # are consistent amongst variables. 

35 self.__dataset = merge_with_dictionary(variables_to_merge, join="exact") 

36 

37 @property 

38 def dataset(self) -> GridDataset: 

39 return self.__dataset 

40 

41 @dataset.setter 

42 def dataset(self, value: GridDataset) -> None: 

43 self.__dataset = value 

44 

45 def __getitem__(self, key): 

46 return self.dataset.__getitem__(key) 

47 

48 def __setitem__(self, key, value): 

49 self.dataset.__setitem__(key, value) 

50 

51 def to_netcdf(self, *args, **kwargs): 

52 """ 

53 

54 Write dataset contents to a netCDF file. 

55 Custom encoding rules can be provided on package level by overriding the _netcdf_encoding in the package 

56 

57 """ 

58 kwargs.update({"encoding": self._netcdf_encoding()}) 

59 self.dataset.to_netcdf(*args, **kwargs) 

60 

61 def _netcdf_encoding(self): 

62 """ 

63 

64 The encoding used in the to_netcdf method 

65 Override this to provide custom encoding rules 

66 

67 """ 

68 return {} 

69 

70 @classmethod 

71 def from_file(cls, path, **kwargs): 

72 """ 

73 Loads an imod mf6 package from a file (currently only netcdf and zarr are supported). 

74 Note that it is expected that this file was saved with imod.mf6.Package.dataset.to_netcdf(), 

75 as the checks upon package initialization are not done again! 

76 

77 Parameters 

78 ---------- 

79 path : str, pathlib.Path 

80 Path to the file. 

81 **kwargs : keyword arguments 

82 Arbitrary keyword arguments forwarded to ``xarray.open_dataset()``, or 

83 ``xarray.open_zarr()``. 

84 Refer to the examples. 

85 

86 Returns 

87 ------- 

88 package : imod.mf6.Package 

89 Returns a package with data loaded from file. 

90 

91 Examples 

92 -------- 

93 

94 To load a package from a file, e.g. a River package: 

95 

96 >>> river = imod.mf6.River.from_file("river.nc") 

97 

98 For large datasets, you likely want to process it in chunks. You can 

99 forward keyword arguments to ``xarray.open_dataset()`` or 

100 ``xarray.open_zarr()``: 

101 

102 >>> river = imod.mf6.River.from_file("river.nc", chunks={"time": 1}) 

103 

104 Refer to the xarray documentation for the possible keyword arguments. 

105 """ 

106 path = pathlib.Path(path) 

107 if path.suffix in (".zip", ".zarr"): 

108 # TODO: seems like a bug? Remove str() call if fixed in xarray/zarr 

109 dataset = xr.open_zarr(str(path), **kwargs) 

110 else: 

111 dataset = xr.open_dataset(path, **kwargs) 

112 

113 if dataset.ugrid_roles.topology: 

114 dataset = xu.UgridDataset(dataset) 

115 dataset = imod.util.spatial.from_mdal_compliant_ugrid2d(dataset) 

116 

117 # Replace NaNs by None 

118 for key, value in dataset.items(): 

119 stripped_value = value.values[()] 

120 if isinstance(stripped_value, numbers.Real) and np.isnan(stripped_value): 

121 dataset[key] = None 

122 

123 instance = cls.__new__(cls) 

124 instance.dataset = dataset 

125 return instance