Coverage for C:\src\imod-python\imod\typing\grid.py: 91%

193 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 11:25 +0200

1import pickle 

2import textwrap 

3from functools import wraps 

4from typing import Callable, Mapping, Sequence 

5 

6import numpy as np 

7import xarray as xr 

8import xugrid as xu 

9from fastcore.dispatch import typedispatch 

10 

11from imod.typing import GridDataArray, GridDataset, structured 

12from imod.util.spatial import _polygonize 

13 

14 

15@typedispatch 

16def zeros_like(grid: xr.DataArray, *args, **kwargs): 

17 return xr.zeros_like(grid, *args, **kwargs) 

18 

19 

20@typedispatch # type: ignore[no-redef] 

21def zeros_like(grid: xu.UgridDataArray, *args, **kwargs): # noqa: F811 

22 return xu.zeros_like(grid, *args, **kwargs) 

23 

24 

25@typedispatch 

26def ones_like(grid: xr.DataArray, *args, **kwargs): 

27 return xr.ones_like(grid, *args, **kwargs) 

28 

29 

30@typedispatch # type: ignore[no-redef] 

31def ones_like(grid: xu.UgridDataArray, *args, **kwargs): # noqa: F811 

32 return xu.ones_like(grid, *args, **kwargs) 

33 

34 

35@typedispatch 

36def nan_like(grid: xr.DataArray, dtype=np.float32, *args, **kwargs): 

37 return xr.full_like(grid, fill_value=np.nan, dtype=dtype, *args, **kwargs) 

38 

39 

40@typedispatch # type: ignore[no-redef] 

41def nan_like(grid: xu.UgridDataArray, dtype=np.float32, *args, **kwargs): # noqa: F811 

42 return xu.full_like(grid, fill_value=np.nan, dtype=dtype, *args, **kwargs) 

43 

44 

45@typedispatch 

46def is_unstructured(grid: xu.UgridDataArray | xu.UgridDataset) -> bool: 

47 return True 

48 

49 

50@typedispatch # type: ignore[no-redef] 

51def is_unstructured(grid: xr.DataArray | xr.Dataset) -> bool: # noqa: F811 

52 return False 

53 

54 

55def _force_decreasing_y(structured_grid: xr.DataArray | xr.Dataset): 

56 flip = slice(None, None, -1) 

57 if structured_grid.indexes["y"].is_monotonic_increasing: 

58 structured_grid = structured_grid.isel(y=flip) 

59 elif not structured_grid.indexes["y"].is_monotonic_decreasing: 

60 raise RuntimeError( 

61 f"Non-monotonous y-coordinates for grid: {structured_grid.name}." 

62 ) 

63 return structured_grid 

64 

65 

66def _get_first_item(objects: Sequence): 

67 return next(iter(objects)) 

68 

69 

70# Typedispatching doesn't work based on types of list elements, therefore resort to 

71# isinstance testing 

72def _type_dispatch_functions_on_grid_sequence( 

73 objects: Sequence[GridDataArray | GridDataset], 

74 unstructured_func: Callable, 

75 structured_func: Callable, 

76 *args, 

77 **kwargs, 

78) -> GridDataArray | GridDataset: 

79 """ 

80 Type dispatch functions on sequence of grids. Functions like merging or concatenating. 

81 """ 

82 first_object = _get_first_item(objects) 

83 start_type = type(first_object) 

84 homogeneous = all(isinstance(o, start_type) for o in objects) 

85 if not homogeneous: 

86 unique_types = {type(o) for o in objects} 

87 raise TypeError( 

88 f"Only homogeneous sequences can be reduced, received sequence of {unique_types}" 

89 ) 

90 if isinstance(first_object, (xu.UgridDataArray, xu.UgridDataset)): 

91 return unstructured_func(objects, *args, **kwargs) 

92 elif isinstance(first_object, (xr.DataArray, xr.Dataset)): 

93 return _force_decreasing_y(structured_func(objects, *args, **kwargs)) 

