Coverage for C:\src\imod-python\imod\mf6\utilities\regrid.py: 97%

179 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 11:25 +0200

1import abc 

2import copy 

3from collections import defaultdict 

4from typing import Any, Dict, Optional, Tuple, Union 

5 

6import xarray as xr 

7from fastcore.dispatch import typedispatch 

8from xarray.core.utils import is_scalar 

9from xugrid.regrid.regridder import BaseRegridder 

10 

11from imod.mf6.auxiliary_variables import ( 

12 expand_transient_auxiliary_variables, 

13 remove_expanded_auxiliary_variables_from_dataset, 

14) 

15from imod.mf6.interfaces.ilinedatapackage import ILineDataPackage 

16from imod.mf6.interfaces.imodel import IModel 

17from imod.mf6.interfaces.ipackage import IPackage 

18from imod.mf6.interfaces.ipointdatapackage import IPointDataPackage 

19from imod.mf6.interfaces.iregridpackage import IRegridPackage 

20from imod.mf6.interfaces.isimulation import ISimulation 

21from imod.mf6.statusinfo import NestedStatusInfo 

22from imod.mf6.utilities.clip import clip_by_grid 

23from imod.mf6.utilities.regridding_types import RegridderType 

24from imod.schemata import ValidationError 

25from imod.typing.grid import GridDataArray, get_grid_geometry_hash, ones_like 

26 

27HashRegridderMapping = Tuple[int, int, BaseRegridder] 

28 

29 

30class RegridderWeightsCache: 

31 """ 

32 This class stores any number of regridders that can regrid a single source grid to a single target grid. 

33 By storing the regridders, we make sure the regridders can be re-used for different arrays on the same grid. 

34 Regridders are stored based on their type (`see these docs<https://deltares.github.io/xugrid/examples/regridder_overview.html>`_) and planar coordinates (x, y). 

35 This is important because computing the regridding weights is a costly affair. 

36 """ 

37 

38 def __init__( 

39 self, 

40 max_cache_size: int = 6, 

41 ) -> None: 

42 self.regridder_instances: dict[ 

43 tuple[type[BaseRegridder], Optional[str]], BaseRegridder 

44 ] = {} 

45 self.weights_cache: Dict[HashRegridderMapping, GridDataArray] = {} 

46 self.max_cache_size = max_cache_size 

47 

48 def __get_regridder_class( 

49 self, regridder_type: RegridderType | BaseRegridder 

50 ) -> type[BaseRegridder]: 

51 if isinstance(regridder_type, abc.ABCMeta): 

52 if not issubclass(regridder_type, BaseRegridder): 

53 raise ValueError( 

54 "only derived types of BaseRegridder can be instantiated" 

55 ) 

56 return regridder_type 

57 elif isinstance(regridder_type, RegridderType): 

58 return regridder_type.value 

59 

60 raise ValueError("invalid type for regridder") 

61 

62 def get_regridder( 

63 self, 

64 source_grid: GridDataArray, 

65 target_grid: GridDataArray, 

66 regridder_type: Union[RegridderType, BaseRegridder], 

67 method: Optional[str] = None, 

68 ) -> BaseRegridder: 

69 """ 

70 returns a regridder of the specified type and with the specified method. 

71 The desired type can be passed through the argument "regridder_type" as an enumerator or 

72 as a class. 

73 The following two are equivalent: 

74 instancesCollection.get_regridder(RegridderType.OVERLAP, "mean") 

75 instancesCollection.get_regridder(xu.OverlapRegridder, "mean") 

76 

77 

78 Parameters 

79 ---------- 

80 regridder_type: RegridderType or regridder class 

81 indicates the desired regridder type 

82 method: str or None 

83 indicates the method the regridder should apply 

84 

85 Returns 

86 ------- 

87 a regridder of the specified characteristics 

88 """ 

89 regridder_class = self.__get_regridder_class(regridder_type) 

90 

91 if "layer" not in source_grid.coords and "layer" in target_grid.coords: 

92 target_grid = target_grid.drop_vars("layer") 

93 

94 source_hash = get_grid_geometry_hash(source_grid) 

95 target_hash = get_grid_geometry_hash(target_grid) 

96 key = (source_hash, target_hash, regridder_class) 

97 if key not in self.weights_cache.keys(): 

98 if len(self.weights_cache) >= self.max_cache_size: 

99 self.remove_first_regridder() 

100 kwargs = {"source": source_grid, "target": target_grid} 

101 if method is not None: 

102 kwargs["method"] = method 

103 regridder = regridder_class(**kwargs) 

104 self.weights_cache[key] = regridder.weights 

105 else: 

106 kwargs = {"weights": self.weights_cache[key], "target": target_grid} 

107 if method is not None: 

108 kwargs["method"] = method 

109 regridder = regridder_class.from_weights(**kwargs) 

