Coverage for C:\src\imod-python\imod\mf6\regridding_utils.py: 95%
56 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 15:36 +0100
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 15:36 +0100
1import abc
2from enum import Enum
3from typing import Any, Optional, Tuple, Union
5import xarray as xr
6import xugrid as xu
7from xugrid.regrid.regridder import BaseRegridder
9from imod.typing.grid import GridDataArray
12class RegridderType(Enum):
13 """
14 Enumerator referring to regridder types in ``xugrid``.
15 These can be used safely in scripts, remaining backwards compatible for
16 when it is decided to rename regridders in ``xugrid``. For an explanation
17 what each regridder type does, we refer to the `xugrid documentation <https://deltares.github.io/xugrid/examples/regridder_overview.html>`_
18 """
20 CENTROIDLOCATOR = xu.CentroidLocatorRegridder
21 BARYCENTRIC = xu.BarycentricInterpolator
22 OVERLAP = xu.OverlapRegridder
23 RELATIVEOVERLAP = xu.RelativeOverlapRegridder
26class RegridderInstancesCollection:
27 """
28 This class stores any number of regridders that can regrid a single source grid to a single target grid.
29 By storing the regridders, we make sure the regridders can be re-used for different arrays on the same grid.
30 This is important because computing the regridding weights is a costly affair.
31 """
33 def __init__(
34 self,
35 source_grid: Union[xr.DataArray, xu.UgridDataArray],
36 target_grid: Union[xr.DataArray, xu.UgridDataArray],
37 ) -> None:
38 self.regridder_instances: dict[
39 Tuple[type[BaseRegridder], Optional[str]], BaseRegridder
40 ] = {}
41 self._source_grid = source_grid
42 self._target_grid = target_grid
44 def __has_regridder(
45 self, regridder_type: type[BaseRegridder], method: Optional[str] = None
46 ) -> bool:
47 return (regridder_type, method) in self.regridder_instances.keys()
49 def __get_existing_regridder(
50 self, regridder_type: type[BaseRegridder], method: Optional[str]
51 ) -> BaseRegridder:
52 if self.__has_regridder(regridder_type, method):
53 return self.regridder_instances[(regridder_type, method)]
54 raise ValueError("no existing regridder of type " + str(regridder_type))
56 def __create_regridder(
57 self, regridder_type: type[BaseRegridder], method: Optional[str]
58 ) -> BaseRegridder:
59 method_args = () if method is None else (method,)
61 self.regridder_instances[(regridder_type, method)] = regridder_type(
62 self._source_grid, self._target_grid, *method_args
63 )
64 return self.regridder_instances[(regridder_type, method)]
66 def __get_regridder_class(
67 self, regridder_type: RegridderType | BaseRegridder
68 ) -> type[BaseRegridder]:
69 if isinstance(regridder_type, abc.ABCMeta):
70 if not issubclass(regridder_type, BaseRegridder):
71 raise ValueError(
72 "only derived types of BaseRegridder can be instantiated"
73 )
74 return regridder_type
75 elif isinstance(regridder_type, RegridderType):
76 return regridder_type.value
78 raise ValueError("invalid type for regridder")
80 def get_regridder(
81 self,
82 regridder_type: Union[RegridderType, BaseRegridder],
83 method: Optional[str] = None,
84 ) -> BaseRegridder:
85 """
86 returns a regridder of the specified type and with the specified method.
87 The desired type can be passed through the argument "regridder_type" as an enumerator or
88 as a class.
89 The following two are equivalent:
90 instancesCollection.get_regridder(RegridderType.OVERLAP, "mean")
91 instancesCollection.get_regridder(xu.OverlapRegridder, "mean")
94 Parameters
95 ----------
96 regridder_type: RegridderType or regridder class
97 indicates the desired regridder type
98 method: str or None
99 indicates the method the regridder should apply
101 Returns
102 -------
103 a regridder of the specified characteristics
104 """
105 regridder_class = self.__get_regridder_class(regridder_type)
107 if not self.__has_regridder(regridder_class, method):
108 self.__create_regridder(regridder_class, method)
110 return self.__get_existing_regridder(regridder_class, method)
113def get_non_grid_data(package, grid_names: list[str]) -> dict[str, Any]:
114 """
115 This function copies the attributes of a dataset that are scalars, such as options.
117 parameters
118 ----------
119 grid_names: list of str
120 the names of the attribbutes of a dataset that are grids.
121 """
122 result = {}
123 all_non_grid_data = list(package.dataset.keys())
124 for name in grid_names:
125 if name in all_non_grid_data:
126 all_non_grid_data.remove(name)
127 for name in all_non_grid_data:
128 if "time" in package.dataset[name].coords:
129 result[name] = package.dataset[name]
130 else:
131 result[name] = package.dataset[name].values[()]
132 return result
135def assign_coord_if_present(
136 coordname: str, target_grid: GridDataArray, maybe_has_coords_attr: Any
137):
138 """
139 If ``maybe_has_coords`` has a ``coords`` attribute and if coordname in
140 target_grid, copy coord.
141 """
142 if coordname in target_grid.coords:
143 if coordname in target_grid.coords and hasattr(maybe_has_coords_attr, "coords"):
144 maybe_has_coords_attr = maybe_has_coords_attr.assign_coords(
145 {coordname: target_grid.coords[coordname].values[()]}
146 )
147 return maybe_has_coords_attr