Coverage for C:\src\imod-python\imod\util\spatial.py: 54%

238 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-08 13:27 +0200

1""" 

2Utility functions for dealing with the spatial 

3location of rasters: :func:`imod.util.spatial.coord_reference`, 

4:func:`imod.util.spatial_reference` and :func:`imod.util.transform`. These are 

5used internally, but are not private since they may be useful to users as well. 

6""" 

7 

8import collections 

9import re 

10from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple, Union 

11 

12import affine 

13import numpy as np 

14import pandas as pd 

15import xarray as xr 

16import xugrid as xu 

17 

18from imod.typing import FloatArray, GridDataset, IntArray 

19from imod.util.imports import MissingOptionalModule 

20 

21# since rasterio, shapely, and geopandas are a big dependencies that are 

22# sometimes hard to install and not always required, we made this an optional 

23# dependency 

24try: 

25 import rasterio 

26except ImportError: 

27 rasterio = MissingOptionalModule("rasterio") 

28 

29try: 

30 import shapely 

31except ImportError: 

32 shapely = MissingOptionalModule("shapely") 

33 

34if TYPE_CHECKING: 

35 import geopandas as gpd 

36else: 

37 try: 

38 import geopandas as gpd 

39 except ImportError: 

40 gpd = MissingOptionalModule("geopandas") 

41 

42def _xycoords(bounds, cellsizes) -> Dict[str, Any]: 

43 """Based on bounds and cellsizes, construct coords with spatial information""" 

44 # unpack tuples 

45 xmin, xmax, ymin, ymax = bounds 

46 dx, dy = cellsizes 

47 coords = collections.OrderedDict() 

48 # from cell size to x and y coordinates 

49 if isinstance(dx, (int, float)): # equidistant 

50 coords["x"] = np.arange(xmin + dx / 2.0, xmax, dx) 

51 coords["y"] = np.arange(ymax + dy / 2.0, ymin, dy) 

52 coords["dx"] = float(dx) 

53 coords["dy"] = float(dy) 

54 else: # nonequidistant 

55 # even though IDF may store them as float32, we always convert them to float64 

56 dx = dx.astype(np.float64) 

57 dy = dy.astype(np.float64) 

58 coords["x"] = xmin + np.cumsum(dx) - 0.5 * dx 

59 coords["y"] = ymax + np.cumsum(dy) - 0.5 * dy 

60 if np.allclose(dx, dx[0]) and np.allclose(dy, dy[0]): 

61 coords["dx"] = float(dx[0]) 

62 coords["dy"] = float(dy[0]) 

63 else: 

64 coords["dx"] = ("x", dx) 

65 coords["dy"] = ("y", dy) 

66 return coords 

67 

68 

69def coord_reference(da_coord) -> Tuple[float, float, float]: 

70 """ 

71 Extracts dx, xmin, xmax for a coordinate DataArray, where x is any coordinate. 

72 

73 If the DataArray coordinates are nonequidistant, dx will be returned as 

74 1D ndarray instead of float. 

75 

76 Parameters 

77 ---------- 

78 a : xarray.DataArray of a coordinate 

79 

80 Returns 

81 -------------- 

82 tuple 

83 (dx, xmin, xmax) for a coordinate x 

84 """ 

85 x = da_coord.values 

86 

87 # Possibly non-equidistant 

88 dx_string = f"d{da_coord.name}" 

89 if dx_string in da_coord.coords: 

90 dx = da_coord.coords[dx_string] 

91 if (dx.shape == x.shape) and (dx.size != 1): 

92 # choose correctly for decreasing coordinate 

93 if dx[0] < 0.0: 

94 end = 0 

95 start = -1 

96 else: 

97 start = 0 

98 end = -1 

99 dx = dx.values.astype(np.float64) 

100 xmin = float(x.min()) - 0.5 * abs(dx[start]) 

101 xmax = float(x.max()) + 0.5 * abs(dx[end]) 

102 # As a single value if equidistant 

103 if np.allclose(dx, dx[0]): 

104 dx = dx[0] 