110 

111 return regridder 

112 

113 def remove_first_regridder(self): 

114 keys = list(self.weights_cache.keys()) 

115 self.weights_cache.pop(keys[0]) 

116 

117 

118def assign_coord_if_present( 

119 coordname: str, target_grid: GridDataArray, maybe_has_coords_attr: Any 

120): 

121 """ 

122 If ``maybe_has_coords`` has a ``coords`` attribute and if coordname in 

123 target_grid, copy coord. 

124 """ 

125 if coordname in target_grid.coords: 

126 if coordname in target_grid.coords and hasattr(maybe_has_coords_attr, "coords"): 

127 maybe_has_coords_attr = maybe_has_coords_attr.assign_coords( 

128 {coordname: target_grid.coords[coordname].values[()]} 

129 ) 

130 return maybe_has_coords_attr 

131 

132 

133def _regrid_array( 

134 package: IRegridPackage, 

135 varname: str, 

136 regridder_collection: RegridderWeightsCache, 

137 regridder_name: str, 

138 regridder_function: str, 

139 target_grid: GridDataArray, 

140) -> Optional[GridDataArray]: 

141 """ 

142 Regrids a data_array. The array is specified by its key in the dataset. 

143 Each data-array can represent: 

144 -a scalar value, valid for the whole grid 

145 -an array of a different scalar per layer 

146 -an array with a value per grid block 

147 -None 

148 """ 

149 

150 # skip regridding for arrays with no valid values (such as "None") 

151 if not package._valid(package.dataset[varname].values[()]): 

152 return None 

153 

154 # the dataarray might be a scalar. If it is, then it does not need regridding. 

155 if is_scalar(package.dataset[varname]): 

156 return package.dataset[varname].values[()] 

157 

158 if isinstance(package.dataset[varname], xr.DataArray): 

159 coords = package.dataset[varname].coords 

160 # if it is an xr.DataArray it may be layer-based; then no regridding is needed 

161 if not ("x" in coords and "y" in coords): 

162 return package.dataset[varname] 

163 

164 # if it is an xr.DataArray it needs the dx, dy coordinates for regridding, which are otherwise not mandatory 

165 if not ("dx" in coords and "dy" in coords): 

166 raise ValueError( 

167 f"DataArray {varname} does not have both a dx and dy coordinates" 

168 ) 

169 

170 # obtain an instance of a regridder for the chosen method 

171 regridder = regridder_collection.get_regridder( 

172 package.dataset[varname], 

173 target_grid, 

174 regridder_name, 

175 regridder_function, 

176 ) 

177 

178 # store original dtype of data 

179 original_dtype = package.dataset[varname].dtype 

180 

181 # regrid data array 

182 regridded_array = regridder.regrid(package.dataset[varname]) 

183 

184 # reconvert the result to the same dtype as the original 

185 return regridded_array.astype(original_dtype) 

186 

187 

188def _get_unique_regridder_types(model: IModel) -> defaultdict[RegridderType, list[str]]: 

189 """ 

190 This function loops over the packages and collects all regridder-types that are in use. 

191 """ 

192 methods: defaultdict = defaultdict(list) 

193 regrid_packages = [pkg for pkg in model.values() if isinstance(pkg, IRegridPackage)] 

194 regrid_packages_with_methods = { 

195 pkg: pkg.get_regrid_methods().items() # type: ignore # noqa: union-attr 

196 for pkg in regrid_packages 

197 if pkg.get_regrid_methods() is not None 

198 } 

199 

200 for pkg, regrid_methods in regrid_packages_with_methods.items(): 

201 for variable, regrid_method in regrid_methods: 

202 if variable in pkg.dataset.data_vars: 

203 functiontype = None 

204 regriddertype = regrid_method[0] 

205 if len(regrid_method) > 1: 

206 functiontype = regrid_method[1] 

207 if functiontype not in methods[regriddertype]: 

208 methods[regriddertype].append(functiontype) 

209 return methods 

210 

211 

212@typedispatch 

213def _regrid_like( 

214 package: IRegridPackage, 

215 target_grid: GridDataArray, 

216 regrid_context: RegridderWeightsCache, 

217 regridder_types: Optional[dict[str, tuple[RegridderType, str]]] = None, 

218) -> IPackage: 