94 raise TypeError( 

95 f"'{unstructured_func.__name__}' not supported for type {type(objects[0])}" 

96 ) 

97 

98 

99# Typedispatching doesn't work based on types of dict elements, therefore resort 

100# to manual type testing 

101def _type_dispatch_functions_on_dict( 

102 dict_of_objects: Mapping[str, GridDataArray | float | bool | int], 

103 unstructured_func: Callable, 

104 structured_func: Callable, 

105 *args, 

106 **kwargs, 

107): 

108 """ 

109 Typedispatch function on grid and scalar variables provided in dictionary. 

110 Types do not need to be homogeneous as scalars and grids can be mixed. No 

111 mixing of structured and unstructured grids is allowed. Also allows running 

112 function on dictionary with purely scalars, in which case it will call to 

113 the xarray function. 

114 """ 

115 

116 error_msg = textwrap.dedent( 

117 """ 

118 Received both structured grid (xr.DataArray) and xu.UgridDataArray. This 

119 means structured grids as well as unstructured grids were provided. 

120 """ 

121 ) 

122 

123 if dict_of_objects is None: 

124 return xr.Dataset() 

125 

126 types = [type(arg) for arg in dict_of_objects.values()] 

127 has_unstructured = xu.UgridDataArray in types 

128 # Test structured if xr.DataArray and spatial. 

129 has_structured_grid = any( 

130 isinstance(arg, xr.DataArray) and is_spatial_grid(arg) 

131 for arg in dict_of_objects.values() 

132 ) 

133 if has_structured_grid and has_unstructured: 

134 raise TypeError(error_msg) 

135 if has_unstructured: 

136 return unstructured_func([dict_of_objects], *args, **kwargs) 

137 

138 return structured_func([dict_of_objects], *args, **kwargs) 

139 

140 

141def merge( 

142 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs 

143) -> GridDataset: 

144 return _type_dispatch_functions_on_grid_sequence( 

145 objects, xu.merge, xr.merge, *args, **kwargs 

146 ) 

147 

148 

149def merge_partitions( 

150 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs 

151) -> GridDataArray | GridDataset: 

152 return _type_dispatch_functions_on_grid_sequence( 

153 objects, xu.merge_partitions, structured.merge_partitions, *args, **kwargs 

154 ) 

155 

156 

157def concat( 

158 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs 

159) -> GridDataArray | GridDataset: 

160 return _type_dispatch_functions_on_grid_sequence( 

161 objects, xu.concat, xr.concat, *args, **kwargs 

162 ) 

163 

164 

165def merge_unstructured_dataset(variables_to_merge: list[dict], *args, **kwargs): 

166 """ 

167 Work around xugrid issue https://github.com/Deltares/xugrid/issues/179 

168 

169 Expects only one dictionary in list. List is used to have same API as 

170 xr.merge(). 

171 

172 Merges unstructured grids first, then manually assigns scalar variables. 

173 """ 

174 if len(variables_to_merge) > 1: 

175 raise ValueError( 

176 f"Only one dict of variables expected, got {len(variables_to_merge)}" 

177 ) 

178 

179 variables_to_merge_dict = variables_to_merge[0] 

180 

181 if not isinstance(variables_to_merge_dict, dict): 

182 raise TypeError(f"Expected dict, got {type(variables_to_merge_dict)}") 

183 

184 # Separate variables into list of grids and dict of scalar variables 

185 grids_ls = [] 

186 scalar_dict = {} 

187 for name, variable in variables_to_merge_dict.items(): 

188 if isinstance(variable, xu.UgridDataArray): 

189 grids_ls.append(variable.rename(name)) 

190 else: 

191 scalar_dict[name] = variable 

192 

193 # Merge grids 

194 dataset = xu.merge(grids_ls, *args, **kwargs) 

195 

196 # Temporarily work around this xugrid issue, until fixed: 

197 # https://github.com/Deltares/xugrid/issues/206 

198 grid_hashes = [hash(pickle.dumps(grid)) for grid in dataset.ugrid.grids] 

199 unique_grid_hashes = np.unique(grid_hashes) 