105 else: 

106 dx = float(dx) 

107 xmin = float(x.min()) - 0.5 * abs(dx) 

108 xmax = float(x.max()) + 0.5 * abs(dx) 

109 elif x.size == 1: 

110 raise ValueError( 

111 f"DataArray has size 1 along {da_coord.name}, so cellsize must be provided" 

112 f" as a coordinate named d{da_coord.name}." 

113 ) 

114 else: # Equidistant 

115 # TODO: decide on decent criterium for what equidistant means 

116 # make use of floating point epsilon? E.g: 

117 # https://github.com/ioam/holoviews/issues/1869#issuecomment-353115449 

118 dxs = np.diff(x.astype(np.float64)) 

119 dx = dxs[0] 

120 atolx = abs(1.0e-4 * dx) 

121 if not np.allclose(dxs, dx, atolx): 

122 raise ValueError( 

123 f"DataArray has to be equidistant along {da_coord.name}, or cellsizes" 

124 f" must be provided as a coordinate named d{da_coord.name}." 

125 ) 

126 

127 # as xarray uses midpoint coordinates 

128 xmin = float(x.min()) - 0.5 * abs(dx) 

129 xmax = float(x.max()) + 0.5 * abs(dx) 

130 

131 return dx, xmin, xmax 

132 

133 

134def spatial_reference( 

135 a: xr.DataArray, 

136) -> Tuple[float, float, float, float, float, float]: 

137 """ 

138 Extracts spatial reference from DataArray. 

139 

140 If the DataArray coordinates are nonequidistant, dx and dy will be returned 

141 as 1D ndarray instead of float. 

142 

143 Parameters 

144 ---------- 

145 a : xarray.DataArray 

146 

147 Returns 

148 -------------- 

149 tuple 

150 (dx, xmin, xmax, dy, ymin, ymax) 

151 

152 """ 

153 dx, xmin, xmax = coord_reference(a["x"]) 

154 dy, ymin, ymax = coord_reference(a["y"]) 

155 return dx, xmin, xmax, dy, ymin, ymax 

156 

157 

158def transform(a: xr.DataArray) -> affine.Affine: 

159 """ 

160 Extract the spatial reference information from the DataArray coordinates, 

161 into an affine.Affine object for writing to e.g. rasterio supported formats. 

162 

163 Parameters 

164 ---------- 

165 a : xarray.DataArray 

166 

167 Returns 

168 ------- 

169 affine.Affine 

170 

171 """ 

172 dx, xmin, _, dy, _, ymax = spatial_reference(a) 

173 

174 def equidistant(dx, name): 

175 if isinstance(dx, np.ndarray): 

176 if np.unique(dx).size == 1: 

177 return dx[0] 

178 else: 

179 raise ValueError(f"DataArray is not equidistant along {name}") 

180 else: 

181 return dx 

182 

183 dx = equidistant(dx, "x") 

184 dy = equidistant(dy, "y") 

185 

186 if dx < 0.0: 

187 raise ValueError("dx must be positive") 

188 if dy > 0.0: 

189 raise ValueError("dy must be negative") 

190 return affine.Affine(dx, 0.0, xmin, 0.0, dy, ymax) 

191 

192 

193def ugrid2d_data(da: xr.DataArray, face_dim: str) -> xr.DataArray: 

194 """ 

195 Reshape a structured (x, y) DataArray into unstructured (face) form. 

196 Extra dimensions are maintained: 

197 e.g. (time, layer, x, y) becomes (time, layer, face). 

198 

199 Parameters 

200 ---------- 

201 da: xr.DataArray 

202 Structured DataArray with last two dimensions ("y", "x"). 

203 

204 Returns 

205 ------- 

206 Unstructured DataArray with dimensions ("y", "x") replaced by ("face",). 

207 """ 

208 if da.dims[-2:] != ("y", "x"): 

209 raise ValueError('Last two dimensions of da must be ("y", "x")') 

210 dims = da.dims[:-2] 

211 coords = {k: da.coords[k] for k in dims} 

