Coverage for C:\src\imod-python\imod\schemata.py: 92%
262 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-13 11:15 +0200
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-13 11:15 +0200
1"""
2Schemata to help validation of input.
4This code is based on: https://github.com/carbonplan/xarray-schema
6which has the following MIT license:
8 MIT License
10 Copyright (c) 2021 carbonplan
12 Permission is hereby granted, free of charge, to any person obtaining a copy
13 of this software and associated documentation files (the "Software"), to deal
14 in the Software without restriction, including without limitation the rights
15 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16 copies of the Software, and to permit persons to whom the Software is
17 furnished to do so, subject to the following conditions:
19 The above copyright notice and this permission notice shall be included in all
20 copies or substantial portions of the Software.
22 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 SOFTWARE.
30In the future, we may be able to replace this module by whatever the best
31validation xarray library becomes.
32"""
34import abc
35import operator
36from functools import partial
37from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeAlias, Union
39import numpy as np
40import scipy
41import xarray as xr
42import xugrid as xu
43from numpy.typing import DTypeLike # noqa: F401
45from imod.typing import GridDataArray, ScalarAsDataArray
47DimsT = Union[str, None]
48ShapeT = Tuple[Union[int, None]]
49ChunksT = Union[bool, Dict[str, Union[int, None]]]
51OPERATORS = {
52 "<": operator.lt,
53 "<=": operator.le,
54 "==": operator.eq,
55 "!=": operator.ne,
56 ">=": operator.ge,
57 ">": operator.gt,
58}
61def partial_operator(op, value):
62 # partial doesn't allow us to insert the 1st arg on call, and
63 # operators don't work with kwargs, so resort to lambda to swap
64 # args a and b around.
65 # https://stackoverflow.com/a/37468215
66 return partial(lambda b, a: OPERATORS[op](a, b), value)
69def scalar_None(obj):
70 """
71 Test if object is a scalar None DataArray, which is the default value for optional
72 variables.
73 """
74 if not isinstance(obj, (xr.DataArray, xu.UgridDataArray)):
75 return False
76 else:
77 return (len(obj.shape) == 0) & (~obj.notnull()).all()
80def align_other_obj_with_coords(
81 obj: GridDataArray, other_obj: GridDataArray
82) -> Tuple[xr.DataArray, xr.DataArray]:
83 """
84 Align other_obj with obj if coordname in obj but not in its dims.
85 Avoid issues like:
86 https://github.com/Deltares/imod-python/issues/830
88 """
89 for coordname in obj.coords.keys():
90 if (coordname in other_obj.dims) and coordname not in obj.dims:
91 obj = obj.expand_dims(coordname)
92 # Note:
93 # xr.align forces xu.UgridDataArray to xr.DataArray. Keep that in mind
94 # in further data processing.
95 return xr.align(obj, other_obj, join="left")
98class ValidationError(Exception):
99 pass
102class BaseSchema(abc.ABC):
103 @abc.abstractmethod
104 def validate(self, obj: GridDataArray, **kwargs) -> None:
105 pass
107 def __or__(self, other):
108 """
109 This allows us to write:
111 DimsSchema("layer", "y", "x") | DimsSchema("layer")
113 And get a SchemaUnion back.
114 """
115 return SchemaUnion(self, other)
118# SchemaType = TypeVar("SchemaType", bound=BaseSchema)
119SchemaType: TypeAlias = BaseSchema
122class SchemaUnion:
123 """
124 Succesful validation only requires a single succes.
126 Used to validate multiple options.
127 """
129 def __init__(self, *args):
130 ntypes = len({type(arg) for arg in args})
131 if ntypes > 1:
132 raise TypeError("schemata in a union should have the same type")
133 self.schemata = tuple(args)
135 def validate(self, obj: Any, **kwargs):
136 errors = []
137 for schema in self.schemata:
138 try:
139 schema.validate(obj, **kwargs)
140 except ValidationError as e:
141 errors.append(e)
143 if len(errors) == len(self.schemata): # All schemata failed
144 message = "\n\t" + "\n\t".join(str(error) for error in errors)
145 raise ValidationError(f"No option succeeded:{message}")
147 def __or__(self, other):
148 return SchemaUnion(*self.schemata, other)
151class OptionSchema(BaseSchema):
152 """
153 Check whether the value is one of given valid options.
154 """
156 def __init__(self, options: Sequence[Any]):
157 self.options = options
159 def validate(self, obj: ScalarAsDataArray, **kwargs) -> None:
160 if scalar_None(obj):
161 return
163 # MODFLOW 6 is not case sensitive for string options.
164 value = obj.item()
165 if isinstance(value, str):
166 value = value.lower()
168 if value not in self.options:
169 valid_values = ", ".join(map(str, self.options))
170 raise ValidationError(
171 f"Invalid option: {value}. Valid options are: {valid_values}"
172 )
175class DTypeSchema(BaseSchema):
176 def __init__(self, dtype: DTypeLike) -> None:
177 if dtype in [
178 np.floating,
179 np.integer,
180 np.signedinteger,
181 np.unsignedinteger,
182 np.generic,
183 ]:
184 self.dtype = dtype
185 else:
186 self.dtype = np.dtype(dtype)
188 def validate(self, obj: GridDataArray, **kwargs) -> None:
189 """
190 Validate dtype
192 Parameters
193 ----------
194 dtype : Any
195 Dtype of the DataArray.
196 """
197 if scalar_None(obj):
198 return
200 if not np.issubdtype(obj.dtype, self.dtype):
201 raise ValidationError(f"dtype {obj.dtype} != {self.dtype}")
204class DimsSchema(BaseSchema):
205 def __init__(self, *dims: DimsT) -> None:
206 self.dims = dims
208 def _fill_in_face_dim(self, obj: Union[xr.DataArray, xu.UgridDataArray]):
209 """
210 Return dims with a filled in face dim if necessary.
211 """
212 if "{face_dim}" in self.dims and isinstance(obj, xu.UgridDataArray):
213 return tuple(
214 (
215 obj.ugrid.grid.face_dimension if i == "{face_dim}" else i
216 for i in self.dims
217 )
218 )
219 elif "{edge_dim}" in self.dims and isinstance(obj, xu.UgridDataArray):
220 return tuple(
221 (
222 obj.ugrid.grid.edge_dimension if i == "{edge_dim}" else i
223 for i in self.dims
224 )
225 )
226 else:
227 return self.dims
229 def validate(self, obj: GridDataArray, **kwargs) -> None:
230 """Validate dimensions
231 Parameters
232 ----------
233 dims : Tuple[Union[str, None]]
234 Dimensions of the DataArray. `None` may be used as a wildcard value.
235 """
236 dims = self._fill_in_face_dim(obj)
237 # Force to tuple for error message print
238 expected = tuple(dims)
239 actual = tuple(obj.dims)
240 if actual != expected:
241 raise ValidationError(f"dim mismatch: expected {expected}, got {actual}")
244class EmptyIndexesSchema(BaseSchema):
245 """
246 Verify indexes, check if no dims with zero size are included. Skips
247 unstructured grid dimensions.
248 """
250 def __init__(self) -> None:
251 pass
253 def get_dims_to_validate(self, obj: Union[xr.DataArray, xu.UgridDataArray]):
254 dims_to_validate = list(obj.dims)
256 # Remove face dim from list to validate, as it has no ``indexes``
257 # attribute.
258 if isinstance(obj, xu.UgridDataArray):
259 ugrid_dims = obj.ugrid.grid.dimensions
260 dims_to_validate = [
261 dim for dim in dims_to_validate if dim not in ugrid_dims
262 ]
263 return dims_to_validate
265 def validate(self, obj: GridDataArray, **kwargs) -> None:
266 dims_to_validate = self.get_dims_to_validate(obj)
268 for dim in dims_to_validate:
269 if len(obj.indexes[dim]) == 0:
270 raise ValidationError(f"provided dimension {dim} with size 0")
273class IndexesSchema(EmptyIndexesSchema):
274 """
275 Verify indexes, check if no dims with zero size are included and that
276 indexes are monotonic. Skips unstructured grid dimensions.
277 """
279 def __init__(self) -> None:
280 pass
282 def validate(self, obj: GridDataArray, **kwargs) -> None:
283 # Test if indexes all empty
284 super().validate(obj)
286 dims_to_validate = self.get_dims_to_validate(obj)
288 for dim in dims_to_validate:
289 if dim == "y":
290 if not obj.indexes[dim].is_monotonic_decreasing:
291 raise ValidationError(
292 f"coord {dim} which is not monotonically decreasing"
293 )
295 else:
296 if not obj.indexes[dim].is_monotonic_increasing:
297 raise ValidationError(
298 f"coord {dim} which is not monotonically increasing"
299 )
302class ShapeSchema(BaseSchema):
303 def __init__(self, shape: ShapeT) -> None:
304 """
305 Validate shape.
307 Parameters
308 ----------
309 shape : ShapeT
310 Shape of the DataArray. `None` may be used as a wildcard value.
311 """
312 self.shape = shape
314 def validate(self, obj: GridDataArray, **kwargs) -> None:
315 if len(self.shape) != len(obj.shape):
316 raise ValidationError(
317 f"number of dimensions in shape ({len(obj.shape)}) o!= da.ndim ({len(self.shape)})"
318 )
320 for i, (actual, expected) in enumerate(zip(obj.shape, self.shape)):
321 if expected is not None and actual != expected:
322 raise ValidationError(
323 f"shape mismatch in axis {i}: {actual} != {expected}"
324 )
327class CompatibleSettingsSchema(BaseSchema):
328 def __init__(self, other: ScalarAsDataArray, other_value: bool) -> None:
329 """
330 Validate if settings are compatible
331 """
332 self.other = other
333 self.other_value = other_value
335 def validate(self, obj: ScalarAsDataArray, **kwargs) -> None:
336 other_obj = kwargs[self.other]
337 if scalar_None(obj) or scalar_None(other_obj):
338 return
339 expected = np.all(other_obj == self.other_value)
341 if obj and not expected:
342 raise ValidationError(
343 f"Incompatible setting: {self.other} should be {self.other_value}"
344 )
347class CoordsSchema(BaseSchema):
348 """
349 Validate presence of coords.
351 Parameters
352 ----------
353 coords : dict_like
354 coords of the DataArray. `None` may be used as a wildcard value.
355 """
357 def __init__(
358 self,
359 coords: Tuple[str, ...],
360 require_all_keys: bool = True,
361 allow_extra_keys: bool = True,
362 ) -> None:
363 self.coords = coords
364 self.require_all_keys = require_all_keys
365 self.allow_extra_keys = allow_extra_keys
367 def validate(self, obj: GridDataArray, **kwargs) -> None:
368 coords = list(obj.coords.keys())
370 if self.require_all_keys:
371 missing_keys = set(self.coords) - set(coords)
372 if missing_keys:
373 raise ValidationError(f"coords has missing keys: {missing_keys}")
375 if not self.allow_extra_keys:
376 extra_keys = set(coords) - set(self.coords)
377 if extra_keys:
378 raise ValidationError(f"coords has extra keys: {extra_keys}")
380 for key in self.coords:
381 if key not in coords:
382 raise ValidationError(f"key {key} not in coords")
385class OtherCoordsSchema(BaseSchema):
386 """
387 Validate whether coordinates match those of other.
388 """
390 def __init__(
391 self,
392 other: str,
393 require_all_keys: bool = True,
394 allow_extra_keys: bool = True,
395 ):
396 self.other = other
397 self.require_all_keys = require_all_keys
398 self.allow_extra_keys = allow_extra_keys
400 def validate(self, obj: GridDataArray, **kwargs) -> None:
401 other_obj = kwargs[self.other]
402 other_coords = list(other_obj.coords.keys())
403 return CoordsSchema(
404 other_coords,
405 self.require_all_keys,
406 self.allow_extra_keys,
407 ).validate(obj)
410class ValueSchema(BaseSchema, abc.ABC):
411 """
412 Base class for AllValueSchema or AnyValueSchema.
413 """
415 def __init__(
416 self,
417 operator: str,
418 other: Any,
419 ignore: Optional[Tuple[str, str, Any]] = None,
420 ):
421 self.operator = OPERATORS[operator]
422 self.operator_str = operator
423 self.other = other
424 self.to_ignore = None
425 self.ignore_varname = None
427 if ignore:
428 self.ignore_varname = ignore[0]
429 self.to_ignore = partial_operator(ignore[1], ignore[2])
431 def get_explicitly_ignored(self, kwargs: Dict) -> Any:
432 """
433 Get cells that should be explicitly ignored by the schema
434 """
435 if self.to_ignore:
436 ignore_obj = kwargs[self.ignore_varname]
437 return self.to_ignore(ignore_obj)
438 else:
439 return False
442class AllValueSchema(ValueSchema):
443 """
444 Validate whether all values pass a condition.
446 E.g. if operator is ">":
448 assert (values > threshold).all()
449 """
451 def validate(self, obj: GridDataArray, **kwargs) -> None:
452 if isinstance(self.other, str):
453 other_obj = kwargs[self.other]
454 else:
455 other_obj = self.other
457 if scalar_None(obj) or scalar_None(other_obj):
458 return
460 explicitly_ignored = self.get_explicitly_ignored(kwargs)
462 ignore = (
463 np.isnan(obj) | np.isnan(other_obj) | explicitly_ignored
464 ) # ignore nan by setting to True
466 condition = self.operator(obj, other_obj)
467 condition = condition | ignore
468 if not condition.all():
469 raise ValidationError(
470 f"not all values comply with criterion: {self.operator_str} {self.other}"
471 )
474class AnyValueSchema(ValueSchema):
475 """
476 Validate whether any value passes a condition.
478 E.g. if operator is ">":
480 assert (values > threshold).any()
481 """
483 def validate(self, obj: GridDataArray, **kwargs) -> None:
484 if isinstance(self.other, str):
485 other_obj = kwargs[self.other]
486 else:
487 other_obj = self.other
489 if scalar_None(obj) or scalar_None(other_obj):
490 return
492 explicitly_ignored = self.get_explicitly_ignored(kwargs)
494 ignore = (
495 ~np.isnan(obj) | ~np.isnan(other_obj) | explicitly_ignored
496 ) # ignore nan by setting to False
498 condition = self.operator(obj, other_obj)
499 condition = condition | ignore
500 if not condition.any():
501 raise ValidationError(
502 f"not a single value complies with criterion: {self.operator_str} {self.other}"
503 )
506def _notnull(obj):
507 """
508 Helper function; does the same as xr.DataArray.notnull. This function is to
509 avoid an issue where xr.DataArray.notnull() returns ordinary numpy arrays
510 for instances of xu.UgridDataArray.
511 """
513 return ~np.isnan(obj)
516class NoDataSchema(BaseSchema):
517 def __init__(
518 self,
519 is_notnull: Union[Callable, Tuple[str, Any]] = _notnull,
520 ):
521 if isinstance(is_notnull, tuple):
522 op, value = is_notnull
523 self.is_notnull = partial_operator(op, value)
524 else:
525 self.is_notnull = is_notnull
528class AllNoDataSchema(NoDataSchema):
529 """
530 Fails when all data is NoData.
531 """
533 def validate(self, obj: GridDataArray, **kwargs) -> None:
534 valid = self.is_notnull(obj)
535 if ~valid.any():
536 raise ValidationError("all nodata")
539class AnyNoDataSchema(NoDataSchema):
540 """
541 Fails when any data is NoData.
542 """
544 def validate(self, obj: GridDataArray, **kwargs) -> None:
545 valid = self.is_notnull(obj)
546 if ~valid.all():
547 raise ValidationError("found a nodata value")
550class NoDataComparisonSchema(BaseSchema):
551 """
552 Base class for IdentityNoDataSchema and AllInsideNoDataSchema.
553 """
555 def __init__(
556 self,
557 other: str,
558 is_notnull: Union[Callable, Tuple[str, Any]] = _notnull,
559 is_other_notnull: Union[Callable, Tuple[str, Any]] = _notnull,
560 ):
561 self.other = other
562 if isinstance(is_notnull, tuple):
563 op, value = is_notnull
564 self.is_notnull = partial_operator(op, value)
565 else:
566 self.is_notnull = is_notnull
568 if isinstance(is_other_notnull, tuple):
569 op, value = is_other_notnull
570 self.is_other_notnull = partial_operator(op, value)
571 else:
572 self.is_other_notnull = is_other_notnull
575class IdentityNoDataSchema(NoDataComparisonSchema):
576 """
577 Checks that the NoData values are located at exactly the same locations.
579 Tests only if if all dimensions of the other object are present in the
580 object. So tests if "stage" with `{time, layer, y, x}` compared to "idomain"
581 `{layer, y, x}` but doesn't test if "k" with `{layer}` is comperated to
582 "idomain" `{layer, y, x}`
583 """
585 def validate(self, obj: GridDataArray, **kwargs) -> None:
586 other_obj = kwargs[self.other]
588 # Only test if object has all dimensions in other object.
589 missing_dims = set(other_obj.dims) - set(obj.dims)
591 if len(missing_dims) == 0:
592 valid = self.is_notnull(obj)
593 other_valid = self.is_other_notnull(other_obj)
594 if (valid ^ other_valid).any():
595 raise ValidationError(f"nodata is not aligned with {self.other}")
598class AllInsideNoDataSchema(NoDataComparisonSchema):
599 """
600 Checks that all notnull values all occur within the notnull values of other.
601 """
603 def validate(self, obj: GridDataArray, **kwargs) -> None:
604 other_obj = kwargs[self.other]
605 valid = self.is_notnull(obj)
606 other_valid = self.is_other_notnull(other_obj)
608 valid, other_valid = align_other_obj_with_coords(valid, other_obj)
610 if (valid & ~other_valid).any():
611 raise ValidationError(f"data values found at nodata values of {self.other}")
614class ActiveCellsConnectedSchema(BaseSchema):
615 """
616 Check if active cells are connected, to avoid isolated islands which can
617 cause convergence issues, if they don't have a head boundary condition, but
618 do have a specified flux.
620 Note
621 ----
622 This schema only works for structured grids.
623 """
625 def __init__(
626 self,
627 is_notnull: Union[Callable, Tuple[str, Any]] = _notnull,
628 ):
629 if isinstance(is_notnull, tuple):
630 op, value = is_notnull
631 self.is_notnull = partial_operator(op, value)
632 else:
633 self.is_notnull = is_notnull
635 def validate(self, obj: GridDataArray, **kwargs) -> None:
636 if isinstance(obj, xu.UgridDataArray):
637 # TODO: https://deltares.github.io/xugrid/api/xugrid.UgridDataArrayAccessor.connected_components.html
638 raise NotImplementedError(
639 f"Schema {self.__name__} only works for structured grids, received xu.UgridDataArray."
640 )
642 active = self.is_notnull(obj)
644 _, nlabels = scipy.ndimage.label(active)
645 if nlabels > 1:
646 raise ValidationError(
647 f"{nlabels} disconnected areas detected in model domain"
648 )