Coverage for C:\src\imod-python\imod\mf6\model.py: 91%

280 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 11:25 +0200

1from __future__ import annotations 

2 

3import abc 

4import collections 

5import inspect 

6import pathlib 

7from copy import deepcopy 

8from pathlib import Path 

9from typing import Any, Optional, Tuple, Union 

10 

11import cftime 

12import jinja2 

13import numpy as np 

14import tomli 

15import tomli_w 

16import xarray as xr 

17import xugrid as xu 

18from jinja2 import Template 

19 

20import imod 

21from imod.logging import standard_log_decorator 

22from imod.mf6.interfaces.imodel import IModel 

23from imod.mf6.package import Package 

24from imod.mf6.statusinfo import NestedStatusInfo, StatusInfo, StatusInfoBase 

25from imod.mf6.utilities.mask import _mask_all_packages 

26from imod.mf6.utilities.regrid import RegridderWeightsCache, _regrid_like 

27from imod.mf6.validation import pkg_errors_to_status_info 

28from imod.mf6.write_context import WriteContext 

29from imod.schemata import ValidationError 

30from imod.typing import GridDataArray 

31from imod.typing.grid import is_spatial_grid 

32 

33 

34class Modflow6Model(collections.UserDict, IModel, abc.ABC): 

35 _mandatory_packages: tuple[str, ...] = () 

36 _model_id: Optional[str] = None 

37 _template: Template 

38 

39 @staticmethod 

40 def _initialize_template(name: str) -> Template: 

41 loader = jinja2.PackageLoader("imod", "templates/mf6") 

42 env = jinja2.Environment(loader=loader, keep_trailing_newline=True) 

43 return env.get_template(name) 

44 

45 def __init__(self, **kwargs): 

46 collections.UserDict.__init__(self) 

47 for k, v in kwargs.items(): 

48 self[k] = v 

49 

50 self._options = {} 

51 

52 def __setitem__(self, key, value): 

53 if len(key) > 16: 

54 raise KeyError( 

55 f"Received key with more than 16 characters: '{key}'" 

56 "Modflow 6 has a character limit of 16." 

57 ) 

58 

59 super().__setitem__(key, value) 

60 

61 def update(self, *args, **kwargs): 

62 for k, v in dict(*args, **kwargs).items(): 

63 self[k] = v 

64 

65 def _get_diskey(self): 

66 dis_pkg_ids = ["dis", "disv", "disu"] 

67 

68 diskeys = [ 

69 self._get_pkgkey(pkg_id) 

70 for pkg_id in dis_pkg_ids 

71 if self._get_pkgkey(pkg_id) is not None 

72 ] 

73 

74 if len(diskeys) > 1: 

75 raise ValueError(f"Found multiple discretizations {diskeys}") 

76 elif len(diskeys) == 0: 

77 raise ValueError("No model discretization found") 

78 else: 

79 return diskeys[0] 

80 

81 def _get_pkgkey(self, pkg_id): 

82 """ 

83 Get package key that belongs to a certain pkg_id, since the keys are 

84 user specified. 

85 """ 

86 key = [pkgname for pkgname, pkg in self.items() if pkg._pkg_id == pkg_id] 

87 nkey = len(key) 

88 if nkey > 1: 

89 raise ValueError(f"Multiple instances of {key} detected") 

90 elif nkey == 1: 

91 return key[0] 

92 else: 

93 return None 

94 

95 def _check_for_required_packages(self, modelkey: str) -> None: 

96 # Check for mandatory packages 

97 pkg_ids = {pkg._pkg_id for pkg in self.values()} 

98 dispresent = "dis" in pkg_ids or "disv" in pkg_ids or "disu" in pkg_ids 

99 if not dispresent: 

100 raise ValueError(f"No dis/disv/disu package found in model {modelkey}") 

101 for required in self._mandatory_packages: 

102 if required not in pkg_ids: 

103 raise ValueError(f"No {required} package found in model {modelkey}") 

104 return 

105 

106 def _use_cftime(self): 

107 """ 

108 Also checks if datetime types are homogeneous across packages. 

109 """ 

110 types = [ 

111 type(pkg.dataset["time"].values[0]) 

112 for pkg in self.values() 

113 if "time" in pkg.dataset.coords 

114 ] 

115 set_of_types = set(types) 

116 # Types will be empty if there's no time dependent input 

117 if len(set_of_types) == 0: 

118 return False 

119 else: # there is time dependent input 

120 if not len(set_of_types) == 1: 

121 raise ValueError( 

122 f"Multiple datetime types detected: {set_of_types}" 

123 "Use either cftime or numpy.datetime64[ns]." 

124 ) 

