Coverage for C:\src\imod-python\imod\util\spatial.py: 93%
263 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 11:25 +0200
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 11:25 +0200
1"""
2Utility functions for dealing with the spatial
3location of rasters: :func:`imod.util.spatial.coord_reference`,
4:func:`imod.util.spatial_reference` and :func:`imod.util.transform`. These are
5used internally, but are not private since they may be useful to users as well.
6"""
8import collections
9import re
10from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union
12import affine
13import numpy as np
14import pandas as pd
15import xarray as xr
16import xugrid as xu
18from imod.typing import FloatArray, GridDataset, IntArray
19from imod.util.imports import MissingOptionalModule
21# since rasterio, shapely, rioxarray, and geopandas are a big dependencies that are
22# sometimes hard to install and not always required, we made this an optional
23# dependency
24try:
25 import rasterio
26except ImportError:
27 rasterio = MissingOptionalModule("rasterio")
29try:
30 import shapely
31except ImportError:
32 shapely = MissingOptionalModule("shapely")
34if TYPE_CHECKING:
35 import geopandas as gpd
36else:
37 try:
38 import geopandas as gpd
39 except ImportError:
40 gpd = MissingOptionalModule("geopandas")
42try:
43 import rioxarray
44except ImportError:
45 rasterio = MissingOptionalModule("rioxarray")
48def _xycoords(bounds, cellsizes) -> Dict[str, Any]:
49 """Based on bounds and cellsizes, construct coords with spatial information"""
50 # unpack tuples
51 xmin, xmax, ymin, ymax = bounds
52 dx, dy = cellsizes
53 coords: collections.OrderedDict[str, Any] = collections.OrderedDict()
54 # from cell size to x and y coordinates
55 if isinstance(dx, (int, float, np.int_)): # equidistant
56 coords["x"] = np.arange(xmin + dx / 2.0, xmax, dx)
57 coords["y"] = np.arange(ymax + dy / 2.0, ymin, dy)
58 coords["dx"] = np.array(float(dx))
59 coords["dy"] = np.array(float(dy))
60 else: # nonequidistant
61 # even though IDF may store them as float32, we always convert them to float64
62 dx = dx.astype(np.float64)
63 dy = dy.astype(np.float64)
64 coords["x"] = xmin + np.cumsum(dx) - 0.5 * dx
65 coords["y"] = ymax + np.cumsum(dy) - 0.5 * dy
66 if np.allclose(dx, dx[0]) and np.allclose(dy, dy[0]):
67 coords["dx"] = np.array(float(dx[0]))
68 coords["dy"] = np.array(float(dy[0]))
69 else:
70 coords["dx"] = ("x", dx)
71 coords["dy"] = ("y", dy)
72 return coords
75def coord_reference(da_coord) -> Tuple[float, float, float]:
76 """
77 Extracts dx, xmin, xmax for a coordinate DataArray, where x is any coordinate.
79 If the DataArray coordinates are nonequidistant, dx will be returned as
80 1D ndarray instead of float.
82 Parameters
83 ----------
84 a : xarray.DataArray of a coordinate
86 Returns
87 --------------
88 tuple
89 (dx, xmin, xmax) for a coordinate x
90 """
91 x = da_coord.values
93 # Possibly non-equidistant
94 dx_string = f"d{da_coord.name}"
95 if dx_string in da_coord.coords:
96 dx = da_coord.coords[dx_string]
97 if (dx.shape == x.shape) and (dx.size != 1):
98 # choose correctly for decreasing coordinate
99 if dx[0] < 0.0:
100 end = 0
101 start = -1
102 else:
103 start = 0
104 end = -1
105 dx = dx.values.astype(np.float64)
106 xmin = float(x.min()) - 0.5 * abs(dx[start])
107 xmax = float(x.max()) + 0.5 * abs(dx[end])
108 # As a single value if equidistant
109 if np.allclose(dx, dx[0]):
110 dx = dx[0]
111 else:
112 dx = float(dx)
113 xmin = float(x.min()) - 0.5 * abs(dx)
114 xmax = float(x.max()) + 0.5 * abs(dx)
115 elif x.size == 1:
116 raise ValueError(
117 f"DataArray has size 1 along {da_coord.name}, so cellsize must be provided"
118 f" as a coordinate named d{da_coord.name}."
119 )
120 else: # Equidistant
121 # TODO: decide on decent criterium for what equidistant means
122 # make use of floating point epsilon? E.g:
123 # https://github.com/ioam/holoviews/issues/1869#issuecomment-353115449
124 dxs = np.diff(x.astype(np.float64))
125 dx = dxs[0]
126 atolx = abs(1.0e-4 * dx)
127 if not np.allclose(dxs, dx, atolx):
128 raise ValueError(
129 f"DataArray has to be equidistant along {da_coord.name}, or cellsizes"
130 f" must be provided as a coordinate named d{da_coord.name}."
131 )
133 # as xarray uses midpoint coordinates
134 xmin = float(x.min()) - 0.5 * abs(dx)
135 xmax = float(x.max()) + 0.5 * abs(dx)
137 return dx, xmin, xmax
140def spatial_reference(
141 a: xr.DataArray,
142) -> Tuple[float, float, float, float, float, float]:
143 """
144 Extracts spatial reference from DataArray.
146 If the DataArray coordinates are nonequidistant, dx and dy will be returned
147 as 1D ndarray instead of float.
149 Parameters
150 ----------
151 a : xarray.DataArray
153 Returns
154 --------------
155 tuple
156 (dx, xmin, xmax, dy, ymin, ymax)
158 """
159 dx, xmin, xmax = coord_reference(a["x"])
160 dy, ymin, ymax = coord_reference(a["y"])
161 return dx, xmin, xmax, dy, ymin, ymax
164def transform(a: xr.DataArray) -> affine.Affine:
165 """
166 Extract the spatial reference information from the DataArray coordinates,
167 into an affine.Affine object for writing to e.g. rasterio supported formats.
169 Parameters
170 ----------
171 a : xarray.DataArray
173 Returns
174 -------
175 affine.Affine
177 """
178 dx, xmin, _, dy, _, ymax = spatial_reference(a)
180 def equidistant(dx, name):
181 if isinstance(dx, np.ndarray):
182 if np.unique(dx).size == 1:
183 return dx[0]
184 else:
185 raise ValueError(f"DataArray is not equidistant along {name}")
186 else:
187 return dx
189 dx = equidistant(dx, "x")
190 dy = equidistant(dy, "y")
192 if dx < 0.0:
193 raise ValueError("dx must be positive")
194 if dy > 0.0:
195 raise ValueError("dy must be negative")
196 return affine.Affine(dx, 0.0, xmin, 0.0, dy, ymax)
199def ugrid2d_data(da: xr.DataArray, face_dim: str) -> xr.DataArray:
200 """
201 Reshape a structured (x, y) DataArray into unstructured (face) form.
202 Extra dimensions are maintained:
203 e.g. (time, layer, x, y) becomes (time, layer, face).
205 Parameters
206 ----------
207 da: xr.DataArray
208 Structured DataArray with last two dimensions ("y", "x").
210 Returns
211 -------
212 Unstructured DataArray with dimensions ("y", "x") replaced by ("face",).
213 """
214 if da.dims[-2:] != ("y", "x"):
215 raise ValueError('Last two dimensions of da must be ("y", "x")')
216 dims = da.dims[:-2]
217 coords = {k: da.coords[k] for k in dims}
218 return xr.DataArray(
219 da.data.reshape(*da.shape[:-2], -1),
220 coords=coords,
221 dims=[*dims, face_dim],
222 name=da.name,
223 )
226def unstack_dim_into_variable(dataset: GridDataset, dim: str) -> GridDataset:
227 """
228 Unstack each variable containing ``dim`` into separate variables.
229 """
230 unstacked = dataset.copy()
232 variables_containing_dim = [
233 variable for variable in dataset.data_vars if dim in dataset[variable].dims
234 ]
236 for variable in variables_containing_dim:
237 stacked = unstacked[variable]
238 unstacked = unstacked.drop_vars(variable) # type: ignore
239 for index in stacked[dim].values:
240 unstacked[f"{variable}_{dim}_{index}"] = stacked.sel(
241 indexers={dim: index}, drop=True
242 )
243 if dim in unstacked.coords:
244 unstacked = unstacked.drop_vars(dim)
245 return unstacked
248def mdal_compliant_ugrid2d(
249 dataset: xr.Dataset, crs: Optional[Any] = None
250) -> xr.Dataset:
251 """
252 Ensures the xarray Dataset will be written to a UGRID netCDF that will be
253 accepted by MDAL.
255 * Unstacks variables with a layer dimension into separate variables.
256 * Removes absent entries from the mesh topology attributes.
257 * Sets encoding to float for datetime variables.
259 Parameters
260 ----------
261 dataset: xarray.Dataset
262 Dataset to make compliant with MDAL
263 crs: Any, Optional
264 Anything accepted by rasterio.crs.CRS.from_user_input
265 Requires ``rioxarray`` installed.
267 Returns
268 -------
269 unstacked: xr.Dataset
271 """
272 ds = unstack_dim_into_variable(dataset, "layer")
274 # Find topology variables
275 for variable in ds.data_vars:
276 attrs = ds[variable].attrs
277 if attrs.get("cf_role") == "mesh_topology":
278 # Possible attributes:
279 #
280 # "cf_role"
281 # "long_name"
282 # "topology_dimension"
283 # "node_dimension": required
284 # "node_coordinates": required
285 # "edge_dimension": optional
286 # "edge_node_connectivity": optional
287 # "face_dimension": required
288 # "face_node_connectivity": required
289 # "max_face_nodes_dimension": required
290 # "face_coordinates": optional
292 node_dim = attrs.get("node_dimension")
293 edge_dim = attrs.get("edge_dimension")
294 face_dim = attrs.get("face_dimension")
296 # Drop the coordinates on the UGRID dimensions
297 to_drop = []
298 for dim in (node_dim, edge_dim, face_dim):
299 if dim is not None and dim in ds.coords:
300 to_drop.append(dim)
301 ds = ds.drop_vars(to_drop)
303 if edge_dim and edge_dim not in ds.dims:
304 attrs.pop("edge_dimension")
306 face_coords = attrs.get("face_coordinates")
307 if face_coords and face_coords not in ds.coords:
308 attrs.pop("face_coordinates")
310 edge_nodes = attrs.get("edge_node_connectivity")
311 if edge_nodes and edge_nodes not in ds:
312 attrs.pop("edge_node_connectivity")
314 if crs is not None:
315 if isinstance(rioxarray, MissingOptionalModule):
316 raise ModuleNotFoundError("rioxarray is required for this functionality")
317 ds.rio.write_crs(crs, inplace=True)
319 # Make sure time is encoded as a float for MDAL
320 # TODO: MDAL requires all data variables to be float (this excludes the UGRID topology data)
321 for var in ds.coords:
322 if np.issubdtype(ds[var].dtype, np.datetime64):
323 ds[var].encoding["dtype"] = np.float64
325 return ds
328def from_mdal_compliant_ugrid2d(dataset: xu.UgridDataset) -> xu.UgridDataset:
329 """
330 Undo some of the changes of ``mdal_compliant_ugrid2d``: re-stack the
331 layers.
333 Parameters
334 ----------
335 dataset: xugrid.UgridDataset
337 Returns
338 -------
339 restacked: xugrid.UgridDataset
341 """
342 ds = dataset.ugrid.obj
343 pattern = re.compile(r"(\w+)_layer_(\d+)")
344 matches = [(variable, pattern.search(variable)) for variable in ds.data_vars]
345 matches = [(variable, match) for (variable, match) in matches if match is not None]
346 if not matches:
347 return dataset
349 # First deal with the variables that may remain untouched.
350 other_vars = set(ds.data_vars).difference([variable for (variable, _) in matches])
351 restacked = ds[list(other_vars)]
353 # Next group by name, which will be the output dataset variable name.
354 grouped = collections.defaultdict(list)
355 for variable, match in matches:
356 name, layer = match.groups() # type: ignore
357 da = ds[variable]
358 grouped[name].append(da.assign_coords(layer=int(layer)))
360 # Concatenate, and make sure the dimension order is natural.
361 ugrid_dims = {dim for grid in dataset.ugrid.grids for dim in grid.dimensions}
362 for variable, das in grouped.items():
363 da = xr.concat(sorted(das, key=lambda da: da["layer"]), dim="layer")
364 newdims = list(da.dims)
365 newdims.remove("layer")
366 # If it's a spatial dataset, the layer should be second last.
367 if ugrid_dims.intersection(newdims):
368 newdims.insert(-1, "layer")
369 # If not, the layer should be last.
370 else:
371 newdims.append("layer")
372 if tuple(newdims) != da.dims:
373 da = da.transpose(*newdims)
375 restacked[variable] = da
377 return xu.UgridDataset(restacked, grids=dataset.ugrid.grids)
380def to_ugrid2d(data: Union[xr.DataArray, xr.Dataset]) -> xr.Dataset:
381 """
382 Convert a structured DataArray or Dataset into its UGRID-2D quadrilateral
383 equivalent.
385 See:
386 https://ugrid-conventions.github.io/ugrid-conventions/#2d-flexible-mesh-mixed-triangles-quadrilaterals-etc-topology
388 Parameters
389 ----------
390 data: Union[xr.DataArray, xr.Dataset]
391 Dataset or DataArray with last two dimensions ("y", "x").
392 In case of a Dataset, the 2D topology is defined once and variables are
393 added one by one.
394 In case of a DataArray, a name is required; a name can be set with:
395 ``da.name = "..."``'
397 Returns
398 -------
399 ugrid2d_dataset: xr.Dataset
400 The equivalent data, in UGRID-2D quadrilateral form.
401 """
402 if not isinstance(data, (xr.DataArray, xr.Dataset)):
403 raise TypeError("data must be xarray.DataArray or xr.Dataset")
405 grid = xu.Ugrid2d.from_structured(data)
406 ds = grid.to_dataset()
408 if isinstance(data, xr.Dataset):
409 for variable in data.data_vars:
410 ds[variable] = ugrid2d_data(data[variable], grid.face_dimension)
411 if isinstance(data, xr.DataArray):
412 if data.name is None:
413 raise ValueError(
414 'A name is required for the DataArray. It can be set with ``da.name = "..."`'
415 )
416 ds[data.name] = ugrid2d_data(data, grid.face_dimension)
417 return mdal_compliant_ugrid2d(ds)
420def gdal_compliant_grid(
421 data: Union[xr.DataArray, xr.Dataset],
422 crs: Optional[Any] = None,
423) -> Union[xr.DataArray, xr.Dataset]:
424 """
425 Assign attributes to x,y coordinates to make data accepted by GDAL.
427 Parameters
428 ----------
429 data: xr.DataArray | xr.Dataset
430 Structured data with a x and y coordinate.
431 crs: Any, Optional
432 Anything accepted by rasterio.crs.CRS.from_user_input
433 Requires ``rioxarray`` installed.
435 Returns
436 -------
437 data with attributes to be accepted by GDAL.
438 """
439 x_attrs = {
440 "axis": "X",
441 "long_name": "x coordinate of projection",
442 "standard_name": "projection_x_coordinate",
443 }
444 y_attrs = {
445 "axis": "Y",
446 "long_name": "y coordinate of projection",
447 "standard_name": "projection_y_coordinate",
448 }
450 # Use of ``dims`` in xarray currently inconsistent between DataArray and
451 # Dataset, therefore use .sizes.keys() to force getting the same thing.
452 # FUTURE: change this to set(data.dims) when made consistent.
453 dims = {str(k) for k in data.sizes.keys()}
454 missing_dims = {"x", "y"} - dims
456 if len(missing_dims) > 0:
457 raise ValueError(f"Missing dimensions: {missing_dims}")
459 x_coord_attrs = data.coords["x"].assign_attrs(x_attrs)
460 y_coord_attrs = data.coords["y"].assign_attrs(y_attrs)
462 data_gdal = data.assign_coords(x=x_coord_attrs, y=y_coord_attrs)
464 if crs is not None:
465 if isinstance(rioxarray, MissingOptionalModule):
466 raise ModuleNotFoundError("rioxarray is required for this functionality")
467 elif (data_gdal.rio.crs is not None) and (data_gdal.rio.crs != crs):
468 raise ValueError(
469 "Grid already has CRS different then provided CRS. "
470 f"Grid has {data_gdal.rio.crs}, got {crs}."
471 )
473 data_gdal.rio.write_crs(crs, inplace=True)
475 return data_gdal
478def empty_2d(
479 dx: Union[float, FloatArray],
480 xmin: float,
481 xmax: float,
482 dy: Union[float, FloatArray],
483 ymin: float,
484 ymax: float,
485) -> xr.DataArray:
486 """
487 Create an empty 2D (x, y) DataArray.
489 ``dx`` and ``dy`` may be provided as:
491 * scalar: for equidistant spacing
492 * array: for non-equidistant spacing
494 Note that xarray (and netCDF4) uses midpoint coordinates. ``xmin`` and
495 ``xmax`` are used to generate the appropriate midpoints.
497 Parameters
498 ----------
499 dx: float, 1d array of floats
500 cell size along x
501 xmin: float
502 xmax: float
503 dy: float, 1d array of floats
504 cell size along y
505 ymin: float
506 ymax: float
508 Returns
509 -------
510 empty: xr.DataArray
511 Filled with NaN.
512 """
513 bounds = (xmin, xmax, ymin, ymax)
514 cellsizes = (np.abs(dx), -np.abs(dy))
515 coords = _xycoords(bounds, cellsizes)
516 nrow = coords["y"].size
517 ncol = coords["x"].size
518 return xr.DataArray(
519 data=np.full((nrow, ncol), np.nan), coords=coords, dims=["y", "x"]
520 )
523def empty_3d(
524 dx: Union[float, FloatArray],
525 xmin: float,
526 xmax: float,
527 dy: Union[float, FloatArray],
528 ymin: float,
529 ymax: float,
530 layer: Union[int, Sequence[int], IntArray],
531) -> xr.DataArray:
532 """
533 Create an empty 2D (x, y) DataArray.
535 ``dx`` and ``dy`` may be provided as:
537 * scalar: for equidistant spacing
538 * array: for non-equidistant spacing
540 Note that xarray (and netCDF4) uses midpoint coordinates. ``xmin`` and
541 ``xmax`` are used to generate the appropriate midpoints.
543 Parameters
544 ----------
545 dx: float, 1d array of floats
546 cell size along x
547 xmin: float
548 xmax: float
549 dy: float, 1d array of floats
550 cell size along y
551 ymin: float
552 ymax: float
553 layer: int, sequence of integers, 1d array of integers
555 Returns
556 -------
557 empty: xr.DataArray
558 Filled with NaN.
559 """
560 bounds = (xmin, xmax, ymin, ymax)
561 cellsizes = (np.abs(dx), -np.abs(dy))
562 coords = _xycoords(bounds, cellsizes)
563 nrow = coords["y"].size
564 ncol = coords["x"].size
565 layer = _layer(layer)
566 coords["layer"] = layer
568 return xr.DataArray(
569 data=np.full((layer.size, nrow, ncol), np.nan),
570 coords=coords,
571 dims=["layer", "y", "x"],
572 )
575def empty_2d_transient(
576 dx: Union[float, FloatArray],
577 xmin: float,
578 xmax: float,
579 dy: Union[float, FloatArray],
580 ymin: float,
581 ymax: float,
582 time: Any,
583) -> xr.DataArray:
584 """
585 Create an empty transient 2D (time, x, y) DataArray.
587 ``dx`` and ``dy`` may be provided as:
589 * scalar: for equidistant spacing
590 * array: for non-equidistant spacing
592 Note that xarray (and netCDF4) uses midpoint coordinates. ``xmin`` and
593 ``xmax`` are used to generate the appropriate midpoints.
595 Parameters
596 ----------
597 dx: float, 1d array of floats
598 cell size along x
599 xmin: float
600 xmax: float
601 dy: float, 1d array of floats
602 cell size along y
603 ymin: float
604 ymax: float
605 time: Any
606 One or more of: str, numpy datetime64, pandas Timestamp
608 Returns
609 -------
610 empty: xr.DataArray
611 Filled with NaN.
612 """
613 bounds = (xmin, xmax, ymin, ymax)
614 cellsizes = (np.abs(dx), -np.abs(dy))
615 coords = _xycoords(bounds, cellsizes)
616 nrow = coords["y"].size
617 ncol = coords["x"].size
618 time = _time(time)
619 coords["time"] = time
620 return xr.DataArray(
621 data=np.full((time.size, nrow, ncol), np.nan),
622 coords=coords,
623 dims=["time", "y", "x"],
624 )
627def empty_3d_transient(
628 dx: Union[float, FloatArray],
629 xmin: float,
630 xmax: float,
631 dy: Union[float, FloatArray],
632 ymin: float,
633 ymax: float,
634 layer: Union[int, Sequence[int], IntArray],
635 time: Any,
636) -> xr.DataArray:
637 """
638 Create an empty transient 3D (time, layer, x, y) DataArray.
640 ``dx`` and ``dy`` may be provided as:
642 * scalar: for equidistant spacing
643 * array: for non-equidistant spacing
645 Note that xarray (and netCDF4) uses midpoint coordinates. ``xmin`` and
646 ``xmax`` are used to generate the appropriate midpoints.
648 Parameters
649 ----------
650 dx: float, 1d array of floats
651 cell size along x
652 xmin: float
653 xmax: float
654 dy: float, 1d array of floats
655 cell size along y
656 ymin: float
657 ymax: float
658 layer: int, sequence of integers, 1d array of integers
659 time: Any
660 One or more of: str, numpy datetime64, pandas Timestamp
662 Returns
663 -------
664 empty: xr.DataArray
665 Filled with NaN.
666 """
667 bounds = (xmin, xmax, ymin, ymax)
668 cellsizes = (np.abs(dx), -np.abs(dy))
669 coords = _xycoords(bounds, cellsizes)
670 nrow = coords["y"].size
671 ncol = coords["x"].size
672 layer = _layer(layer)
673 coords["layer"] = layer
674 time = _time(time)
675 coords["time"] = time
676 return xr.DataArray(
677 data=np.full((time.size, layer.size, nrow, ncol), np.nan),
678 coords=coords,
679 dims=["time", "layer", "y", "x"],
680 )
683def _layer(layer: Union[int, Sequence[int], IntArray]) -> IntArray:
684 layer = np.atleast_1d(layer)
685 if layer.ndim > 1:
686 raise ValueError("layer must be 1d")
687 return layer
690def _time(time: Any) -> Any:
691 time = np.atleast_1d(time)
692 if time.ndim > 1:
693 raise ValueError("time must be 1d")
694 return pd.to_datetime(time)
697def is_divisor(numerator: Union[float, FloatArray], denominator: float) -> bool:
698 """
699 Parameters
700 ----------
701 numerator: np.array of floats or float
702 denominator: float
704 Returns
705 -------
706 is_divisor: bool
707 """
708 denominator = np.abs(denominator)
709 remainder = np.abs(numerator) % denominator
710 return bool(np.all(np.isclose(remainder, 0.0) | np.isclose(remainder, denominator)))
713def _polygonize(da: xr.DataArray) -> "gpd.GeoDataFrame":
714 """
715 Polygonize a 2D-DataArray into a GeoDataFrame of polygons.
717 Private method located in util.spatial to work around circular imports.
718 """
720 if da.dims != ("y", "x"):
721 raise ValueError('Dimensions must be ("y", "x")')
723 values = da.values
724 if values.dtype == np.float64:
725 values = values.astype(np.float32)
727 affine_transform = transform(da)
728 shapes = rasterio.features.shapes(values, transform=affine_transform)
730 geometries = []
731 colvalues = []
732 for geom, colval in shapes:
733 geometries.append(shapely.geometry.Polygon(geom["coordinates"][0]))
734 colvalues.append(colval)
736 gdf = gpd.GeoDataFrame({"value": colvalues, "geometry": geometries})
737 gdf.crs = da.attrs.get("crs")
738 return gdf