219 """ 

220 Creates a package of the same type as this package, based on another discretization. 

221 It regrids all the arrays in this package to the desired discretization, and leaves the options 

222 unmodified. At the moment only regridding to a different planar grid is supported, meaning 

223 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords. 

224 

225 The regridding methods can be specified in the _regrid_method attribute of the package. These are the defaults 

226 that specify how each array should be regridded. These defaults can be overridden using the input 

227 parameters of this function. 

228 

229 Examples 

230 -------- 

231 To regrid the npf package with a non-default method for the k-field, call regrid_like with these arguments: 

232 

233 >>> new_npf = npf.regrid_like(like, {"k": (imod.RegridderType.OVERLAP, "mean")}) 

234 

235 

236 Parameters 

237 ---------- 

238 target_grid: xr.DataArray or xu.UgridDataArray 

239 a grid defined over the same discretization as the one we want to regrid the package to 

240 regridder_types: dict(str->(regridder type,str)) 

241 dictionary mapping arraynames (str) to a tuple of regrid type (a specialization class of BaseRegridder) and function name (str) 

242 this dictionary can be used to override the default mapping method. 

243 regrid_context: RegridderWeightsCache 

244 stores regridder weights for different regridders. Can be used to speed up regridding, 

245 if the same regridders are used several times for regridding different arrays. 

246 

247 Returns 

248 ------- 

249 a package with the same options as this package, and with all the data-arrays regridded to another discretization, 

250 similar to the one used in input argument "target_grid" 

251 """ 

252 if not hasattr(package, "_regrid_method"): 

253 raise NotImplementedError( 

254 f"Package {type(package).__name__} does not support regridding" 

255 ) 

256 

257 if hasattr(package, "auxiliary_data_fields"): 

258 remove_expanded_auxiliary_variables_from_dataset(package) 

259 

260 regridder_settings = package.get_regrid_methods() 

261 if regridder_types is not None: 

262 regridder_settings.update(regridder_types) 

263 

264 new_package_data = package.get_non_grid_data(list(regridder_settings.keys())) 

265 

266 for ( 

267 varname, 

268 regridder_type_and_function, 

269 ) in regridder_settings.items(): 

270 regridder_function = None 

271 regridder_name = regridder_type_and_function[0] 

272 if len(regridder_type_and_function) > 1: 

273 regridder_function = regridder_type_and_function[1] 

274 

275 # skip variables that are not in this dataset 

276 if varname not in package.dataset.keys(): 

277 continue 

278 

279 # regrid the variable 

280 new_package_data[varname] = _regrid_array( 

281 package, 

282 varname, 

283 regrid_context, 

284 regridder_name, 

285 regridder_function, 

286 target_grid, 

287 ) 

288 # set dx and dy if present in target_grid 

289 new_package_data[varname] = assign_coord_if_present( 

290 "dx", target_grid, new_package_data[varname] 

291 ) 

292 new_package_data[varname] = assign_coord_if_present( 

293 "dy", target_grid, new_package_data[varname] 

294 ) 

295 if hasattr(package, "auxiliary_data_fields"): 

296 expand_transient_auxiliary_variables(package) 

297 

298 return package.__class__(**new_package_data) 

299 

300 

301@typedispatch # type: ignore[no-redef] 

302def _regrid_like( 

303 model: IModel, 

304 target_grid: GridDataArray, 

305 validate: bool = True, 

306 regrid_context: Optional[RegridderWeightsCache] = None, 

307) -> IModel: 

308 """ 

309 Creates a model by regridding the packages of this model to another discretization. 

310 It regrids all the arrays in the package using the default regridding methods. 

311 At the moment only regridding to a different planar grid is supported, meaning 

312 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords. 

313 

314 Parameters 

315 ---------- 

316 target_grid: xr.DataArray or xu.UgridDataArray 

317 a grid defined over the same discretization as the one we want to regrid the package to 

318 validate: bool 

319 set to true to validate the regridded packages 

320 regrid_context: Optional RegridderWeightsCache 

321 stores regridder weights for different regridders. Can be used to speed up regridding, 

322 if the same regridders are used several times for regridding different arrays. 

323 

324 Returns 

325 ------- 

326 a model with similar packages to the input model, and with all the data-arrays regridded to another discretization, 

327 similar to the one used in input argument "target_grid" 

328 """ 

329 supported, error_with_object_name = model.is_regridding_supported() 

330 if not supported: 

331 raise ValueError( 

332 f"regridding this model cannot be done due to the presence of package {error_with_object_name}" 

333 ) 

334 new_model = model.__class__() 

335 if regrid_context is None: 

336 regrid_context = RegridderWeightsCache() 

337 for pkg_name, pkg in model.items(): 

338 if isinstance(pkg, (IRegridPackage, ILineDataPackage, IPointDataPackage)): 

339 new_model[pkg_name] = pkg.regrid_like(target_grid, regrid_context) 

340 else: 

341 raise NotImplementedError( 

342 f"regridding is not implemented for package {pkg_name} of type {type(pkg)}" 

343 ) 

344 

345 methods = _get_unique_regridder_types(model) 

346 output_domain = _get_regridding_domain(model, target_grid, regrid_context, methods) 

347 new_model.mask_all_packages(output_domain) 

348 new_model.purge_empty_packages() 

349 if validate: 