212 return xr.DataArray( 

213 da.data.reshape(*da.shape[:-2], -1), 

214 coords=coords, 

215 dims=[*dims, face_dim], 

216 name=da.name, 

217 ) 

218 

219 

220def unstack_dim_into_variable(dataset: GridDataset, dim: str) -> GridDataset: 

221 """ 

222 Unstack each variable containing ``dim`` into separate variables. 

223 """ 

224 unstacked = dataset.copy() 

225 

226 variables_containing_dim = [ 

227 variable for variable in dataset.data_vars if dim in dataset[variable].dims 

228 ] 

229 

230 for variable in variables_containing_dim: 

231 stacked = unstacked[variable] 

232 unstacked = unstacked.drop_vars(variable) 

233 for index in stacked[dim].values: 

234 unstacked[f"{variable}_{dim}_{index}"] = stacked.sel( 

235 indexers={dim: index}, drop=True 

236 ) 

237 if dim in unstacked.coords: 

238 unstacked = unstacked.drop_vars(dim) 

239 return unstacked 

240 

241 

242def mdal_compliant_ugrid2d(dataset: xr.Dataset) -> xr.Dataset: 

243 """ 

244 Ensures the xarray Dataset will be written to a UGRID netCDF that will be 

245 accepted by MDAL. 

246 

247 * Unstacks variables with a layer dimension into separate variables. 

248 * Removes absent entries from the mesh topology attributes. 

249 * Sets encoding to float for datetime variables. 

250 

251 Parameters 

252 ---------- 

253 dataset: xarray.Dataset 

254 

255 Returns 

256 ------- 

257 unstacked: xr.Dataset 

258 

259 """ 

260 ds = unstack_dim_into_variable(dataset, "layer") 

261 

262 # Find topology variables 

263 for variable in ds.data_vars: 

264 attrs = ds[variable].attrs 

265 if attrs.get("cf_role") == "mesh_topology": 

266 # Possible attributes: 

267 # 

268 # "cf_role" 

269 # "long_name" 

270 # "topology_dimension" 

271 # "node_dimension": required 

272 # "node_coordinates": required 

273 # "edge_dimension": optional 

274 # "edge_node_connectivity": optional 

275 # "face_dimension": required 

276 # "face_node_connectivity": required 

277 # "max_face_nodes_dimension": required 

278 # "face_coordinates": optional 

279 

280 node_dim = attrs.get("node_dimension") 

281 edge_dim = attrs.get("edge_dimension") 

282 face_dim = attrs.get("face_dimension") 

283 

284 # Drop the coordinates on the UGRID dimensions 

285 to_drop = [] 

286 for dim in (node_dim, edge_dim, face_dim): 

287 if dim is not None and dim in ds.coords: 

288 to_drop.append(dim) 

289 ds = ds.drop_vars(to_drop) 

290 

291 if edge_dim and edge_dim not in ds.dims: 

292 attrs.pop("edge_dimension") 

293 

294 face_coords = attrs.get("face_coordinates") 

295 if face_coords and face_coords not in ds.coords: 

296 attrs.pop("face_coordinates") 

297 

298 edge_nodes = attrs.get("edge_node_connectivity") 

299 if edge_nodes and edge_nodes not in ds: 

300 attrs.pop("edge_node_connectivity") 

301 

302 # Make sure time is encoded as a float for MDAL 

303 # TODO: MDAL requires all data variables to be float (this excludes the UGRID topology data) 

304 for var in ds.coords: 

305 if np.issubdtype(ds[var].dtype, np.datetime64): 

306 ds[var].encoding["dtype"] = np.float64 

307 

308 return ds 

309 

310 

311def from_mdal_compliant_ugrid2d(dataset: xu.UgridDataset): 

312 """ 

313 Undo some of the changes of ``mdal_compliant_ugrid2d``: re-stack the 

314 layers. 

315 

316 Parameters 

317 ---------- 

318 dataset: xugrid.UgridDataset 

319 

320 Returns 

321 ------- 

322 restacked: xugrid.UgridDataset 

323 

324 """ 