125 # Since we compare types and not instances, we use issubclass 

126 if issubclass(types[0], cftime.datetime): 

127 return True 

128 elif issubclass(types[0], np.datetime64): 

129 return False 

130 else: 

131 raise ValueError("Use either cftime or numpy.datetime64[ns].") 

132 

133 def _yield_times(self): 

134 modeltimes = [] 

135 for pkg in self.values(): 

136 if "time" in pkg.dataset.coords: 

137 modeltimes.append(pkg.dataset["time"].values) 

138 repeat_stress = pkg.dataset.get("repeat_stress") 

139 if repeat_stress is not None and repeat_stress.values[()] is not None: 

140 modeltimes.append(repeat_stress.isel(repeat_items=0).values) 

141 return modeltimes 

142 

143 def render(self, modelname: str, write_context: WriteContext): 

144 dir_for_render = write_context.root_directory / modelname 

145 

146 d = {k: v for k, v in self._options.items() if not (v is None or v is False)} 

147 packages = [] 

148 for pkgname, pkg in self.items(): 

149 # Add the six to the package id 

150 pkg_id = pkg._pkg_id 

151 key = f"{pkg_id}6" 

152 path = dir_for_render / f"{pkgname}.{pkg_id}" 

153 packages.append((key, path.as_posix(), pkgname)) 

154 d["packages"] = packages 

155 return self._template.render(d) 

156 

157 def _model_checks(self, modelkey: str): 

158 """ 

159 Check model integrity (called before writing) 

160 """ 

161 

162 self._check_for_required_packages(modelkey) 

163 

164 def __get_domain_geometry( 

165 self, 

166 ) -> tuple[ 

167 Union[xr.DataArray, xu.UgridDataArray], 

168 Union[xr.DataArray, xu.UgridDataArray], 

169 Union[xr.DataArray, xu.UgridDataArray], 

170 ]: 

171 discretization = self[self._get_diskey()] 

172 if discretization is None: 

173 raise ValueError("Discretization not found") 

174 top = discretization["top"] 

175 bottom = discretization["bottom"] 

176 idomain = discretization["idomain"] 

177 return top, bottom, idomain 

178 

179 def __get_k(self): 

180 try: 

181 npf = self[imod.mf6.NodePropertyFlow._pkg_id] 

182 except RuntimeError: 

183 raise ValidationError("expected one package of type ModePropertyFlow") 

184 

185 k = npf["k"] 

186 return k 

187 

188 @standard_log_decorator() 

189 def validate(self, model_name: str = "") -> StatusInfoBase: 

190 try: 

191 diskey = self._get_diskey() 

192 except Exception as e: 

193 status_info = StatusInfo(f"{model_name} model") 

194 status_info.add_error(str(e)) 

195 return status_info 

196 

197 dis = self[diskey] 

198 # We'll use the idomain for checking dims, shape, nodata. 

199 idomain = dis["idomain"] 

200 bottom = dis["bottom"] 

201 

202 model_status_info = NestedStatusInfo(f"{model_name} model") 

203 for pkg_name, pkg in self.items(): 

204 # Check for all schemata when writing. Types and dimensions 

205 # may have been changed after initialization... 

206 

207 if pkg_name in ["adv"]: 

208 continue # some packages can be skipped 

209 

210 # Concatenate write and init schemata. 

211 schemata = deepcopy(pkg._init_schemata) 

212 for key, value in pkg._write_schemata.items(): 

213 if key not in schemata.keys(): 

214 schemata[key] = value 

215 else: 

216 schemata[key] += value 

217 

218 pkg_errors = pkg._validate( 

219 schemata=schemata, 

220 idomain=idomain, 

221 bottom=bottom, 

222 ) 

223 if len(pkg_errors) > 0: 

224 model_status_info.add(pkg_errors_to_status_info(pkg_name, pkg_errors)) 

225 

226 return model_status_info 

227 

228 @standard_log_decorator() 

229 def write( 

230 self, modelname, globaltimes, validate: bool, write_context: WriteContext 

231 ) -> StatusInfoBase: 

232 """ 

233 Write model namefile 

234 Write packages 

235 """ 

236 

237 workdir = write_context.simulation_directory 

238 modeldirectory = workdir / modelname 

239 Path(modeldirectory).mkdir(exist_ok=True, parents=True) 

240 if validate: 

241 model_status_info = self.validate(modelname) 

242 if model_status_info.has_errors(): 

243 return model_status_info 

244 

245 # write model namefile 

246 namefile_content = self.render(modelname, write_context) 

