Coverage for C:\src\imod-python\imod\mf6\model.py: 88%

272 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-08 13:27 +0200

1from __future__ import annotations 

2 

3import abc 

4import collections 

5import inspect 

6import pathlib 

7from copy import deepcopy 

8from pathlib import Path 

9from typing import Optional, Tuple, Union 

10 

11import cftime 

12import jinja2 

13import numpy as np 

14import tomli 

15import tomli_w 

16import xarray as xr 

17import xugrid as xu 

18from jinja2 import Template 

19 

20import imod 

21from imod.logging import standard_log_decorator 

22from imod.mf6.interfaces.imodel import IModel 

23from imod.mf6.package import Package 

24from imod.mf6.statusinfo import NestedStatusInfo, StatusInfo, StatusInfoBase 

25from imod.mf6.utilities.mask import _mask_all_packages 

26from imod.mf6.utilities.regrid import RegridderWeightsCache, _regrid_like 

27from imod.mf6.validation import pkg_errors_to_status_info 

28from imod.mf6.write_context import WriteContext 

29from imod.schemata import ValidationError 

30from imod.typing import GridDataArray 

31 

32 

33class Modflow6Model(collections.UserDict, IModel, abc.ABC): 

34 _mandatory_packages: tuple[str, ...] = () 

35 _model_id: Optional[str] = None 

36 _template: Template 

37 

38 @staticmethod 

39 def _initialize_template(name: str) -> Template: 

40 loader = jinja2.PackageLoader("imod", "templates/mf6") 

41 env = jinja2.Environment(loader=loader, keep_trailing_newline=True) 

42 return env.get_template(name) 

43 

44 def __init__(self, **kwargs): 

45 collections.UserDict.__init__(self) 

46 for k, v in kwargs.items(): 

47 self[k] = v 

48 

49 self._options = {} 

50 

51 def __setitem__(self, key, value): 

52 if len(key) > 16: 

53 raise KeyError( 

54 f"Received key with more than 16 characters: '{key}'" 

55 "Modflow 6 has a character limit of 16." 

56 ) 

57 

58 super().__setitem__(key, value) 

59 

60 def update(self, *args, **kwargs): 

61 for k, v in dict(*args, **kwargs).items(): 

62 self[k] = v 

63 

64 def _get_diskey(self): 

65 dis_pkg_ids = ["dis", "disv", "disu"] 

66 

67 diskeys = [ 

68 self._get_pkgkey(pkg_id) 

69 for pkg_id in dis_pkg_ids 

70 if self._get_pkgkey(pkg_id) is not None 

71 ] 

72 

73 if len(diskeys) > 1: 

74 raise ValueError(f"Found multiple discretizations {diskeys}") 

75 elif len(diskeys) == 0: 

76 raise ValueError("No model discretization found") 

77 else: 

78 return diskeys[0] 

79 

80 def _get_pkgkey(self, pkg_id): 

81 """ 

82 Get package key that belongs to a certain pkg_id, since the keys are 

83 user specified. 

84 """ 

85 key = [pkgname for pkgname, pkg in self.items() if pkg._pkg_id == pkg_id] 

86 nkey = len(key) 

87 if nkey > 1: 

88 raise ValueError(f"Multiple instances of {key} detected") 

89 elif nkey == 1: 

90 return key[0] 

91 else: 

92 return None 

93 

94 def _check_for_required_packages(self, modelkey: str) -> None: 

95 # Check for mandatory packages 

96 pkg_ids = set([pkg._pkg_id for pkg in self.values()]) 

97 dispresent = "dis" in pkg_ids or "disv" in pkg_ids or "disu" in pkg_ids 

98 if not dispresent: 

99 raise ValueError(f"No dis/disv/disu package found in model {modelkey}") 

100 for required in self._mandatory_packages: 

101 if required not in pkg_ids: 

102 raise ValueError(f"No {required} package found in model {modelkey}") 

103 return 

104 

105 def _use_cftime(self): 

106 """ 

107 Also checks if datetime types are homogeneous across packages. 

108 """ 

109 types = [ 

110 type(pkg.dataset["time"].values[0]) 

111 for pkg in self.values() 

112 if "time" in pkg.dataset.coords 

113 ] 

114 set_of_types = set(types) 

115 # Types will be empty if there's no time dependent input 

116 if len(set_of_types) == 0: 