325 ds = dataset.ugrid.obj 

326 pattern = re.compile(r"(\w+)_layer_(\d+)") 

327 matches = [(variable, pattern.search(variable)) for variable in ds.data_vars] 

328 matches = [(variable, match) for (variable, match) in matches if match is not None] 

329 if not matches: 

330 return dataset 

331 

332 # First deal with the variables that may remain untouched. 

333 other_vars = set(ds.data_vars).difference([variable for (variable, _) in matches]) 

334 restacked = ds[list(other_vars)] 

335 

336 # Next group by name, which will be the output dataset variable name. 

337 grouped = collections.defaultdict(list) 

338 for variable, match in matches: 

339 name, layer = match.groups() 

340 da = ds[variable] 

341 grouped[name].append(da.assign_coords(layer=int(layer))) 

342 

343 # Concatenate, and make sure the dimension order is natural. 

344 ugrid_dims = set([dim for grid in dataset.ugrid.grids for dim in grid.dimensions]) 

345 for variable, das in grouped.items(): 

346 da = xr.concat(sorted(das, key=lambda da: da["layer"]), dim="layer") 

347 newdims = list(da.dims) 

348 newdims.remove("layer") 

349 # If it's a spatial dataset, the layer should be second last. 

350 if ugrid_dims.intersection(newdims): 

351 newdims.insert(-1, "layer") 

352 # If not, the layer should be last. 

353 else: 

354 newdims.append("layer") 

355 if tuple(newdims) != da.dims: 

356 da = da.transpose(*newdims) 

357 

358 restacked[variable] = da 

359 

360 return xu.UgridDataset(restacked, grids=dataset.ugrid.grids) 

361 

362 

363def to_ugrid2d(data: Union[xr.DataArray, xr.Dataset]) -> xr.Dataset: 

364 """ 

365 Convert a structured DataArray or Dataset into its UGRID-2D quadrilateral 

366 equivalent. 

367 

368 See: 

369 https://ugrid-conventions.github.io/ugrid-conventions/#2d-flexible-mesh-mixed-triangles-quadrilaterals-etc-topology 

370 

371 Parameters 

372 ---------- 

373 data: Union[xr.DataArray, xr.Dataset] 

374 Dataset or DataArray with last two dimensions ("y", "x"). 

375 In case of a Dataset, the 2D topology is defined once and variables are 

376 added one by one. 

377 In case of a DataArray, a name is required; a name can be set with: 

378 ``da.name = "..."``' 

379 

380 Returns 

381 ------- 

382 ugrid2d_dataset: xr.Dataset 

383 The equivalent data, in UGRID-2D quadrilateral form. 

384 """ 

385 if not isinstance(data, (xr.DataArray, xr.Dataset)): 

386 raise TypeError("data must be xarray.DataArray or xr.Dataset") 

387 

388 grid = xu.Ugrid2d.from_structured(data) 

389 ds = grid.to_dataset() 

390 

391 if isinstance(data, xr.Dataset): 

392 for variable in data.data_vars: 

393 ds[variable] = ugrid2d_data(data[variable], grid.face_dimension) 

394 if isinstance(data, xr.DataArray): 

395 if data.name is None: 

396 raise ValueError( 

397 'A name is required for the DataArray. It can be set with ``da.name = "..."`' 

398 ) 

399 ds[data.name] = ugrid2d_data(data, grid.face_dimension) 

400 return mdal_compliant_ugrid2d(ds) 

401 

402 

403def empty_2d( 

404 dx: Union[float, FloatArray], 

405 xmin: float, 

406 xmax: float, 

407 dy: Union[float, FloatArray], 

408 ymin: float, 

409 ymax: float, 

410) -> xr.DataArray: 

