Coverage for C:\src\imod-python\imod\mf6\utilities\mask.py: 93%
73 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 11:25 +0200
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 11:25 +0200
1import numbers
3import numpy as np
4from fastcore.dispatch import typedispatch
5from xarray.core.utils import is_scalar
7from imod.mf6.auxiliary_variables import (
8 expand_transient_auxiliary_variables,
9 remove_expanded_auxiliary_variables_from_dataset,
10)
11from imod.mf6.interfaces.imaskingsettings import IMaskingSettings
12from imod.mf6.interfaces.imodel import IModel
13from imod.mf6.interfaces.ipackage import IPackage
14from imod.mf6.interfaces.isimulation import ISimulation
15from imod.typing.grid import GridDataArray, get_spatial_dimension_names, is_same_domain
18def _mask_all_models(
19 simulation: ISimulation,
20 mask: GridDataArray,
21):
22 spatial_dims = get_spatial_dimension_names(mask)
23 if any(coord not in spatial_dims for coord in mask.coords):
24 raise ValueError("unexpected coordinate dimension in masking domain")
26 if simulation.is_split():
27 raise ValueError(
28 "masking can only be applied to simulations that have not been split. Apply masking before splitting."
29 )
31 flowmodels = list(simulation.get_models_of_type("gwf6").keys())
32 transportmodels = list(simulation.get_models_of_type("gwt6").keys())
34 modelnames = flowmodels + transportmodels
36 for name in modelnames:
37 if is_same_domain(simulation[name].domain, mask):
38 simulation[name].mask_all_packages(mask)
39 else:
40 raise ValueError(
41 "masking can only be applied to simulations when all the models in the simulation use the same grid."
42 )
45def _mask_all_packages(
46 model: IModel,
47 mask: GridDataArray,
48):
49 spatial_dimension_names = get_spatial_dimension_names(mask)
50 if any(coord not in spatial_dimension_names for coord in mask.coords):
51 raise ValueError("unexpected coordinate dimension in masking domain")
53 for pkgname, pkg in model.items():
54 model[pkgname] = pkg.mask(mask)
55 model.purge_empty_packages()
58def mask_package(package: IPackage, mask: GridDataArray) -> IPackage:
59 masked = {}
60 if len(package.auxiliary_data_fields) > 0:
61 remove_expanded_auxiliary_variables_from_dataset(package)
63 for var in package.dataset.data_vars.keys():
64 if _skip_dataarray(package.dataset[var]) or _skip_variable(package, var):
65 masked[var] = package.dataset[var]
66 else:
67 masked[var] = _mask_spatial_var(package, var, mask)
69 if len(package.auxiliary_data_fields) > 0:
70 expand_transient_auxiliary_variables(package)
71 return type(package)(**masked)
74def _skip_dataarray(da: GridDataArray) -> bool:
75 if len(da.dims) == 0 or set(da.coords).issubset(["layer"]):
76 return True
78 if is_scalar(da.values[()]):
79 return True
81 spatial_dims = ["x", "y", "mesh2d_nFaces", "layer"]
82 if not np.any([coord in spatial_dims for coord in da.coords]):
83 return True
85 return False
88@typedispatch
89def _skip_variable(package: IPackage, var: str) -> bool:
90 return False
93@typedispatch # type: ignore [no-redef]
94def _skip_variable(package: IMaskingSettings, var: str) -> bool:
95 return var in package.skip_variables
98def _mask_spatial_var(self, var: str, mask: GridDataArray) -> GridDataArray:
99 da = self.dataset[var]
100 array_mask = _adjust_mask_for_unlayered_data(da, mask)
102 if issubclass(da.dtype.type, numbers.Integral):
103 if var == "idomain":
104 return da.where(array_mask > 0, other=array_mask)
105 else:
106 return da.where(array_mask > 0, other=0)
107 elif issubclass(da.dtype.type, numbers.Real):
108 return da.where(array_mask > 0)
109 else:
110 raise TypeError(
111 f"Expected dtype float or integer. Received instead: {da.dtype}"
112 )
115def _adjust_mask_for_unlayered_data(
116 da: GridDataArray, mask: GridDataArray
117) -> GridDataArray:
118 """
119 Some arrays are not layered while the mask is layered (for example the
120 top array in dis or disv packaged). In that case we use the top layer of
121 the mask to perform the masking. If layer is not a dataset dimension,
122 but still a dataset coordinate, we limit the mask to the relevant layer
123 coordinate(s).
124 """
125 array_mask = mask
126 if "layer" in da.coords and "layer" not in da.dims:
127 array_mask = mask.sel(layer=da.coords["layer"])
128 if "layer" not in da.coords and "layer" in array_mask.coords:
129 array_mask = mask.isel(layer=0)
131 return array_mask