Coverage for C:\src\imod-python\imod\mf6\utilities\mask.py: 92%

65 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-08 13:27 +0200

1import numbers 

2 

3import numpy as np 

4from xarray.core.utils import is_scalar 

5 

6from imod.mf6.auxiliary_variables import ( 

7 expand_transient_auxiliary_variables, 

8 remove_expanded_auxiliary_variables_from_dataset, 

9) 

10from imod.mf6.interfaces.imodel import IModel 

11from imod.mf6.interfaces.ipackage import IPackage 

12from imod.mf6.interfaces.isimulation import ISimulation 

13from imod.typing.grid import GridDataArray, get_spatial_dimension_names, is_same_domain 

14 

15 

16def _mask_all_models( 

17 simulation: ISimulation, 

18 mask: GridDataArray, 

19): 

20 spatial_dims = get_spatial_dimension_names(mask) 

21 if any([coord not in spatial_dims for coord in mask.coords]): 

22 raise ValueError("unexpected coordinate dimension in masking domain") 

23 

24 if simulation.is_split(): 

25 raise ValueError( 

26 "masking can only be applied to simulations that have not been split. Apply masking before splitting." 

27 ) 

28 

29 flowmodels = list(simulation.get_models_of_type("gwf6").keys()) 

30 transportmodels = list(simulation.get_models_of_type("gwt6").keys()) 

31 modelnames = flowmodels + transportmodels 

32 

33 for name in modelnames: 

34 if is_same_domain(simulation[name].domain, mask): 

35 simulation[name].mask_all_packages(mask) 

36 else: 

37 raise ValueError( 

38 "masking can only be applied to simulations when all the models in the simulation use the same grid." 

39 ) 

40 

41 

42def _mask_all_packages( 

43 model: IModel, 

44 mask: GridDataArray, 

45): 

46 spatial_dimension_names = get_spatial_dimension_names(mask) 

47 if any([coord not in spatial_dimension_names for coord in mask.coords]): 

48 raise ValueError("unexpected coordinate dimension in masking domain") 

49 

50 for pkgname, pkg in model.items(): 

51 model[pkgname] = pkg.mask(mask) 

52 model.purge_empty_packages() 

53 

54 

55def _mask(package: IPackage, mask: GridDataArray) -> IPackage: 

56 masked = {} 

57 if len(package.auxiliary_data_fields) > 0: 

58 remove_expanded_auxiliary_variables_from_dataset(package) 

59 for var in package.dataset.data_vars.keys(): 

60 if _skip_masking_variable(package, var, package.dataset[var]): 

61 masked[var] = package.dataset[var] 

62 else: 

63 masked[var] = _mask_spatial_var(package, var, mask) 

64 if len(package.auxiliary_data_fields) > 0: 

65 expand_transient_auxiliary_variables(package) 

66 return type(package)(**masked) 

67 

68 

69def _skip_masking_variable(package: IPackage, var: str, da: GridDataArray) -> bool: 

70 if ( 

71 package._skip_masking_dataarray(var) 

72 or len(da.dims) == 0 

73 or set(da.coords).issubset(["layer"]) 

74 ): 

75 return True 

76 if is_scalar(da.values[()]): 

77 return True 

78 spatial_dims = ["x", "y", "mesh2d_nFaces", "layer"] 

79 if not np.any([coord in spatial_dims for coord in da.coords]): 

80 return True 

81 return False 

82 

83 

84def _mask_spatial_var(self, var: str, mask: GridDataArray) -> GridDataArray: 

85 da = self.dataset[var] 

86 array_mask = _adjust_mask_for_unlayered_data(da, mask) 

87 

88 if issubclass(da.dtype.type, numbers.Integral): 

89 if var == "idomain": 

90 return da.where(array_mask > 0, other=array_mask) 

91 else: 

92 return da.where(array_mask > 0, other=0) 

93 elif issubclass(da.dtype.type, numbers.Real): 

94 return da.where(array_mask > 0) 

95 else: 

96 raise TypeError( 

97 f"Expected dtype float or integer. Received instead: {da.dtype}" 

98 ) 

99 

100 

101def _adjust_mask_for_unlayered_data( 

102 da: GridDataArray, mask: GridDataArray 

103) -> GridDataArray: 

104 """ 

105 Some arrays are not layered while the mask is layered (for example the 

106 top array in dis or disv packaged). In that case we use the top layer of 

107 the mask to perform the masking. If layer is not a dataset dimension, 

108 but still a dataset coordinate, we limit the mask to the relevant layer 

109 coordinate(s). 

110 """ 

111 array_mask = mask 

112 if "layer" in da.coords and "layer" not in da.dims: 

113 array_mask = mask.sel(layer=da.coords["layer"]) 

114 if "layer" not in da.coords and "layer" in array_mask.coords: 

115 array_mask = mask.isel(layer=0) 

116 

117 return array_mask