411 """ 

412 Create an empty 2D (x, y) DataArray. 

413 

414 ``dx`` and ``dy`` may be provided as: 

415 

416 * scalar: for equidistant spacing 

417 * array: for non-equidistant spacing 

418 

419 Note that xarray (and netCDF4) uses midpoint coordinates. ``xmin`` and 

420 ``xmax`` are used to generate the appropriate midpoints. 

421 

422 Parameters 

423 ---------- 

424 dx: float, 1d array of floats 

425 cell size along x 

426 xmin: float 

427 xmax: float 

428 dy: float, 1d array of floats 

429 cell size along y 

430 ymin: float 

431 ymax: float 

432 

433 Returns 

434 ------- 

435 empty: xr.DataArray 

436 Filled with NaN. 

437 """ 

438 bounds = (xmin, xmax, ymin, ymax) 

439 cellsizes = (abs(dx), -abs(dy)) 

440 coords = _xycoords(bounds, cellsizes) 

441 nrow = coords["y"].size 

442 ncol = coords["x"].size 

443 return xr.DataArray( 

444 data=np.full((nrow, ncol), np.nan), coords=coords, dims=["y", "x"] 

445 ) 

446 

447 

448def empty_3d( 

449 dx: Union[float, FloatArray], 

450 xmin: float, 

451 xmax: float, 

452 dy: Union[float, FloatArray], 

453 ymin: float, 

454 ymax: float, 

455 layer: Union[int, Sequence[int], IntArray], 

456) -> xr.DataArray: 

457 """ 

458 Create an empty 2D (x, y) DataArray. 

459 

460 ``dx`` and ``dy`` may be provided as: 

461 

462 * scalar: for equidistant spacing 

463 * array: for non-equidistant spacing 

464 

465 Note that xarray (and netCDF4) uses midpoint coordinates. ``xmin`` and 

466 ``xmax`` are used to generate the appropriate midpoints. 

467 

468 Parameters 

469 ---------- 

470 dx: float, 1d array of floats 

471 cell size along x 

472 xmin: float 

473 xmax: float 

474 dy: float, 1d array of floats 

475 cell size along y 

476 ymin: float 

477 ymax: float 

478 layer: int, sequence of integers, 1d array of integers 

479 

480 Returns 

481 ------- 

482 empty: xr.DataArray 

483 Filled with NaN. 

484 """ 

485 bounds = (xmin, xmax, ymin, ymax) 

486 cellsizes = (abs(dx), -abs(dy)) 

487 coords = _xycoords(bounds, cellsizes) 

488 nrow = coords["y"].size 

489 ncol = coords["x"].size 

490 layer = _layer(layer) 

491 coords["layer"] = layer 

492 

493 return xr.DataArray( 

494 data=np.full((layer.size, nrow, ncol), np.nan), 

495 coords=coords, 

496 dims=["layer", "y", "x"], 

497 ) 

498 

499 

500def empty_2d_transient( 

501 dx: Union[float, FloatArray], 

502 xmin: float, 

503 xmax: float, 

504 dy: Union[float, FloatArray], 

505 ymin: float, 

506 ymax: float, 

507 time: Any, 

508) -> xr.DataArray: 

509 """ 

510 Create an empty transient 2D (time, x, y) DataArray. 

511 

512 ``dx`` and ``dy`` may be provided as: 

513 

514 * scalar: for equidistant spacing 

515 * array: for non-equidistant spacing 

516 

517 Note that xarray (and netCDF4) uses midpoint coordinates. ``xmin`` and 

518 ``xmax`` are used to generate the appropriate midpoints. 

519 

520 Parameters 

521 ---------- 

522 dx: float, 1d array of floats 

523 cell size along x 

524 xmin: float 

525 xmax: float 

526 dy: float, 1d array of floats 

527 cell size along y 

528 ymin: float 

529 ymax: float 

530 time: Any 

531 One or more of: str, numpy datetime64, pandas Timestamp 

532 

533 Returns 

534 ------- 

535 empty: xr.DataArray 

536 Filled with NaN. 

537 """ 

538 bounds = (xmin, xmax, ymin, ymax) 

539 cellsizes = (abs(dx), -abs(dy)) 

540 coords = _xycoords(bounds, cellsizes) 

541 nrow = coords["y"].size 

542 ncol = coords["x"].size 

543 time = _time(time) 

544 coords["time"] = time 