117 return False 

118 else: # there is time dependent input 

119 if not len(set_of_types) == 1: 

120 raise ValueError( 

121 f"Multiple datetime types detected: {set_of_types}" 

122 "Use either cftime or numpy.datetime64[ns]." 

123 ) 

124 # Since we compare types and not instances, we use issubclass 

125 if issubclass(types[0], cftime.datetime): 

126 return True 

127 elif issubclass(types[0], np.datetime64): 

128 return False 

129 else: 

130 raise ValueError("Use either cftime or numpy.datetime64[ns].") 

131 

132 def _yield_times(self): 

133 modeltimes = [] 

134 for pkg in self.values(): 

135 if "time" in pkg.dataset.coords: 

136 modeltimes.append(pkg.dataset["time"].values) 

137 repeat_stress = pkg.dataset.get("repeat_stress") 

138 if repeat_stress is not None and repeat_stress.values[()] is not None: 

139 modeltimes.append(repeat_stress.isel(repeat_items=0).values) 

140 return modeltimes 

141 

142 def render(self, modelname: str, write_context: WriteContext): 

143 dir_for_render = write_context.root_directory / modelname 

144 

145 d = {k: v for k, v in self._options.items() if not (v is None or v is False)} 

146 packages = [] 

147 for pkgname, pkg in self.items(): 

148 # Add the six to the package id 

149 pkg_id = pkg._pkg_id 

150 key = f"{pkg_id}6" 

151 path = dir_for_render / f"{pkgname}.{pkg_id}" 

152 packages.append((key, path.as_posix(), pkgname)) 

153 d["packages"] = packages 

154 return self._template.render(d) 

155 

156 def _model_checks(self, modelkey: str): 

157 """ 

158 Check model integrity (called before writing) 

159 """ 

160 

161 self._check_for_required_packages(modelkey) 

162 

163 def __get_domain_geometry( 

164 self, 

165 ) -> tuple[ 

166 Union[xr.DataArray, xu.UgridDataArray], 

167 Union[xr.DataArray, xu.UgridDataArray], 

168 Union[xr.DataArray, xu.UgridDataArray], 

169 ]: 

170 discretization = self[self._get_diskey()] 

171 if discretization is None: 

172 raise ValueError("Discretization not found") 

173 top = discretization["top"] 

174 bottom = discretization["bottom"] 

175 idomain = discretization["idomain"] 

176 return top, bottom, idomain 

177 

178 def __get_k(self): 

179 try: 

180 npf = self[imod.mf6.NodePropertyFlow._pkg_id] 

181 except RuntimeError: 

182 raise ValidationError("expected one package of type ModePropertyFlow") 

183 

184 k = npf["k"] 

185 return k 

186 

187 @standard_log_decorator() 

188 def validate(self, model_name: str = "") -> StatusInfoBase: 

189 try: 

190 diskey = self._get_diskey() 

191 except Exception as e: 

192 status_info = StatusInfo(f"{model_name} model") 

193 status_info.add_error(str(e)) 

194 return status_info 

195 

196 dis = self[diskey] 

197 # We'll use the idomain for checking dims, shape, nodata. 

198 idomain = dis["idomain"] 

199 bottom = dis["bottom"] 

200 

201 model_status_info = NestedStatusInfo(f"{model_name} model") 

202 for pkg_name, pkg in self.items(): 

203 # Check for all schemata when writing. Types and dimensions 

204 # may have been changed after initialization... 

205 

206 if pkg_name in ["adv"]: 

207 continue # some packages can be skipped 

208 

209 # Concatenate write and init schemata. 

210 schemata = deepcopy(pkg._init_schemata) 

211 for key, value in pkg._write_schemata.items(): 

212 if key not in schemata.keys(): 

213 schemata[key] = value 

214 else: 

215 schemata[key] += value 

216 

217 pkg_errors = pkg._validate( 

218 schemata=schemata, 

219 idomain=idomain, 

220 bottom=bottom, 

221 ) 

222 if len(pkg_errors) > 0: 

223 model_status_info.add(pkg_errors_to_status_info(pkg_name, pkg_errors)) 

224 

225 return model_status_info 

226 

227 @standard_log_decorator() 

228 def write( 

229 self, modelname, globaltimes, validate: bool, write_context: WriteContext 

230 ) -> StatusInfoBase: 