247 namefile_path = modeldirectory / f"{modelname}.nam" 

248 with open(namefile_path, "w") as f: 

249 f.write(namefile_content) 

250 

251 # write package contents 

252 pkg_write_context = write_context.copy_with_new_write_directory( 

253 new_write_directory=modeldirectory 

254 ) 

255 for pkg_name, pkg in self.items(): 

256 try: 

257 if isinstance(pkg, imod.mf6.Well): 

258 top, bottom, idomain = self.__get_domain_geometry() 

259 k = self.__get_k() 

260 mf6_well_pkg = pkg.to_mf6_pkg( 

261 idomain, 

262 top, 

263 bottom, 

264 k, 

265 validate, 

266 pkg_write_context.is_partitioned, 

267 ) 

268 

269 mf6_well_pkg.write( 

270 pkgname=pkg_name, 

271 globaltimes=globaltimes, 

272 write_context=pkg_write_context, 

273 ) 

274 elif isinstance(pkg, imod.mf6.HorizontalFlowBarrierBase): 

275 top, bottom, idomain = self.__get_domain_geometry() 

276 k = self.__get_k() 

277 mf6_hfb_pkg = pkg.to_mf6_pkg(idomain, top, bottom, k, validate) 

278 mf6_hfb_pkg.write( 

279 pkgname=pkg_name, 

280 globaltimes=globaltimes, 

281 write_context=pkg_write_context, 

282 ) 

283 else: 

284 pkg.write( 

285 pkgname=pkg_name, 

286 globaltimes=globaltimes, 

287 write_context=pkg_write_context, 

288 ) 

289 except Exception as e: 

290 raise type(e)(f"{e}\nError occured while writing {pkg_name}") 

291 

292 return NestedStatusInfo(modelname) 

293 

294 @standard_log_decorator() 

295 def dump( 

296 self, 

297 directory, 

298 modelname, 

299 validate: bool = True, 

300 mdal_compliant: bool = False, 

301 crs: Optional[Any] = None, 

302 ): 

303 """ 

304 Dump simulation to files. Writes a model definition as .TOML file, which 

305 points to data for each package. Each package is stored as a separate 

306 NetCDF. Structured grids are saved as regular NetCDFs, unstructured 

307 grids are saved as UGRID NetCDF. Structured grids are always made GDAL 

308 compliant, unstructured grids can be made MDAL compliant optionally. 

309 

310 Parameters 

311 ---------- 

312 directory: str or Path 

313 directory to dump simulation into. 

314 modelname: str 

315 modelname, will be used to create a subdirectory. 

316 validate: bool, optional 

317 Whether to validate simulation data. Defaults to True. 

318 mdal_compliant: bool, optional 

319 Convert data with 

320 :func:`imod.prepare.spatial.mdal_compliant_ugrid2d` to MDAL 

321 compliant unstructured grids. Defaults to False. 

322 crs: Any, optional 

323 Anything accepted by rasterio.crs.CRS.from_user_input 

324 Requires ``rioxarray`` installed. 

325 """ 

326 modeldirectory = pathlib.Path(directory) / modelname 

327 modeldirectory.mkdir(exist_ok=True, parents=True) 

328 if validate: 

329 statusinfo = self.validate() 

330 if statusinfo.has_errors(): 

331 raise ValidationError(statusinfo.to_string()) 

332 

333 toml_content: dict = collections.defaultdict(dict) 

334 for pkgname, pkg in self.items(): 

335 pkg_path = f"{pkgname}.nc" 

336 toml_content[type(pkg).__name__][pkgname] = pkg_path 

337 dataset = pkg.dataset 

338 if isinstance(dataset, xu.UgridDataset): 

339 if mdal_compliant: 

340 dataset = dataset.ugrid.to_dataset() 

341 mdal_dataset = imod.util.spatial.mdal_compliant_ugrid2d( 

342 dataset, crs=crs 

343 ) 

344 mdal_dataset.to_netcdf(modeldirectory / pkg_path) 

345 else: 

346 dataset.ugrid.to_netcdf(modeldirectory / pkg_path) 

347 else: 

348 if is_spatial_grid(dataset): 

349 dataset = imod.util.spatial.gdal_compliant_grid(dataset, crs=crs) 

350 dataset.to_netcdf(modeldirectory / pkg_path) 

351 

352 toml_path = modeldirectory / f"{modelname}.toml" 

353 with open(toml_path, "wb") as f: 

354 tomli_w.dump(toml_content, f) 

355 

356 return toml_path 

357 

358 @classmethod 

359 def from_file(cls, toml_path): 