545 return xr.DataArray( 

546 data=np.full((time.size, nrow, ncol), np.nan), 

547 coords=coords, 

548 dims=["time", "y", "x"], 

549 ) 

550 

551 

552def empty_3d_transient( 

553 dx: Union[float, FloatArray], 

554 xmin: float, 

555 xmax: float, 

556 dy: Union[float, FloatArray], 

557 ymin: float, 

558 ymax: float, 

559 layer: Union[int, Sequence[int], IntArray], 

560 time: Any, 

561) -> xr.DataArray: 

562 """ 

563 Create an empty transient 3D (time, layer, x, y) DataArray. 

564 

565 ``dx`` and ``dy`` may be provided as: 

566 

567 * scalar: for equidistant spacing 

568 * array: for non-equidistant spacing 

569 

570 Note that xarray (and netCDF4) uses midpoint coordinates. ``xmin`` and 

571 ``xmax`` are used to generate the appropriate midpoints. 

572 

573 Parameters 

574 ---------- 

575 dx: float, 1d array of floats 

576 cell size along x 

577 xmin: float 

578 xmax: float 

579 dy: float, 1d array of floats 

580 cell size along y 

581 ymin: float 

582 ymax: float 

583 layer: int, sequence of integers, 1d array of integers 

584 time: Any 

585 One or more of: str, numpy datetime64, pandas Timestamp 

586 

587 Returns 

588 ------- 

589 empty: xr.DataArray 

590 Filled with NaN. 

591 """ 

592 bounds = (xmin, xmax, ymin, ymax) 

593 cellsizes = (abs(dx), -abs(dy)) 

594 coords = _xycoords(bounds, cellsizes) 

595 nrow = coords["y"].size 

596 ncol = coords["x"].size 

597 layer = _layer(layer) 

598 coords["layer"] = layer 

599 time = _time(time) 

600 coords["time"] = time 

601 return xr.DataArray( 

602 data=np.full((time.size, layer.size, nrow, ncol), np.nan), 

603 coords=coords, 

604 dims=["time", "layer", "y", "x"], 

605 ) 

606 

607 

608def _layer(layer: Union[int, Sequence[int], IntArray]) -> IntArray: 

609 layer = np.atleast_1d(layer) 

610 if layer.ndim > 1: 

611 raise ValueError("layer must be 1d") 

612 return layer 

613 

614 

615def _time(time: Any) -> Any: 

616 time = np.atleast_1d(time) 

617 if time.ndim > 1: 

618 raise ValueError("time must be 1d") 

619 return pd.to_datetime(time) 

620 

621 

622def is_divisor(numerator: FloatArray, denominator: float) -> bool: 

623 """ 

624 Parameters 

625 ---------- 

626 numerator: np.array of floats 

627 denominator: float 

628 

629 Returns 

630 ------- 

631 is_divisor: bool 

632 """ 

633 denominator = abs(denominator) 

634 remainder = np.abs(numerator) % denominator 

635 return (np.isclose(remainder, 0.0) | np.isclose(remainder, denominator)).all() 

636 

637 

638def _polygonize(da: xr.DataArray) -> gpd.GeoDataFrame: 

639 """ 

640 Polygonize a 2D-DataArray into a GeoDataFrame of polygons. 

641 

642 Private method located in util.spatial to work around circular imports. 

643 """ 

644 

645 if da.dims != ("y", "x"): 

646 raise ValueError('Dimensions must be ("y", "x")') 

647 

648 values = da.values 

649 if values.dtype == np.float64: 

650 values = values.astype(np.float32) 

651 

652 affine_transform = transform(da) 

653 shapes = rasterio.features.shapes(values, transform=affine_transform) 

654 

655 geometries = [] 

656 colvalues = [] 

657 for geom, colval in shapes: 

658 geometries.append(shapely.geometry.Polygon(geom["coordinates"][0])) 

659 colvalues.append(colval) 

660 

661 gdf = gpd.GeoDataFrame({"value": colvalues, "geometry": geometries}) 

662 gdf.crs = da.attrs.get("crs") 

663 return gdf