231 """ 

232 Write model namefile 

233 Write packages 

234 """ 

235 

236 workdir = write_context.simulation_directory 

237 modeldirectory = workdir / modelname 

238 Path(modeldirectory).mkdir(exist_ok=True, parents=True) 

239 if validate: 

240 model_status_info = self.validate(modelname) 

241 if model_status_info.has_errors(): 

242 return model_status_info 

243 

244 # write model namefile 

245 namefile_content = self.render(modelname, write_context) 

246 namefile_path = modeldirectory / f"{modelname}.nam" 

247 with open(namefile_path, "w") as f: 

248 f.write(namefile_content) 

249 

250 # write package contents 

251 pkg_write_context = write_context.copy_with_new_write_directory( 

252 new_write_directory=modeldirectory 

253 ) 

254 for pkg_name, pkg in self.items(): 

255 try: 

256 if isinstance(pkg, imod.mf6.Well): 

257 top, bottom, idomain = self.__get_domain_geometry() 

258 k = self.__get_k() 

259 mf6_well_pkg = pkg.to_mf6_pkg( 

260 idomain, 

261 top, 

262 bottom, 

263 k, 

264 validate, 

265 pkg_write_context.is_partitioned, 

266 ) 

267 

268 mf6_well_pkg.write( 

269 pkgname=pkg_name, 

270 globaltimes=globaltimes, 

271 write_context=pkg_write_context, 

272 ) 

273 elif isinstance(pkg, imod.mf6.HorizontalFlowBarrierBase): 

274 top, bottom, idomain = self.__get_domain_geometry() 

275 k = self.__get_k() 

276 mf6_hfb_pkg = pkg.to_mf6_pkg(idomain, top, bottom, k, validate) 

277 mf6_hfb_pkg.write( 

278 pkgname=pkg_name, 

279 globaltimes=globaltimes, 

280 write_context=pkg_write_context, 

281 ) 

282 else: 

283 pkg.write( 

284 pkgname=pkg_name, 

285 globaltimes=globaltimes, 

286 write_context=pkg_write_context, 

287 ) 

288 except Exception as e: 

289 raise type(e)(f"{e}\nError occured while writing {pkg_name}") 

290 

291 return NestedStatusInfo(modelname) 

292 

293 @standard_log_decorator() 

294 def dump( 

295 self, directory, modelname, validate: bool = True, mdal_compliant: bool = False 

296 ): 

297 modeldirectory = pathlib.Path(directory) / modelname 

298 modeldirectory.mkdir(exist_ok=True, parents=True) 

299 if validate: 

300 statusinfo = self.validate() 

301 if statusinfo.has_errors(): 

302 raise ValidationError(statusinfo.to_string()) 

303 

304 toml_content: dict = collections.defaultdict(dict) 

305 for pkgname, pkg in self.items(): 

306 pkg_path = f"{pkgname}.nc" 

307 toml_content[type(pkg).__name__][pkgname] = pkg_path 

308 dataset = pkg.dataset 

309 if isinstance(dataset, xu.UgridDataset): 

310 if mdal_compliant: 

311 dataset = pkg.dataset.ugrid.to_dataset() 

312 mdal_dataset = imod.util.spatial.mdal_compliant_ugrid2d(dataset) 

313 mdal_dataset.to_netcdf(modeldirectory / pkg_path) 

314 else: 

315 pkg.dataset.ugrid.to_netcdf(modeldirectory / pkg_path) 

316 else: 

317 pkg.to_netcdf(modeldirectory / pkg_path) 

318 

319 toml_path = modeldirectory / f"{modelname}.toml" 

320 with open(toml_path, "wb") as f: 

321 tomli_w.dump(toml_content, f) 

322 

323 return toml_path 

324 

325 @classmethod 

326 def from_file(cls, toml_path): 

327 pkg_classes = { 

328 name: pkg_cls 

329 for name, pkg_cls in inspect.getmembers(imod.mf6, inspect.isclass) 

330 if issubclass(pkg_cls, Package) 

331 } 

332 

333 toml_path = pathlib.Path(toml_path) 

334 with open(toml_path, "rb") as f: 

335 toml_content = tomli.load(f) 

336 

337 parentdir = toml_path.parent 

338 instance = cls() 

339 for key, entry in toml_content.items(): 

340 for pkgname, path in entry.items(): 

341 pkg_cls = pkg_classes[key] 

