Coverage for C:\src\imod-python\imod\typing\grid.py: 91%
193 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 11:25 +0200
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 11:25 +0200
1import pickle
2import textwrap
3from functools import wraps
4from typing import Callable, Mapping, Sequence
6import numpy as np
7import xarray as xr
8import xugrid as xu
9from fastcore.dispatch import typedispatch
11from imod.typing import GridDataArray, GridDataset, structured
12from imod.util.spatial import _polygonize
15@typedispatch
16def zeros_like(grid: xr.DataArray, *args, **kwargs):
17 return xr.zeros_like(grid, *args, **kwargs)
20@typedispatch # type: ignore[no-redef]
21def zeros_like(grid: xu.UgridDataArray, *args, **kwargs): # noqa: F811
22 return xu.zeros_like(grid, *args, **kwargs)
25@typedispatch
26def ones_like(grid: xr.DataArray, *args, **kwargs):
27 return xr.ones_like(grid, *args, **kwargs)
30@typedispatch # type: ignore[no-redef]
31def ones_like(grid: xu.UgridDataArray, *args, **kwargs): # noqa: F811
32 return xu.ones_like(grid, *args, **kwargs)
35@typedispatch
36def nan_like(grid: xr.DataArray, dtype=np.float32, *args, **kwargs):
37 return xr.full_like(grid, fill_value=np.nan, dtype=dtype, *args, **kwargs)
40@typedispatch # type: ignore[no-redef]
41def nan_like(grid: xu.UgridDataArray, dtype=np.float32, *args, **kwargs): # noqa: F811
42 return xu.full_like(grid, fill_value=np.nan, dtype=dtype, *args, **kwargs)
45@typedispatch
46def is_unstructured(grid: xu.UgridDataArray | xu.UgridDataset) -> bool:
47 return True
50@typedispatch # type: ignore[no-redef]
51def is_unstructured(grid: xr.DataArray | xr.Dataset) -> bool: # noqa: F811
52 return False
55def _force_decreasing_y(structured_grid: xr.DataArray | xr.Dataset):
56 flip = slice(None, None, -1)
57 if structured_grid.indexes["y"].is_monotonic_increasing:
58 structured_grid = structured_grid.isel(y=flip)
59 elif not structured_grid.indexes["y"].is_monotonic_decreasing:
60 raise RuntimeError(
61 f"Non-monotonous y-coordinates for grid: {structured_grid.name}."
62 )
63 return structured_grid
66def _get_first_item(objects: Sequence):
67 return next(iter(objects))
70# Typedispatching doesn't work based on types of list elements, therefore resort to
71# isinstance testing
72def _type_dispatch_functions_on_grid_sequence(
73 objects: Sequence[GridDataArray | GridDataset],
74 unstructured_func: Callable,
75 structured_func: Callable,
76 *args,
77 **kwargs,
78) -> GridDataArray | GridDataset:
79 """
80 Type dispatch functions on sequence of grids. Functions like merging or concatenating.
81 """
82 first_object = _get_first_item(objects)
83 start_type = type(first_object)
84 homogeneous = all(isinstance(o, start_type) for o in objects)
85 if not homogeneous:
86 unique_types = {type(o) for o in objects}
87 raise TypeError(
88 f"Only homogeneous sequences can be reduced, received sequence of {unique_types}"
89 )
90 if isinstance(first_object, (xu.UgridDataArray, xu.UgridDataset)):
91 return unstructured_func(objects, *args, **kwargs)
92 elif isinstance(first_object, (xr.DataArray, xr.Dataset)):
93 return _force_decreasing_y(structured_func(objects, *args, **kwargs))
94 raise TypeError(
95 f"'{unstructured_func.__name__}' not supported for type {type(objects[0])}"
96 )
99# Typedispatching doesn't work based on types of dict elements, therefore resort
100# to manual type testing
101def _type_dispatch_functions_on_dict(
102 dict_of_objects: Mapping[str, GridDataArray | float | bool | int],
103 unstructured_func: Callable,
104 structured_func: Callable,
105 *args,
106 **kwargs,
107):
108 """
109 Typedispatch function on grid and scalar variables provided in dictionary.
110 Types do not need to be homogeneous as scalars and grids can be mixed. No
111 mixing of structured and unstructured grids is allowed. Also allows running
112 function on dictionary with purely scalars, in which case it will call to
113 the xarray function.
114 """
116 error_msg = textwrap.dedent(
117 """
118 Received both structured grid (xr.DataArray) and xu.UgridDataArray. This
119 means structured grids as well as unstructured grids were provided.
120 """
121 )
123 if dict_of_objects is None:
124 return xr.Dataset()
126 types = [type(arg) for arg in dict_of_objects.values()]
127 has_unstructured = xu.UgridDataArray in types
128 # Test structured if xr.DataArray and spatial.
129 has_structured_grid = any(
130 isinstance(arg, xr.DataArray) and is_spatial_grid(arg)
131 for arg in dict_of_objects.values()
132 )
133 if has_structured_grid and has_unstructured:
134 raise TypeError(error_msg)
135 if has_unstructured:
136 return unstructured_func([dict_of_objects], *args, **kwargs)
138 return structured_func([dict_of_objects], *args, **kwargs)
141def merge(
142 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs
143) -> GridDataset:
144 return _type_dispatch_functions_on_grid_sequence(
145 objects, xu.merge, xr.merge, *args, **kwargs
146 )
149def merge_partitions(
150 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs
151) -> GridDataArray | GridDataset:
152 return _type_dispatch_functions_on_grid_sequence(
153 objects, xu.merge_partitions, structured.merge_partitions, *args, **kwargs
154 )
157def concat(
158 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs
159) -> GridDataArray | GridDataset:
160 return _type_dispatch_functions_on_grid_sequence(
161 objects, xu.concat, xr.concat, *args, **kwargs
162 )
165def merge_unstructured_dataset(variables_to_merge: list[dict], *args, **kwargs):
166 """
167 Work around xugrid issue https://github.com/Deltares/xugrid/issues/179
169 Expects only one dictionary in list. List is used to have same API as
170 xr.merge().
172 Merges unstructured grids first, then manually assigns scalar variables.
173 """
174 if len(variables_to_merge) > 1:
175 raise ValueError(
176 f"Only one dict of variables expected, got {len(variables_to_merge)}"
177 )
179 variables_to_merge_dict = variables_to_merge[0]
181 if not isinstance(variables_to_merge_dict, dict):
182 raise TypeError(f"Expected dict, got {type(variables_to_merge_dict)}")
184 # Separate variables into list of grids and dict of scalar variables
185 grids_ls = []
186 scalar_dict = {}
187 for name, variable in variables_to_merge_dict.items():
188 if isinstance(variable, xu.UgridDataArray):
189 grids_ls.append(variable.rename(name))
190 else:
191 scalar_dict[name] = variable
193 # Merge grids
194 dataset = xu.merge(grids_ls, *args, **kwargs)
196 # Temporarily work around this xugrid issue, until fixed:
197 # https://github.com/Deltares/xugrid/issues/206
198 grid_hashes = [hash(pickle.dumps(grid)) for grid in dataset.ugrid.grids]
199 unique_grid_hashes = np.unique(grid_hashes)
200 if unique_grid_hashes.size > 1:
201 raise ValueError(
202 "Multiple grids provided, please provide data on one unique grid"
203 )
204 else:
205 # Possibly won't work anymore if this ever gets implemented:
206 # https://github.com/Deltares/xugrid/issues/195
207 dataset._grids = [dataset.grids[0]]
209 # Assign scalar variables manually
210 for name, variable in scalar_dict.items():
211 dataset[name] = variable
213 return dataset
216def merge_with_dictionary(
217 variables_to_merge: Mapping[str, GridDataArray | float | bool | int],
218 *args,
219 **kwargs,
220):
221 return _type_dispatch_functions_on_dict(
222 variables_to_merge, merge_unstructured_dataset, xr.merge, *args, **kwargs
223 )
226@typedispatch
227def bounding_polygon(active: xr.DataArray):
228 """Return bounding polygon of active cells"""
229 to_polygonize = active.where(active, other=np.nan)
230 polygons_gdf = _polygonize(to_polygonize)
231 # Filter polygons with inactive values (NaN)
232 is_active_polygon = polygons_gdf["value"] == 1.0
233 return polygons_gdf.loc[is_active_polygon]
236@typedispatch # type: ignore[no-redef]
237def bounding_polygon(active: xu.UgridDataArray): # noqa: F811
238 """Return bounding polygon of active cells"""
239 active_indices = np.where(active > 0)[0]
240 domain_slice = {f"{active.ugrid.grid.face_dimension}": active_indices}
241 active_clipped = active.isel(domain_slice, missing_dims="ignore")
243 return active_clipped.ugrid.grid.bounding_polygon()
246@typedispatch
247def is_spatial_grid(array: xr.DataArray | xr.Dataset) -> bool:
248 """Return True if the array contains data in at least 2 spatial dimensions"""
249 coords = array.coords
250 dims = array.dims
251 has_spatial_coords = "x" in coords and "y" in coords
252 has_spatial_dims = "x" in dims and "y" in dims
253 return has_spatial_coords & has_spatial_dims
256@typedispatch # type: ignore[no-redef]
257def is_spatial_grid(array: xu.UgridDataArray | xu.UgridDataset) -> bool: # noqa: F811
258 """Return True if the array contains data associated to cell faces"""
259 face_dim = array.ugrid.grid.face_dimension
260 dims = array.dims
261 coords = array.coords
262 has_spatial_coords = face_dim in coords
263 has_spatial_dims = face_dim in dims
264 return has_spatial_dims & has_spatial_coords
267@typedispatch # type: ignore[no-redef]
268def is_spatial_grid(_: object) -> bool: # noqa: F811
269 return False
272@typedispatch
273def is_equal(array1: xu.UgridDataArray, array2: xu.UgridDataArray) -> bool:
274 return array1.equals(array2) and array1.ugrid.grid.equals(array2.ugrid.grid)
277@typedispatch # type: ignore[no-redef]
278def is_equal(array1: xr.DataArray, array2: xr.DataArray) -> bool: # noqa: F811
279 return array1.equals(array2)
282@typedispatch # type: ignore[no-redef]
283def is_equal(array1: object, array2: object) -> bool: # noqa: F811
284 return False
287@typedispatch
288def is_same_domain(grid1: xu.UgridDataArray, grid2: xu.UgridDataArray) -> bool:
289 return grid1.coords.equals(grid2.coords) and grid1.ugrid.grid.equals(
290 grid2.ugrid.grid
291 )
294@typedispatch # type: ignore[no-redef]
295def is_same_domain(grid1: xr.DataArray, grid2: xr.DataArray) -> bool: # noqa: F811
296 return grid1.coords.equals(grid2.coords)
299@typedispatch # type: ignore[no-redef]
300def is_same_domain(grid1: object, grid2: object) -> bool: # noqa: F811
301 return False
304@typedispatch
305def get_spatial_dimension_names(grid: xr.DataArray) -> list[str]:
306 return ["x", "y", "layer", "dx", "dy"]
309@typedispatch # type: ignore[no-redef]
310def get_spatial_dimension_names(grid: xu.UgridDataArray) -> list[str]: # noqa: F811
311 facedim = grid.ugrid.grid.face_dimension
312 return [facedim, "layer"]
315@typedispatch # type: ignore[no-redef]
316def get_spatial_dimension_names(grid: object) -> list[str]: # noqa: F811
317 return []
320@typedispatch
321def get_grid_geometry_hash(grid: xr.DataArray) -> int:
322 hash_x = hash(pickle.dumps(grid["x"].values))
323 hash_y = hash(pickle.dumps(grid["y"].values))
324 return (hash_x, hash_y)
327@typedispatch # type: ignore[no-redef]
328def get_grid_geometry_hash(grid: xu.UgridDataArray) -> int: # noqa: F811
329 hash_x = hash(pickle.dumps(grid.ugrid.grid.node_x))
330 hash_y = hash(pickle.dumps(grid.ugrid.grid.node_y))
331 hash_connectivity = hash(pickle.dumps(grid.ugrid.grid.node_face_connectivity))
332 return (hash_x, hash_y, hash_connectivity)
335@typedispatch # type: ignore[no-redef]
336def get_grid_geometry_hash(grid: object) -> int: # noqa: F811
337 raise ValueError("get_grid_geometry_hash not supported for this object.")
340@typedispatch
341def enforce_dim_order(grid: xr.DataArray) -> xr.DataArray:
342 """Enforce dimension order to iMOD Python standard"""
343 return grid.transpose("species", "time", "layer", "y", "x", missing_dims="ignore")
346@typedispatch # type: ignore[no-redef]
347def enforce_dim_order(grid: xu.UgridDataArray) -> xu.UgridDataArray: # noqa: F811
348 """Enforce dimension order to iMOD Python standard"""
349 face_dimension = grid.ugrid.grid.face_dimension
350 return grid.transpose(
351 "species", "time", "layer", face_dimension, missing_dims="ignore"
352 )
355def _enforce_unstructured(obj: GridDataArray, ugrid2d=xu.Ugrid2d) -> xu.UgridDataArray:
356 """Force obj to unstructured"""
357 return xu.UgridDataArray(xr.DataArray(obj), ugrid2d)
360def preserve_gridtype(func):
361 """
362 Decorator to preserve gridtype, this is to work around the following xugrid
363 behavior:
365 >>> UgridDataArray() * DataArray() -> UgridDataArray
366 >>> DataArray() * UgridDataArray() -> DataArray
368 with this decorator:
370 >>> UgridDataArray() * DataArray() -> UgridDataArray
371 >>> DataArray() * UgridDataArray() -> UgridDataArray
372 """
374 @wraps(func)
375 def decorator(*args, **kwargs):
376 unstructured = False
377 grid = None
378 for arg in args:
379 if is_unstructured(arg):
380 unstructured = True
381 grid = arg.ugrid.grid
383 x = func(*args, **kwargs)
385 if unstructured:
386 # Multiple grids returned
387 if isinstance(x, tuple):
388 return tuple(_enforce_unstructured(i, grid) for i in x)
389 return _enforce_unstructured(x, grid)
390 return x
392 return decorator