350 status_info = NestedStatusInfo("Model validation status") 

351 status_info.add(new_model.validate("Regridded model")) 

352 if status_info.has_errors(): 

353 raise ValidationError("\n" + status_info.to_string()) 

354 return new_model 

355 

356 

357@typedispatch # type: ignore[no-redef] 

358def _regrid_like( 

359 simulation: ISimulation, 

360 regridded_simulation_name: str, 

361 target_grid: GridDataArray, 

362 validate: bool = True, 

363) -> ISimulation: 

364 """ 

365 This method creates a new simulation object. The models contained in the new simulation are regridded versions 

366 of the models in the input object (this). 

367 Time discretization and solver settings are copied. 

368 

369 Parameters 

370 ---------- 

371 regridded_simulation_name: str 

372 name given to the output simulation 

373 target_grid: xr.DataArray or xu.UgridDataArray 

374 discretization onto which the models in this simulation will be regridded 

375 validate: bool 

376 set to true to validate the regridded packages 

377 

378 Returns 

379 ------- 

380 a new simulation object with regridded models 

381 """ 

382 

383 if simulation.is_split(): 

384 raise RuntimeError( 

385 "Unable to regrid simulation. Regridding can only be done on simulations that haven't been split." 

386 + " Therefore regridding should be done before splitting the simulation." 

387 ) 

388 if not simulation.has_one_flow_model(): 

389 raise ValueError( 

390 "Unable to regrid simulation. Regridding can only be done on simulations that have a single flow model." 

391 ) 

392 regrid_context = RegridderWeightsCache() 

393 

394 models = simulation.get_models() 

395 for model_name, model in models.items(): 

396 supported, error_with_object_name = model.is_regridding_supported() 

397 if not supported: 

398 raise ValueError( 

399 f"Unable to regrid simulation, due to the presence of package '{error_with_object_name}' in model {model_name} " 

400 ) 

401 

402 result = simulation.__class__(regridded_simulation_name) 

403 for key, item in simulation.items(): 

404 if isinstance(item, IModel): 

405 result[key] = item.regrid_like(target_grid, validate, regrid_context) 

406 elif key == "gwtgwf_exchanges": 

407 pass 

408 elif isinstance(item, IPackage) and not isinstance(item, IRegridPackage): 

409 result[key] = copy.deepcopy(item) 

410 

411 else: 

412 raise NotImplementedError(f"regridding not supported for {key}") 

413 

414 return result 

415 

416 

417@typedispatch # type: ignore[no-redef] 

418def _regrid_like( 

419 package: ILineDataPackage, target_grid: GridDataArray, *_ 

420) -> ILineDataPackage: 

421 """ 

422 The regrid_like method is irrelevant for this package as it is 

423 grid-agnostic, instead this method clips the package based on the grid 

424 exterior. 

425 """ 

426 return clip_by_grid(package, target_grid) 

427 

428 

429@typedispatch # type: ignore[no-redef] 

430def _regrid_like( 

431 package: IPointDataPackage, target_grid: GridDataArray, *_ 

432) -> IPointDataPackage: 

433 """ 

434 he regrid_like method is irrelevant for this package as it is 

435 grid-agnostic, instead this method clips the package based on the grid 

436 exterior. 

437 """ 

438 target_grid_2d = target_grid.isel(layer=0, drop=True, missing_dims="ignore") 

439 return clip_by_grid(package, target_grid_2d) 

440 

441 

442@typedispatch # type: ignore[no-redef] 

443def _regrid_like(package: object, target_grid: GridDataArray, *_) -> None: 

444 raise TypeError("this object cannot be regridded") 

445 

446 

447def _get_regridding_domain( 

448 model: IModel, 

449 target_grid: GridDataArray, 

450 regrid_context: RegridderWeightsCache, 

451 methods: defaultdict[RegridderType, list[str]], 

452) -> GridDataArray: 

453 """ 

454 This method computes the output-domain for a regridding operation by regridding idomain with 

455 all regridders. Each regridder may leave some cells inactive. The output domain for the model consists of those 

456 cells that all regridders consider active. 

457 """ 

458 idomain = model.domain 

459 included_in_all = ones_like(target_grid) 

460 regridders = [ 

461 regrid_context.get_regridder(idomain, target_grid, regriddertype, function) 

462 for regriddertype, functionlist in methods.items() 

463 for function in functionlist 

464 ] 

465 for regridder in regridders: 

466 regridded_idomain = regridder.regrid(idomain) 

467 included_in_all = included_in_all.where(regridded_idomain.notnull()) 

468 included_in_all = regridded_idomain.where( 

469 regridded_idomain <= 0, other=included_in_all 

470 ) 

471 

472 new_idomain = included_in_all.where(included_in_all.notnull(), other=0) 

473 new_idomain = new_idomain.astype(int) 

474 

475 return new_idomain