342 instance[pkgname] = pkg_cls.from_file(parentdir / path) 

343 

344 return instance 

345 

346 @classmethod 

347 def model_id(cls) -> str: 

348 if cls._model_id is None: 

349 raise ValueError("Model id has not been set") 

350 return cls._model_id 

351 

352 def clip_box( 

353 self, 

354 time_min: Optional[cftime.datetime | np.datetime64 | str] = None, 

355 time_max: Optional[cftime.datetime | np.datetime64 | str] = None, 

356 layer_min: Optional[int] = None, 

357 layer_max: Optional[int] = None, 

358 x_min: Optional[float] = None, 

359 x_max: Optional[float] = None, 

360 y_min: Optional[float] = None, 

361 y_max: Optional[float] = None, 

362 state_for_boundary: Optional[GridDataArray] = None, 

363 ): 

364 """ 

365 Clip a model by a bounding box (time, layer, y, x). 

366 

367 Slicing intervals may be half-bounded, by providing None: 

368 

369 * To select 500.0 <= x <= 1000.0: 

370 ``clip_box(x_min=500.0, x_max=1000.0)``. 

371 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)`` 

372 or ``clip_box(x_max=1000.0)``. 

373 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)`` 

374 or ``clip_box(x_min=1000.0)``. 

375 

376 Parameters 

377 ---------- 

378 time_min: optional 

379 time_max: optional 

380 layer_min: optional, int 

381 layer_max: optional, int 

382 x_min: optional, float 

383 x_max: optional, float 

384 y_min: optional, float 

385 y_max: optional, float 

386 state_for_boundary : 

387 """ 

388 supported, error_with_object = self.is_clipping_supported() 

389 if not supported: 

390 raise ValueError( 

391 f"model cannot be clipped due to presence of package '{error_with_object}' in model" 

392 ) 

393 

394 clipped = self._clip_box_packages( 

395 time_min, 

396 time_max, 

397 layer_min, 

398 layer_max, 

399 x_min, 

400 x_max, 

401 y_min, 

402 y_max, 

403 state_for_boundary, 

404 ) 

405 

406 return clipped 

407 

408 def _clip_box_packages( 

409 self, 

410 time_min: Optional[cftime.datetime | np.datetime64 | str] = None, 

411 time_max: Optional[cftime.datetime | np.datetime64 | str] = None, 

412 layer_min: Optional[int] = None, 

413 layer_max: Optional[int] = None, 

414 x_min: Optional[float] = None, 

415 x_max: Optional[float] = None, 

416 y_min: Optional[float] = None, 

417 y_max: Optional[float] = None, 

418 state_for_boundary: Optional[GridDataArray] = None, 

419 ): 

420 """ 

421 Clip a model by a bounding box (time, layer, y, x). 

422 

423 Slicing intervals may be half-bounded, by providing None: 

424 

425 * To select 500.0 <= x <= 1000.0: 

426 ``clip_box(x_min=500.0, x_max=1000.0)``. 

427 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)`` 

428 or ``clip_box(x_max=1000.0)``. 

429 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)`` 

430 or ``clip_box(x_min=1000.0)``. 

431 

432 Parameters 

433 ---------- 

434 time_min: optional 

435 time_max: optional 

436 layer_min: optional, int 

437 layer_max: optional, int 

438 x_min: optional, float 

439 x_max: optional, float 

440 y_min: optional, float 

441 y_max: optional, float 

442 

443 Returns 

444 ------- 

445 clipped : Modflow6Model 

446 """ 

447 

448 top, bottom, idomain = self.__get_domain_geometry() 

449 

450 clipped = type(self)(**self._options) 

451 for key, pkg in self.items(): 

452 clipped[key] = pkg.clip_box( 

453 time_min=time_min, 

454 time_max=time_max, 

455 layer_min=layer_min, 

456 layer_max=layer_max, 

457 x_min=x_min, 

458 x_max=x_max, 

459 y_min=y_min, 

460 y_max=y_max, 

461 top=top, 

462 bottom=bottom, 

463 state_for_boundary=state_for_boundary, 

464 ) 

465 

466 return clipped 

467 

468 def regrid_like( 

469 self, 

470 target_grid: GridDataArray, 

471 validate: bool = True, 

472 regrid_context: Optional[RegridderWeightsCache] = None, 

473 ) -> "Modflow6Model": 

