Coverage for C:\src\imod-python\imod\mf6\model.py: 91%
280 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 11:25 +0200
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 11:25 +0200
1from __future__ import annotations
3import abc
4import collections
5import inspect
6import pathlib
7from copy import deepcopy
8from pathlib import Path
9from typing import Any, Optional, Tuple, Union
11import cftime
12import jinja2
13import numpy as np
14import tomli
15import tomli_w
16import xarray as xr
17import xugrid as xu
18from jinja2 import Template
20import imod
21from imod.logging import standard_log_decorator
22from imod.mf6.interfaces.imodel import IModel
23from imod.mf6.package import Package
24from imod.mf6.statusinfo import NestedStatusInfo, StatusInfo, StatusInfoBase
25from imod.mf6.utilities.mask import _mask_all_packages
26from imod.mf6.utilities.regrid import RegridderWeightsCache, _regrid_like
27from imod.mf6.validation import pkg_errors_to_status_info
28from imod.mf6.write_context import WriteContext
29from imod.schemata import ValidationError
30from imod.typing import GridDataArray
31from imod.typing.grid import is_spatial_grid
34class Modflow6Model(collections.UserDict, IModel, abc.ABC):
35 _mandatory_packages: tuple[str, ...] = ()
36 _model_id: Optional[str] = None
37 _template: Template
39 @staticmethod
40 def _initialize_template(name: str) -> Template:
41 loader = jinja2.PackageLoader("imod", "templates/mf6")
42 env = jinja2.Environment(loader=loader, keep_trailing_newline=True)
43 return env.get_template(name)
45 def __init__(self, **kwargs):
46 collections.UserDict.__init__(self)
47 for k, v in kwargs.items():
48 self[k] = v
50 self._options = {}
52 def __setitem__(self, key, value):
53 if len(key) > 16:
54 raise KeyError(
55 f"Received key with more than 16 characters: '{key}'"
56 "Modflow 6 has a character limit of 16."
57 )
59 super().__setitem__(key, value)
61 def update(self, *args, **kwargs):
62 for k, v in dict(*args, **kwargs).items():
63 self[k] = v
65 def _get_diskey(self):
66 dis_pkg_ids = ["dis", "disv", "disu"]
68 diskeys = [
69 self._get_pkgkey(pkg_id)
70 for pkg_id in dis_pkg_ids
71 if self._get_pkgkey(pkg_id) is not None
72 ]
74 if len(diskeys) > 1:
75 raise ValueError(f"Found multiple discretizations {diskeys}")
76 elif len(diskeys) == 0:
77 raise ValueError("No model discretization found")
78 else:
79 return diskeys[0]
81 def _get_pkgkey(self, pkg_id):
82 """
83 Get package key that belongs to a certain pkg_id, since the keys are
84 user specified.
85 """
86 key = [pkgname for pkgname, pkg in self.items() if pkg._pkg_id == pkg_id]
87 nkey = len(key)
88 if nkey > 1:
89 raise ValueError(f"Multiple instances of {key} detected")
90 elif nkey == 1:
91 return key[0]
92 else:
93 return None
95 def _check_for_required_packages(self, modelkey: str) -> None:
96 # Check for mandatory packages
97 pkg_ids = {pkg._pkg_id for pkg in self.values()}
98 dispresent = "dis" in pkg_ids or "disv" in pkg_ids or "disu" in pkg_ids
99 if not dispresent:
100 raise ValueError(f"No dis/disv/disu package found in model {modelkey}")
101 for required in self._mandatory_packages:
102 if required not in pkg_ids:
103 raise ValueError(f"No {required} package found in model {modelkey}")
104 return
106 def _use_cftime(self):
107 """
108 Also checks if datetime types are homogeneous across packages.
109 """
110 types = [
111 type(pkg.dataset["time"].values[0])
112 for pkg in self.values()
113 if "time" in pkg.dataset.coords
114 ]
115 set_of_types = set(types)
116 # Types will be empty if there's no time dependent input
117 if len(set_of_types) == 0:
118 return False
119 else: # there is time dependent input
120 if not len(set_of_types) == 1:
121 raise ValueError(
122 f"Multiple datetime types detected: {set_of_types}"
123 "Use either cftime or numpy.datetime64[ns]."
124 )
125 # Since we compare types and not instances, we use issubclass
126 if issubclass(types[0], cftime.datetime):
127 return True
128 elif issubclass(types[0], np.datetime64):
129 return False
130 else:
131 raise ValueError("Use either cftime or numpy.datetime64[ns].")
133 def _yield_times(self):
134 modeltimes = []
135 for pkg in self.values():
136 if "time" in pkg.dataset.coords:
137 modeltimes.append(pkg.dataset["time"].values)
138 repeat_stress = pkg.dataset.get("repeat_stress")
139 if repeat_stress is not None and repeat_stress.values[()] is not None:
140 modeltimes.append(repeat_stress.isel(repeat_items=0).values)
141 return modeltimes
143 def render(self, modelname: str, write_context: WriteContext):
144 dir_for_render = write_context.root_directory / modelname
146 d = {k: v for k, v in self._options.items() if not (v is None or v is False)}
147 packages = []
148 for pkgname, pkg in self.items():
149 # Add the six to the package id
150 pkg_id = pkg._pkg_id
151 key = f"{pkg_id}6"
152 path = dir_for_render / f"{pkgname}.{pkg_id}"
153 packages.append((key, path.as_posix(), pkgname))
154 d["packages"] = packages
155 return self._template.render(d)
157 def _model_checks(self, modelkey: str):
158 """
159 Check model integrity (called before writing)
160 """
162 self._check_for_required_packages(modelkey)
164 def __get_domain_geometry(
165 self,
166 ) -> tuple[
167 Union[xr.DataArray, xu.UgridDataArray],
168 Union[xr.DataArray, xu.UgridDataArray],
169 Union[xr.DataArray, xu.UgridDataArray],
170 ]:
171 discretization = self[self._get_diskey()]
172 if discretization is None:
173 raise ValueError("Discretization not found")
174 top = discretization["top"]
175 bottom = discretization["bottom"]
176 idomain = discretization["idomain"]
177 return top, bottom, idomain
179 def __get_k(self):
180 try:
181 npf = self[imod.mf6.NodePropertyFlow._pkg_id]
182 except RuntimeError:
183 raise ValidationError("expected one package of type ModePropertyFlow")
185 k = npf["k"]
186 return k
188 @standard_log_decorator()
189 def validate(self, model_name: str = "") -> StatusInfoBase:
190 try:
191 diskey = self._get_diskey()
192 except Exception as e:
193 status_info = StatusInfo(f"{model_name} model")
194 status_info.add_error(str(e))
195 return status_info
197 dis = self[diskey]
198 # We'll use the idomain for checking dims, shape, nodata.
199 idomain = dis["idomain"]
200 bottom = dis["bottom"]
202 model_status_info = NestedStatusInfo(f"{model_name} model")
203 for pkg_name, pkg in self.items():
204 # Check for all schemata when writing. Types and dimensions
205 # may have been changed after initialization...
207 if pkg_name in ["adv"]:
208 continue # some packages can be skipped
210 # Concatenate write and init schemata.
211 schemata = deepcopy(pkg._init_schemata)
212 for key, value in pkg._write_schemata.items():
213 if key not in schemata.keys():
214 schemata[key] = value
215 else:
216 schemata[key] += value
218 pkg_errors = pkg._validate(
219 schemata=schemata,
220 idomain=idomain,
221 bottom=bottom,
222 )
223 if len(pkg_errors) > 0:
224 model_status_info.add(pkg_errors_to_status_info(pkg_name, pkg_errors))
226 return model_status_info
228 @standard_log_decorator()
229 def write(
230 self, modelname, globaltimes, validate: bool, write_context: WriteContext
231 ) -> StatusInfoBase:
232 """
233 Write model namefile
234 Write packages
235 """
237 workdir = write_context.simulation_directory
238 modeldirectory = workdir / modelname
239 Path(modeldirectory).mkdir(exist_ok=True, parents=True)
240 if validate:
241 model_status_info = self.validate(modelname)
242 if model_status_info.has_errors():
243 return model_status_info
245 # write model namefile
246 namefile_content = self.render(modelname, write_context)
247 namefile_path = modeldirectory / f"{modelname}.nam"
248 with open(namefile_path, "w") as f:
249 f.write(namefile_content)
251 # write package contents
252 pkg_write_context = write_context.copy_with_new_write_directory(
253 new_write_directory=modeldirectory
254 )
255 for pkg_name, pkg in self.items():
256 try:
257 if isinstance(pkg, imod.mf6.Well):
258 top, bottom, idomain = self.__get_domain_geometry()
259 k = self.__get_k()
260 mf6_well_pkg = pkg.to_mf6_pkg(
261 idomain,
262 top,
263 bottom,
264 k,
265 validate,
266 pkg_write_context.is_partitioned,
267 )
269 mf6_well_pkg.write(
270 pkgname=pkg_name,
271 globaltimes=globaltimes,
272 write_context=pkg_write_context,
273 )
274 elif isinstance(pkg, imod.mf6.HorizontalFlowBarrierBase):
275 top, bottom, idomain = self.__get_domain_geometry()
276 k = self.__get_k()
277 mf6_hfb_pkg = pkg.to_mf6_pkg(idomain, top, bottom, k, validate)
278 mf6_hfb_pkg.write(
279 pkgname=pkg_name,
280 globaltimes=globaltimes,
281 write_context=pkg_write_context,
282 )
283 else:
284 pkg.write(
285 pkgname=pkg_name,
286 globaltimes=globaltimes,
287 write_context=pkg_write_context,
288 )
289 except Exception as e:
290 raise type(e)(f"{e}\nError occured while writing {pkg_name}")
292 return NestedStatusInfo(modelname)
294 @standard_log_decorator()
295 def dump(
296 self,
297 directory,
298 modelname,
299 validate: bool = True,
300 mdal_compliant: bool = False,
301 crs: Optional[Any] = None,
302 ):
303 """
304 Dump simulation to files. Writes a model definition as .TOML file, which
305 points to data for each package. Each package is stored as a separate
306 NetCDF. Structured grids are saved as regular NetCDFs, unstructured
307 grids are saved as UGRID NetCDF. Structured grids are always made GDAL
308 compliant, unstructured grids can be made MDAL compliant optionally.
310 Parameters
311 ----------
312 directory: str or Path
313 directory to dump simulation into.
314 modelname: str
315 modelname, will be used to create a subdirectory.
316 validate: bool, optional
317 Whether to validate simulation data. Defaults to True.
318 mdal_compliant: bool, optional
319 Convert data with
320 :func:`imod.prepare.spatial.mdal_compliant_ugrid2d` to MDAL
321 compliant unstructured grids. Defaults to False.
322 crs: Any, optional
323 Anything accepted by rasterio.crs.CRS.from_user_input
324 Requires ``rioxarray`` installed.
325 """
326 modeldirectory = pathlib.Path(directory) / modelname
327 modeldirectory.mkdir(exist_ok=True, parents=True)
328 if validate:
329 statusinfo = self.validate()
330 if statusinfo.has_errors():
331 raise ValidationError(statusinfo.to_string())
333 toml_content: dict = collections.defaultdict(dict)
334 for pkgname, pkg in self.items():
335 pkg_path = f"{pkgname}.nc"
336 toml_content[type(pkg).__name__][pkgname] = pkg_path
337 dataset = pkg.dataset
338 if isinstance(dataset, xu.UgridDataset):
339 if mdal_compliant:
340 dataset = dataset.ugrid.to_dataset()
341 mdal_dataset = imod.util.spatial.mdal_compliant_ugrid2d(
342 dataset, crs=crs
343 )
344 mdal_dataset.to_netcdf(modeldirectory / pkg_path)
345 else:
346 dataset.ugrid.to_netcdf(modeldirectory / pkg_path)
347 else:
348 if is_spatial_grid(dataset):
349 dataset = imod.util.spatial.gdal_compliant_grid(dataset, crs=crs)
350 dataset.to_netcdf(modeldirectory / pkg_path)
352 toml_path = modeldirectory / f"{modelname}.toml"
353 with open(toml_path, "wb") as f:
354 tomli_w.dump(toml_content, f)
356 return toml_path
358 @classmethod
359 def from_file(cls, toml_path):
360 pkg_classes = {
361 name: pkg_cls
362 for name, pkg_cls in inspect.getmembers(imod.mf6, inspect.isclass)
363 if issubclass(pkg_cls, Package)
364 }
366 toml_path = pathlib.Path(toml_path)
367 with open(toml_path, "rb") as f:
368 toml_content = tomli.load(f)
370 parentdir = toml_path.parent
371 instance = cls()
372 for key, entry in toml_content.items():
373 for pkgname, path in entry.items():
374 pkg_cls = pkg_classes[key]
375 instance[pkgname] = pkg_cls.from_file(parentdir / path)
377 return instance
379 @property
380 def options(self) -> dict:
381 if self._options is None:
382 raise ValueError("Model id has not been set")
383 return self._options
385 @property
386 def model_id(self) -> str:
387 if self._model_id is None:
388 raise ValueError("Model id has not been set")
389 return self._model_id
391 def clip_box(
392 self,
393 time_min: Optional[cftime.datetime | np.datetime64 | str] = None,
394 time_max: Optional[cftime.datetime | np.datetime64 | str] = None,
395 layer_min: Optional[int] = None,
396 layer_max: Optional[int] = None,
397 x_min: Optional[float] = None,
398 x_max: Optional[float] = None,
399 y_min: Optional[float] = None,
400 y_max: Optional[float] = None,
401 state_for_boundary: Optional[GridDataArray] = None,
402 ):
403 """
404 Clip a model by a bounding box (time, layer, y, x).
406 Slicing intervals may be half-bounded, by providing None:
408 * To select 500.0 <= x <= 1000.0:
409 ``clip_box(x_min=500.0, x_max=1000.0)``.
410 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)``
411 or ``clip_box(x_max=1000.0)``.
412 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)``
413 or ``clip_box(x_min=1000.0)``.
415 Parameters
416 ----------
417 time_min: optional
418 time_max: optional
419 layer_min: optional, int
420 layer_max: optional, int
421 x_min: optional, float
422 x_max: optional, float
423 y_min: optional, float
424 y_max: optional, float
425 state_for_boundary: optional, float
426 """
427 supported, error_with_object = self.is_clipping_supported()
428 if not supported:
429 raise ValueError(
430 f"model cannot be clipped due to presence of package '{error_with_object}' in model"
431 )
433 clipped = self._clip_box_packages(
434 time_min,
435 time_max,
436 layer_min,
437 layer_max,
438 x_min,
439 x_max,
440 y_min,
441 y_max,
442 )
444 return clipped
446 def _clip_box_packages(
447 self,
448 time_min: Optional[cftime.datetime | np.datetime64 | str] = None,
449 time_max: Optional[cftime.datetime | np.datetime64 | str] = None,
450 layer_min: Optional[int] = None,
451 layer_max: Optional[int] = None,
452 x_min: Optional[float] = None,
453 x_max: Optional[float] = None,
454 y_min: Optional[float] = None,
455 y_max: Optional[float] = None,
456 ):
457 """
458 Clip a model by a bounding box (time, layer, y, x).
460 Slicing intervals may be half-bounded, by providing None:
462 * To select 500.0 <= x <= 1000.0:
463 ``clip_box(x_min=500.0, x_max=1000.0)``.
464 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)``
465 or ``clip_box(x_max=1000.0)``.
466 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)``
467 or ``clip_box(x_min=1000.0)``.
469 Parameters
470 ----------
471 time_min: optional
472 time_max: optional
473 layer_min: optional, int
474 layer_max: optional, int
475 x_min: optional, float
476 x_max: optional, float
477 y_min: optional, float
478 y_max: optional, float
480 Returns
481 -------
482 clipped : Modflow6Model
483 """
485 top, bottom, idomain = self.__get_domain_geometry()
487 clipped = type(self)(**self._options)
488 for key, pkg in self.items():
489 clipped[key] = pkg.clip_box(
490 time_min=time_min,
491 time_max=time_max,
492 layer_min=layer_min,
493 layer_max=layer_max,
494 x_min=x_min,
495 x_max=x_max,
496 y_min=y_min,
497 y_max=y_max,
498 top=top,
499 bottom=bottom,
500 )
502 return clipped
504 def regrid_like(
505 self,
506 target_grid: GridDataArray,
507 validate: bool = True,
508 regrid_context: Optional[RegridderWeightsCache] = None,
509 ) -> "Modflow6Model":
510 """
511 Creates a model by regridding the packages of this model to another discretization.
512 It regrids all the arrays in the package using the default regridding methods.
513 At the moment only regridding to a different planar grid is supported, meaning
514 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords.
516 Parameters
517 ----------
518 target_grid: xr.DataArray or xu.UgridDataArray
519 a grid defined over the same discretization as the one we want to regrid the package to
520 validate: bool
521 set to true to validate the regridded packages
522 regrid_context: Optional RegridderWeightsCache
523 stores regridder weights for different regridders. Can be used to speed up regridding,
524 if the same regridders are used several times for regridding different arrays.
526 Returns
527 -------
528 a model with similar packages to the input model, and with all the data-arrays regridded to another discretization,
529 similar to the one used in input argument "target_grid"
530 """
531 return _regrid_like(self, target_grid, validate, regrid_context)
533 def mask_all_packages(
534 self,
535 mask: GridDataArray,
536 ):
537 """
538 This function applies a mask to all packages in a model. The mask must
539 be presented as an idomain-like integer array that has 0 (inactive) or
540 -1 (vertical passthrough) values in filtered cells and 1 in active
541 cells.
542 Masking will overwrite idomain with the mask where the mask is 0 or -1.
543 Where the mask is 1, the original value of idomain will be kept. Masking
544 will update the packages accordingly, blanking their input where needed,
545 and is therefore not a reversible operation.
547 Parameters
548 ----------
549 mask: xr.DataArray, xu.UgridDataArray of ints
550 idomain-like integer array. 1 sets cells to active, 0 sets cells to inactive,
551 -1 sets cells to vertical passthrough
552 """
554 _mask_all_packages(self, mask)
556 def purge_empty_packages(self, model_name: Optional[str] = "") -> None:
557 """
558 This function removes empty packages from the model.
559 """
560 empty_packages = [
561 package_name for package_name, package in self.items() if package.is_empty()
562 ]
563 for package_name in empty_packages:
564 self.pop(package_name)
566 @property
567 def domain(self):
568 dis = self._get_diskey()
569 return self[dis]["idomain"]
571 @property
572 def bottom(self):
573 dis = self._get_diskey()
574 return self[dis]["bottom"]
576 def __repr__(self) -> str:
577 INDENT = " "
578 typename = type(self).__name__
579 options = [
580 f"{INDENT}{key}={repr(value)}," for key, value in self._options.items()
581 ]
582 packages = [
583 f"{INDENT}{repr(key)}: {type(value).__name__},"
584 for key, value in self.items()
585 ]
586 # Place the emtpy dict on the same line. Looks silly otherwise.
587 if packages:
588 content = [f"{typename}("] + options + ["){"] + packages + ["}"]
589 else:
590 content = [f"{typename}("] + options + ["){}"]
591 return "\n".join(content)
593 def is_use_newton(self):
594 return False
596 def is_splitting_supported(self) -> Tuple[bool, str]:
597 """
598 Returns True if all the packages in the model supports splitting. If one
599 of the packages in the model does not support splitting, it returns the
600 name of the first one.
601 """
602 for package_name, package in self.items():
603 if not package.is_splitting_supported():
604 return False, package_name
605 return True, ""
607 def is_regridding_supported(self) -> Tuple[bool, str]:
608 """
609 Returns True if all the packages in the model supports regridding. If one
610 of the packages in the model does not support regridding, it returns the
611 name of the first one.
612 """
613 for package_name, package in self.items():
614 if not package.is_regridding_supported():
615 return False, package_name
616 return True, ""
618 def is_clipping_supported(self) -> Tuple[bool, str]:
619 """
620 Returns True if all the packages in the model supports clipping. If one
621 of the packages in the model does not support clipping, it returns the
622 name of the first one.
623 """
624 for package_name, package in self.items():
625 if not package.is_clipping_supported():
626 return False, package_name
627 return True, ""