200 if unique_grid_hashes.size > 1: 

201 raise ValueError( 

202 "Multiple grids provided, please provide data on one unique grid" 

203 ) 

204 else: 

205 # Possibly won't work anymore if this ever gets implemented: 

206 # https://github.com/Deltares/xugrid/issues/195 

207 dataset._grids = [dataset.grids[0]] 

208 

209 # Assign scalar variables manually 

210 for name, variable in scalar_dict.items(): 

211 dataset[name] = variable 

212 

213 return dataset 

214 

215 

216def merge_with_dictionary( 

217 variables_to_merge: Mapping[str, GridDataArray | float | bool | int], 

218 *args, 

219 **kwargs, 

220): 

221 return _type_dispatch_functions_on_dict( 

222 variables_to_merge, merge_unstructured_dataset, xr.merge, *args, **kwargs 

223 ) 

224 

225 

226@typedispatch 

227def bounding_polygon(active: xr.DataArray): 

228 """Return bounding polygon of active cells""" 

229 to_polygonize = active.where(active, other=np.nan) 

230 polygons_gdf = _polygonize(to_polygonize) 

231 # Filter polygons with inactive values (NaN) 

232 is_active_polygon = polygons_gdf["value"] == 1.0 

233 return polygons_gdf.loc[is_active_polygon] 

234 

235 

236@typedispatch # type: ignore[no-redef] 

237def bounding_polygon(active: xu.UgridDataArray): # noqa: F811 

238 """Return bounding polygon of active cells""" 

239 active_indices = np.where(active > 0)[0] 

240 domain_slice = {f"{active.ugrid.grid.face_dimension}": active_indices} 

241 active_clipped = active.isel(domain_slice, missing_dims="ignore") 

242 

243 return active_clipped.ugrid.grid.bounding_polygon() 

244 

245 

246@typedispatch 

247def is_spatial_grid(array: xr.DataArray | xr.Dataset) -> bool: 

248 """Return True if the array contains data in at least 2 spatial dimensions""" 

249 coords = array.coords 

250 dims = array.dims 

251 has_spatial_coords = "x" in coords and "y" in coords 

252 has_spatial_dims = "x" in dims and "y" in dims 

253 return has_spatial_coords & has_spatial_dims 

254 

255 

256@typedispatch # type: ignore[no-redef] 

257def is_spatial_grid(array: xu.UgridDataArray | xu.UgridDataset) -> bool: # noqa: F811 

258 """Return True if the array contains data associated to cell faces""" 

259 face_dim = array.ugrid.grid.face_dimension 

260 dims = array.dims 

261 coords = array.coords 

262 has_spatial_coords = face_dim in coords 

263 has_spatial_dims = face_dim in dims 

264 return has_spatial_dims & has_spatial_coords 

265 

266 

267@typedispatch # type: ignore[no-redef] 

268def is_spatial_grid(_: object) -> bool: # noqa: F811 

269 return False 

270 

271 

272@typedispatch 

273def is_equal(array1: xu.UgridDataArray, array2: xu.UgridDataArray) -> bool: 

274 return array1.equals(array2) and array1.ugrid.grid.equals(array2.ugrid.grid) 

275 

276 

277@typedispatch # type: ignore[no-redef] 

278def is_equal(array1: xr.DataArray, array2: xr.DataArray) -> bool: # noqa: F811 

279 return array1.equals(array2) 

280 

281 

282@typedispatch # type: ignore[no-redef] 

283def is_equal(array1: object, array2: object) -> bool: # noqa: F811 

284 return False 

285 

286 

287@typedispatch 

288def is_same_domain(grid1: xu.UgridDataArray, grid2: xu.UgridDataArray) -> bool: 

289 return grid1.coords.equals(grid2.coords) and grid1.ugrid.grid.equals( 

290 grid2.ugrid.grid 

291 ) 

292 

293 

294@typedispatch # type: ignore[no-redef] 

295def is_same_domain(grid1: xr.DataArray, grid2: xr.DataArray) -> bool: # noqa: F811 