474 """ 

475 Creates a model by regridding the packages of this model to another discretization. 

476 It regrids all the arrays in the package using the default regridding methods. 

477 At the moment only regridding to a different planar grid is supported, meaning 

478 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords. 

479 

480 Parameters 

481 ---------- 

482 target_grid: xr.DataArray or xu.UgridDataArray 

483 a grid defined over the same discretization as the one we want to regrid the package to 

484 validate: bool 

485 set to true to validate the regridded packages 

486 regrid_context: Optional RegridderWeightsCache 

487 stores regridder weights for different regridders. Can be used to speed up regridding, 

488 if the same regridders are used several times for regridding different arrays. 

489 

490 Returns 

491 ------- 

492 a model with similar packages to the input model, and with all the data-arrays regridded to another discretization, 

493 similar to the one used in input argument "target_grid" 

494 """ 

495 return _regrid_like(self, target_grid, validate, regrid_context) 

496 

497 def mask_all_packages( 

498 self, 

499 mask: GridDataArray, 

500 ): 

501 """ 

502 This function applies a mask to all packages in a model. The mask must 

503 be presented as an idomain-like integer array that has 0 (inactive) or 

504 -1 (vertical passthrough) values in filtered cells and 1 in active 

505 cells. 

506 Masking will overwrite idomain with the mask where the mask is 0 or -1. 

507 Where the mask is 1, the original value of idomain will be kept. Masking 

508 will update the packages accordingly, blanking their input where needed, 

509 and is therefore not a reversible operation. 

510 

511 Parameters 

512 ---------- 

513 mask: xr.DataArray, xu.UgridDataArray of ints 

514 idomain-like integer array. 1 sets cells to active, 0 sets cells to inactive, 

515 -1 sets cells to vertical passthrough 

516 """ 

517 

518 _mask_all_packages(self, mask) 

519 

520 def purge_empty_packages(self, model_name: Optional[str] = "") -> None: 

521 """ 

522 This function removes empty packages from the model. 

523 """ 

524 empty_packages = [ 

525 package_name for package_name, package in self.items() if package.is_empty() 

526 ] 

527 for package_name in empty_packages: 

528 self.pop(package_name) 

529 

530 @property 

531 def domain(self): 

532 dis = self._get_diskey() 

533 return self[dis]["idomain"] 

534 

535 @property 

536 def bottom(self): 

537 dis = self._get_diskey() 

538 return self[dis]["bottom"] 

539 

540 def __repr__(self) -> str: 

541 INDENT = " " 

542 typename = type(self).__name__ 

543 options = [ 

544 f"{INDENT}{key}={repr(value)}," for key, value in self._options.items() 

545 ] 

546 packages = [ 

547 f"{INDENT}{repr(key)}: {type(value).__name__}," 

548 for key, value in self.items() 

549 ] 

550 # Place the emtpy dict on the same line. Looks silly otherwise. 

551 if packages: 

552 content = [f"{typename}("] + options + ["){"] + packages + ["}"] 

553 else: 

554 content = [f"{typename}("] + options + ["){}"] 

555 return "\n".join(content) 

556 

557 def is_use_newton(self): 

558 return False 

559 

560 def is_splitting_supported(self) -> Tuple[bool, str]: 

561 """ 

562 Returns True if all the packages in the model supports splitting. If one 

563 of the packages in the model does not support splitting, it returns the 

564 name of the first one. 

565 """ 

566 for package_name, package in self.items(): 

567 if not package.is_splitting_supported(): 

568 return False, package_name 

569 return True, "" 

570 

571 def is_regridding_supported(self) -> Tuple[bool, str]: 

572 """ 

573 Returns True if all the packages in the model supports regridding. If one 

574 of the packages in the model does not support regridding, it returns the 

575 name of the first one. 

576 """ 

577 for package_name, package in self.items(): 

578 if not package.is_regridding_supported(): 

579 return False, package_name 

580 return True, "" 

581 

582 def is_clipping_supported(self) -> Tuple[bool, str]: 

583 """ 

584 Returns True if all the packages in the model supports clipping. If one 

585 of the packages in the model does not support clipping, it returns the 

586 name of the first one. 

587 """ 

588 for package_name, package in self.items(): 

589 if not package.is_clipping_supported(): 

590 return False, package_name 

591 return True, ""