360 pkg_classes = { 

361 name: pkg_cls 

362 for name, pkg_cls in inspect.getmembers(imod.mf6, inspect.isclass) 

363 if issubclass(pkg_cls, Package) 

364 } 

365 

366 toml_path = pathlib.Path(toml_path) 

367 with open(toml_path, "rb") as f: 

368 toml_content = tomli.load(f) 

369 

370 parentdir = toml_path.parent 

371 instance = cls() 

372 for key, entry in toml_content.items(): 

373 for pkgname, path in entry.items(): 

374 pkg_cls = pkg_classes[key] 

375 instance[pkgname] = pkg_cls.from_file(parentdir / path) 

376 

377 return instance 

378 

379 @property 

380 def options(self) -> dict: 

381 if self._options is None: 

382 raise ValueError("Model id has not been set") 

383 return self._options 

384 

385 @property 

386 def model_id(self) -> str: 

387 if self._model_id is None: 

388 raise ValueError("Model id has not been set") 

389 return self._model_id 

390 

391 def clip_box( 

392 self, 

393 time_min: Optional[cftime.datetime | np.datetime64 | str] = None, 

394 time_max: Optional[cftime.datetime | np.datetime64 | str] = None, 

395 layer_min: Optional[int] = None, 

396 layer_max: Optional[int] = None, 

397 x_min: Optional[float] = None, 

398 x_max: Optional[float] = None, 

399 y_min: Optional[float] = None, 

400 y_max: Optional[float] = None, 

401 state_for_boundary: Optional[GridDataArray] = None, 

402 ): 

403 """ 

404 Clip a model by a bounding box (time, layer, y, x). 

405 

406 Slicing intervals may be half-bounded, by providing None: 

407 

408 * To select 500.0 <= x <= 1000.0: 

409 ``clip_box(x_min=500.0, x_max=1000.0)``. 

410 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)`` 

411 or ``clip_box(x_max=1000.0)``. 

412 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)`` 

413 or ``clip_box(x_min=1000.0)``. 

414 

415 Parameters 

416 ---------- 

417 time_min: optional 

418 time_max: optional 

419 layer_min: optional, int 

420 layer_max: optional, int 

421 x_min: optional, float 

422 x_max: optional, float 

423 y_min: optional, float 

424 y_max: optional, float 

425 state_for_boundary: optional, float 

426 """ 

427 supported, error_with_object = self.is_clipping_supported() 

428 if not supported: 

429 raise ValueError( 

430 f"model cannot be clipped due to presence of package '{error_with_object}' in model" 

431 ) 

432 

433 clipped = self._clip_box_packages( 

434 time_min, 

435 time_max, 

436 layer_min, 

437 layer_max, 

438 x_min, 

439 x_max, 

440 y_min, 

441 y_max, 

442 ) 

443 

444 return clipped 

445 

446 def _clip_box_packages( 

447 self, 

448 time_min: Optional[cftime.datetime | np.datetime64 | str] = None, 

449 time_max: Optional[cftime.datetime | np.datetime64 | str] = None, 

450 layer_min: Optional[int] = None, 

451 layer_max: Optional[int] = None, 

452 x_min: Optional[float] = None, 

453 x_max: Optional[float] = None, 

454 y_min: Optional[float] = None, 

455 y_max: Optional[float] = None, 

456 ): 

457 """ 

458 Clip a model by a bounding box (time, layer, y, x). 

459 

460 Slicing intervals may be half-bounded, by providing None: 

461 

462 * To select 500.0 <= x <= 1000.0: 

463 ``clip_box(x_min=500.0, x_max=1000.0)``. 

464 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)`` 

465 or ``clip_box(x_max=1000.0)``. 

466 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)`` 

467 or ``clip_box(x_min=1000.0)``. 

468 

469 Parameters 

470 ---------- 

471 time_min: optional 

472 time_max: optional 

473 layer_min: optional, int 

474 layer_max: optional, int 

475 x_min: optional, float 

476 x_max: optional, float 

477 y_min: optional, float 

478 y_max: optional, float 

479 

480 Returns 

481 ------- 

482 clipped : Modflow6Model 

483 """ 

484 

485 top, bottom, idomain = self.__get_domain_geometry() 

486 

487 clipped = type(self)(**self._options) 

488 for key, pkg in self.items(): 

489 clipped[key] = pkg.clip_box( 

490 time_min=time_min, 

491 time_max=time_max, 

492 layer_min=layer_min, 

493 layer_max=layer_max, 

494 x_min=x_min, 

495 x_max=x_max, 

496 y_min=y_min, 

497 y_max=y_max, 

498 top=top, 

499 bottom=bottom, 

500 ) 