296 return grid1.coords.equals(grid2.coords) 

297 

298 

299@typedispatch # type: ignore[no-redef] 

300def is_same_domain(grid1: object, grid2: object) -> bool: # noqa: F811 

301 return False 

302 

303 

304@typedispatch 

305def get_spatial_dimension_names(grid: xr.DataArray) -> list[str]: 

306 return ["x", "y", "layer", "dx", "dy"] 

307 

308 

309@typedispatch # type: ignore[no-redef] 

310def get_spatial_dimension_names(grid: xu.UgridDataArray) -> list[str]: # noqa: F811 

311 facedim = grid.ugrid.grid.face_dimension 

312 return [facedim, "layer"] 

313 

314 

315@typedispatch # type: ignore[no-redef] 

316def get_spatial_dimension_names(grid: object) -> list[str]: # noqa: F811 

317 return [] 

318 

319 

320@typedispatch 

321def get_grid_geometry_hash(grid: xr.DataArray) -> int: 

322 hash_x = hash(pickle.dumps(grid["x"].values)) 

323 hash_y = hash(pickle.dumps(grid["y"].values)) 

324 return (hash_x, hash_y) 

325 

326 

327@typedispatch # type: ignore[no-redef] 

328def get_grid_geometry_hash(grid: xu.UgridDataArray) -> int: # noqa: F811 

329 hash_x = hash(pickle.dumps(grid.ugrid.grid.node_x)) 

330 hash_y = hash(pickle.dumps(grid.ugrid.grid.node_y)) 

331 hash_connectivity = hash(pickle.dumps(grid.ugrid.grid.node_face_connectivity)) 

332 return (hash_x, hash_y, hash_connectivity) 

333 

334 

335@typedispatch # type: ignore[no-redef] 

336def get_grid_geometry_hash(grid: object) -> int: # noqa: F811 

337 raise ValueError("get_grid_geometry_hash not supported for this object.") 

338 

339 

340@typedispatch 

341def enforce_dim_order(grid: xr.DataArray) -> xr.DataArray: 

342 """Enforce dimension order to iMOD Python standard""" 

343 return grid.transpose("species", "time", "layer", "y", "x", missing_dims="ignore") 

344 

345 

346@typedispatch # type: ignore[no-redef] 

347def enforce_dim_order(grid: xu.UgridDataArray) -> xu.UgridDataArray: # noqa: F811 

348 """Enforce dimension order to iMOD Python standard""" 

349 face_dimension = grid.ugrid.grid.face_dimension 

350 return grid.transpose( 

351 "species", "time", "layer", face_dimension, missing_dims="ignore" 

352 ) 

353 

354 

355def _enforce_unstructured(obj: GridDataArray, ugrid2d=xu.Ugrid2d) -> xu.UgridDataArray: 

356 """Force obj to unstructured""" 

357 return xu.UgridDataArray(xr.DataArray(obj), ugrid2d) 

358 

359 

360def preserve_gridtype(func): 

361 """ 

362 Decorator to preserve gridtype, this is to work around the following xugrid 

363 behavior: 

364 

365 >>> UgridDataArray() * DataArray() -> UgridDataArray 

366 >>> DataArray() * UgridDataArray() -> DataArray 

367 

368 with this decorator: 

369 

370 >>> UgridDataArray() * DataArray() -> UgridDataArray 

371 >>> DataArray() * UgridDataArray() -> UgridDataArray 

372 """ 

373 

374 @wraps(func) 

375 def decorator(*args, **kwargs): 

376 unstructured = False 

377 grid = None 

378 for arg in args: 

379 if is_unstructured(arg): 

380 unstructured = True 

381 grid = arg.ugrid.grid 

382 

383 x = func(*args, **kwargs) 

384 

385 if unstructured: 

386 # Multiple grids returned 

387 if isinstance(x, tuple): 

388 return tuple(_enforce_unstructured(i, grid) for i in x) 

389 return _enforce_unstructured(x, grid) 

390 return x 

391 

392 return decorator