501 

502 return clipped 

503 

504 def regrid_like( 

505 self, 

506 target_grid: GridDataArray, 

507 validate: bool = True, 

508 regrid_context: Optional[RegridderWeightsCache] = None, 

509 ) -> "Modflow6Model": 

510 """ 

511 Creates a model by regridding the packages of this model to another discretization. 

512 It regrids all the arrays in the package using the default regridding methods. 

513 At the moment only regridding to a different planar grid is supported, meaning 

514 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords. 

515 

516 Parameters 

517 ---------- 

518 target_grid: xr.DataArray or xu.UgridDataArray 

519 a grid defined over the same discretization as the one we want to regrid the package to 

520 validate: bool 

521 set to true to validate the regridded packages 

522 regrid_context: Optional RegridderWeightsCache 

523 stores regridder weights for different regridders. Can be used to speed up regridding, 

524 if the same regridders are used several times for regridding different arrays. 

525 

526 Returns 

527 ------- 

528 a model with similar packages to the input model, and with all the data-arrays regridded to another discretization, 

529 similar to the one used in input argument "target_grid" 

530 """ 

531 return _regrid_like(self, target_grid, validate, regrid_context) 

532 

533 def mask_all_packages( 

534 self, 

535 mask: GridDataArray, 

536 ): 

537 """ 

538 This function applies a mask to all packages in a model. The mask must 

539 be presented as an idomain-like integer array that has 0 (inactive) or 

540 -1 (vertical passthrough) values in filtered cells and 1 in active 

541 cells. 

542 Masking will overwrite idomain with the mask where the mask is 0 or -1. 

543 Where the mask is 1, the original value of idomain will be kept. Masking 

544 will update the packages accordingly, blanking their input where needed, 

545 and is therefore not a reversible operation. 

546 

547 Parameters 

548 ---------- 

549 mask: xr.DataArray, xu.UgridDataArray of ints 

550 idomain-like integer array. 1 sets cells to active, 0 sets cells to inactive, 

551 -1 sets cells to vertical passthrough 

552 """ 

553 

554 _mask_all_packages(self, mask) 

555 

556 def purge_empty_packages(self, model_name: Optional[str] = "") -> None: 

557 """ 

558 This function removes empty packages from the model. 

559 """ 

560 empty_packages = [ 

561 package_name for package_name, package in self.items() if package.is_empty() 

562 ] 

563 for package_name in empty_packages: 

564 self.pop(package_name) 

565 

566 @property 

567 def domain(self): 

568 dis = self._get_diskey() 

569 return self[dis]["idomain"] 

570 

571 @property 

572 def bottom(self): 

573 dis = self._get_diskey() 

574 return self[dis]["bottom"] 

575 

576 def __repr__(self) -> str: 

577 INDENT = " " 

578 typename = type(self).__name__ 

579 options = [ 

580 f"{INDENT}{key}={repr(value)}," for key, value in self._options.items() 

581 ] 

582 packages = [ 

583 f"{INDENT}{repr(key)}: {type(value).__name__}," 

584 for key, value in self.items() 

585 ] 

586 # Place the emtpy dict on the same line. Looks silly otherwise. 

587 if packages: 

588 content = [f"{typename}("] + options + ["){"] + packages + ["}"] 

589 else: 

590 content = [f"{typename}("] + options + ["){}"] 

591 return "\n".join(content) 

592 

593 def is_use_newton(self): 

594 return False 

595 

596 def is_splitting_supported(self) -> Tuple[bool, str]: 

597 """ 

598 Returns True if all the packages in the model supports splitting. If one 

599 of the packages in the model does not support splitting, it returns the 

600 name of the first one. 

601 """ 

602 for package_name, package in self.items(): 

603 if not package.is_splitting_supported(): 

604 return False, package_name 

605 return True, "" 

606 

607 def is_regridding_supported(self) -> Tuple[bool, str]: 

608 """ 

609 Returns True if all the packages in the model supports regridding. If one 

610 of the packages in the model does not support regridding, it returns the 

611 name of the first one. 

612 """ 

613 for package_name, package in self.items(): 

614 if not package.is_regridding_supported(): 

615 return False, package_name 

616 return True, "" 

617 

618 def is_clipping_supported(self) -> Tuple[bool, str]: 

619 """ 

620 Returns True if all the packages in the model supports clipping. If one 

621 of the packages in the model does not support clipping, it returns the 

622 name of the first one. 

623 """ 

624 for package_name, package in self.items(): 

625 if not package.is_clipping_supported(): 

626 return False, package_name 

627 return True, ""