Coverage for C:\src\imod-python\imod\schemata.py: 92%

262 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-13 11:15 +0200

1""" 

2Schemata to help validation of input. 

3 

4This code is based on: https://github.com/carbonplan/xarray-schema 

5 

6which has the following MIT license: 

7 

8 MIT License 

9 

10 Copyright (c) 2021 carbonplan 

11 

12 Permission is hereby granted, free of charge, to any person obtaining a copy 

13 of this software and associated documentation files (the "Software"), to deal 

14 in the Software without restriction, including without limitation the rights 

15 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

16 copies of the Software, and to permit persons to whom the Software is 

17 furnished to do so, subject to the following conditions: 

18 

19 The above copyright notice and this permission notice shall be included in all 

20 copies or substantial portions of the Software. 

21 

22 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

23 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

24 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

25 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

26 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

27 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

28 SOFTWARE. 

29 

30In the future, we may be able to replace this module by whatever the best 

31validation xarray library becomes. 

32""" 

33 

34import abc 

35import operator 

36from functools import partial 

37from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeAlias, Union 

38 

39import numpy as np 

40import scipy 

41import xarray as xr 

42import xugrid as xu 

43from numpy.typing import DTypeLike # noqa: F401 

44 

45from imod.typing import GridDataArray, ScalarAsDataArray 

46 

47DimsT = Union[str, None] 

48ShapeT = Tuple[Union[int, None]] 

49ChunksT = Union[bool, Dict[str, Union[int, None]]] 

50 

51OPERATORS = { 

52 "<": operator.lt, 

53 "<=": operator.le, 

54 "==": operator.eq, 

55 "!=": operator.ne, 

56 ">=": operator.ge, 

57 ">": operator.gt, 

58} 

59 

60 

61def partial_operator(op, value): 

62 # partial doesn't allow us to insert the 1st arg on call, and 

63 # operators don't work with kwargs, so resort to lambda to swap 

64 # args a and b around. 

65 # https://stackoverflow.com/a/37468215 

66 return partial(lambda b, a: OPERATORS[op](a, b), value) 

67 

68 

69def scalar_None(obj): 

70 """ 

71 Test if object is a scalar None DataArray, which is the default value for optional 

72 variables. 

73 """ 

74 if not isinstance(obj, (xr.DataArray, xu.UgridDataArray)): 

75 return False 

76 else: 

77 return (len(obj.shape) == 0) & (~obj.notnull()).all() 

78 

79 

80def align_other_obj_with_coords( 

81 obj: GridDataArray, other_obj: GridDataArray 

82) -> Tuple[xr.DataArray, xr.DataArray]: 

83 """ 

84 Align other_obj with obj if coordname in obj but not in its dims. 

85 Avoid issues like: 

86 https://github.com/Deltares/imod-python/issues/830 

87 

88 """ 

89 for coordname in obj.coords.keys(): 

90 if (coordname in other_obj.dims) and coordname not in obj.dims: 

91 obj = obj.expand_dims(coordname) 

92 # Note: 

93 # xr.align forces xu.UgridDataArray to xr.DataArray. Keep that in mind 

94 # in further data processing. 

95 return xr.align(obj, other_obj, join="left") 

96 

97 

98class ValidationError(Exception): 

99 pass 

100 

101 

102class BaseSchema(abc.ABC): 

103 @abc.abstractmethod 

104 def validate(self, obj: GridDataArray, **kwargs) -> None: 

105 pass 

106 

107 def __or__(self, other): 

108 """ 

109 This allows us to write: 

110 

111 DimsSchema("layer", "y", "x") | DimsSchema("layer") 

112 

113 And get a SchemaUnion back. 

114 """ 

115 return SchemaUnion(self, other) 

116 

117 

118# SchemaType = TypeVar("SchemaType", bound=BaseSchema) 

119SchemaType: TypeAlias = BaseSchema 

120 

121 

122class SchemaUnion: 

123 """ 

124 Succesful validation only requires a single succes. 

125 

126 Used to validate multiple options. 

127 """ 

128 

129 def __init__(self, *args): 

130 ntypes = len({type(arg) for arg in args}) 

131 if ntypes > 1: 

132 raise TypeError("schemata in a union should have the same type") 

133 self.schemata = tuple(args) 

134 

135 def validate(self, obj: Any, **kwargs): 

136 errors = [] 

137 for schema in self.schemata: 

138 try: 

139 schema.validate(obj, **kwargs) 

140 except ValidationError as e: 

141 errors.append(e) 

142 

143 if len(errors) == len(self.schemata): # All schemata failed 

144 message = "\n\t" + "\n\t".join(str(error) for error in errors) 

145 raise ValidationError(f"No option succeeded:{message}") 

146 

147 def __or__(self, other): 

148 return SchemaUnion(*self.schemata, other) 

149 

150 

151class OptionSchema(BaseSchema): 

152 """ 

153 Check whether the value is one of given valid options. 

154 """ 

155 

156 def __init__(self, options: Sequence[Any]): 

157 self.options = options 

158 

159 def validate(self, obj: ScalarAsDataArray, **kwargs) -> None: 

160 if scalar_None(obj): 

161 return 

162 

163 # MODFLOW 6 is not case sensitive for string options. 

164 value = obj.item() 

165 if isinstance(value, str): 

166 value = value.lower() 

167 

168 if value not in self.options: 

169 valid_values = ", ".join(map(str, self.options)) 

170 raise ValidationError( 

171 f"Invalid option: {value}. Valid options are: {valid_values}" 

172 ) 

173 

174 

175class DTypeSchema(BaseSchema): 

176 def __init__(self, dtype: DTypeLike) -> None: 

177 if dtype in [ 

178 np.floating, 

179 np.integer, 

180 np.signedinteger, 

181 np.unsignedinteger, 

182 np.generic, 

183 ]: 

184 self.dtype = dtype 

185 else: 

186 self.dtype = np.dtype(dtype) 

187 

188 def validate(self, obj: GridDataArray, **kwargs) -> None: 

189 """ 

190 Validate dtype 

191 

192 Parameters 

193 ---------- 

194 dtype : Any 

195 Dtype of the DataArray. 

196 """ 

197 if scalar_None(obj): 

198 return 

199 

200 if not np.issubdtype(obj.dtype, self.dtype): 

201 raise ValidationError(f"dtype {obj.dtype} != {self.dtype}") 

202 

203 

204class DimsSchema(BaseSchema): 

205 def __init__(self, *dims: DimsT) -> None: 

206 self.dims = dims 

207 

208 def _fill_in_face_dim(self, obj: Union[xr.DataArray, xu.UgridDataArray]): 

209 """ 

210 Return dims with a filled in face dim if necessary. 

211 """ 

212 if "{face_dim}" in self.dims and isinstance(obj, xu.UgridDataArray): 

213 return tuple( 

214 ( 

215 obj.ugrid.grid.face_dimension if i == "{face_dim}" else i 

216 for i in self.dims 

217 ) 

218 ) 

219 elif "{edge_dim}" in self.dims and isinstance(obj, xu.UgridDataArray): 

220 return tuple( 

221 ( 

222 obj.ugrid.grid.edge_dimension if i == "{edge_dim}" else i 

223 for i in self.dims 

224 ) 

225 ) 

226 else: 

227 return self.dims 

228 

229 def validate(self, obj: GridDataArray, **kwargs) -> None: 

230 """Validate dimensions 

231 Parameters 

232 ---------- 

233 dims : Tuple[Union[str, None]] 

234 Dimensions of the DataArray. `None` may be used as a wildcard value. 

235 """ 

236 dims = self._fill_in_face_dim(obj) 

237 # Force to tuple for error message print 

238 expected = tuple(dims) 

239 actual = tuple(obj.dims) 

240 if actual != expected: 

241 raise ValidationError(f"dim mismatch: expected {expected}, got {actual}") 

242 

243 

244class EmptyIndexesSchema(BaseSchema): 

245 """ 

246 Verify indexes, check if no dims with zero size are included. Skips 

247 unstructured grid dimensions. 

248 """ 

249 

250 def __init__(self) -> None: 

251 pass 

252 

253 def get_dims_to_validate(self, obj: Union[xr.DataArray, xu.UgridDataArray]): 

254 dims_to_validate = list(obj.dims) 

255 

256 # Remove face dim from list to validate, as it has no ``indexes`` 

257 # attribute. 

258 if isinstance(obj, xu.UgridDataArray): 

259 ugrid_dims = obj.ugrid.grid.dimensions 

260 dims_to_validate = [ 

261 dim for dim in dims_to_validate if dim not in ugrid_dims 

262 ] 

263 return dims_to_validate 

264 

265 def validate(self, obj: GridDataArray, **kwargs) -> None: 

266 dims_to_validate = self.get_dims_to_validate(obj) 

267 

268 for dim in dims_to_validate: 

269 if len(obj.indexes[dim]) == 0: 

270 raise ValidationError(f"provided dimension {dim} with size 0") 

271 

272 

273class IndexesSchema(EmptyIndexesSchema): 

274 """ 

275 Verify indexes, check if no dims with zero size are included and that 

276 indexes are monotonic. Skips unstructured grid dimensions. 

277 """ 

278 

279 def __init__(self) -> None: 

280 pass 

281 

282 def validate(self, obj: GridDataArray, **kwargs) -> None: 

283 # Test if indexes all empty 

284 super().validate(obj) 

285 

286 dims_to_validate = self.get_dims_to_validate(obj) 

287 

288 for dim in dims_to_validate: 

289 if dim == "y": 

290 if not obj.indexes[dim].is_monotonic_decreasing: 

291 raise ValidationError( 

292 f"coord {dim} which is not monotonically decreasing" 

293 ) 

294 

295 else: 

296 if not obj.indexes[dim].is_monotonic_increasing: 

297 raise ValidationError( 

298 f"coord {dim} which is not monotonically increasing" 

299 ) 

300 

301 

302class ShapeSchema(BaseSchema): 

303 def __init__(self, shape: ShapeT) -> None: 

304 """ 

305 Validate shape. 

306 

307 Parameters 

308 ---------- 

309 shape : ShapeT 

310 Shape of the DataArray. `None` may be used as a wildcard value. 

311 """ 

312 self.shape = shape 

313 

314 def validate(self, obj: GridDataArray, **kwargs) -> None: 

315 if len(self.shape) != len(obj.shape): 

316 raise ValidationError( 

317 f"number of dimensions in shape ({len(obj.shape)}) o!= da.ndim ({len(self.shape)})" 

318 ) 

319 

320 for i, (actual, expected) in enumerate(zip(obj.shape, self.shape)): 

321 if expected is not None and actual != expected: 

322 raise ValidationError( 

323 f"shape mismatch in axis {i}: {actual} != {expected}" 

324 ) 

325 

326 

327class CompatibleSettingsSchema(BaseSchema): 

328 def __init__(self, other: ScalarAsDataArray, other_value: bool) -> None: 

329 """ 

330 Validate if settings are compatible 

331 """ 

332 self.other = other 

333 self.other_value = other_value 

334 

335 def validate(self, obj: ScalarAsDataArray, **kwargs) -> None: 

336 other_obj = kwargs[self.other] 

337 if scalar_None(obj) or scalar_None(other_obj): 

338 return 

339 expected = np.all(other_obj == self.other_value) 

340 

341 if obj and not expected: 

342 raise ValidationError( 

343 f"Incompatible setting: {self.other} should be {self.other_value}" 

344 ) 

345 

346 

347class CoordsSchema(BaseSchema): 

348 """ 

349 Validate presence of coords. 

350 

351 Parameters 

352 ---------- 

353 coords : dict_like 

354 coords of the DataArray. `None` may be used as a wildcard value. 

355 """ 

356 

357 def __init__( 

358 self, 

359 coords: Tuple[str, ...], 

360 require_all_keys: bool = True, 

361 allow_extra_keys: bool = True, 

362 ) -> None: 

363 self.coords = coords 

364 self.require_all_keys = require_all_keys 

365 self.allow_extra_keys = allow_extra_keys 

366 

367 def validate(self, obj: GridDataArray, **kwargs) -> None: 

368 coords = list(obj.coords.keys()) 

369 

370 if self.require_all_keys: 

371 missing_keys = set(self.coords) - set(coords) 

372 if missing_keys: 

373 raise ValidationError(f"coords has missing keys: {missing_keys}") 

374 

375 if not self.allow_extra_keys: 

376 extra_keys = set(coords) - set(self.coords) 

377 if extra_keys: 

378 raise ValidationError(f"coords has extra keys: {extra_keys}") 

379 

380 for key in self.coords: 

381 if key not in coords: 

382 raise ValidationError(f"key {key} not in coords") 

383 

384 

385class OtherCoordsSchema(BaseSchema): 

386 """ 

387 Validate whether coordinates match those of other. 

388 """ 

389 

390 def __init__( 

391 self, 

392 other: str, 

393 require_all_keys: bool = True, 

394 allow_extra_keys: bool = True, 

395 ): 

396 self.other = other 

397 self.require_all_keys = require_all_keys 

398 self.allow_extra_keys = allow_extra_keys 

399 

400 def validate(self, obj: GridDataArray, **kwargs) -> None: 

401 other_obj = kwargs[self.other] 

402 other_coords = list(other_obj.coords.keys()) 

403 return CoordsSchema( 

404 other_coords, 

405 self.require_all_keys, 

406 self.allow_extra_keys, 

407 ).validate(obj) 

408 

409 

410class ValueSchema(BaseSchema, abc.ABC): 

411 """ 

412 Base class for AllValueSchema or AnyValueSchema. 

413 """ 

414 

415 def __init__( 

416 self, 

417 operator: str, 

418 other: Any, 

419 ignore: Optional[Tuple[str, str, Any]] = None, 

420 ): 

421 self.operator = OPERATORS[operator] 

422 self.operator_str = operator 

423 self.other = other 

424 self.to_ignore = None 

425 self.ignore_varname = None 

426 

427 if ignore: 

428 self.ignore_varname = ignore[0] 

429 self.to_ignore = partial_operator(ignore[1], ignore[2]) 

430 

431 def get_explicitly_ignored(self, kwargs: Dict) -> Any: 

432 """ 

433 Get cells that should be explicitly ignored by the schema 

434 """ 

435 if self.to_ignore: 

436 ignore_obj = kwargs[self.ignore_varname] 

437 return self.to_ignore(ignore_obj) 

438 else: 

439 return False 

440 

441 

442class AllValueSchema(ValueSchema): 

443 """ 

444 Validate whether all values pass a condition. 

445 

446 E.g. if operator is ">": 

447 

448 assert (values > threshold).all() 

449 """ 

450 

451 def validate(self, obj: GridDataArray, **kwargs) -> None: 

452 if isinstance(self.other, str): 

453 other_obj = kwargs[self.other] 

454 else: 

455 other_obj = self.other 

456 

457 if scalar_None(obj) or scalar_None(other_obj): 

458 return 

459 

460 explicitly_ignored = self.get_explicitly_ignored(kwargs) 

461 

462 ignore = ( 

463 np.isnan(obj) | np.isnan(other_obj) | explicitly_ignored 

464 ) # ignore nan by setting to True 

465 

466 condition = self.operator(obj, other_obj) 

467 condition = condition | ignore 

468 if not condition.all(): 

469 raise ValidationError( 

470 f"not all values comply with criterion: {self.operator_str} {self.other}" 

471 ) 

472 

473 

474class AnyValueSchema(ValueSchema): 

475 """ 

476 Validate whether any value passes a condition. 

477 

478 E.g. if operator is ">": 

479 

480 assert (values > threshold).any() 

481 """ 

482 

483 def validate(self, obj: GridDataArray, **kwargs) -> None: 

484 if isinstance(self.other, str): 

485 other_obj = kwargs[self.other] 

486 else: 

487 other_obj = self.other 

488 

489 if scalar_None(obj) or scalar_None(other_obj): 

490 return 

491 

492 explicitly_ignored = self.get_explicitly_ignored(kwargs) 

493 

494 ignore = ( 

495 ~np.isnan(obj) | ~np.isnan(other_obj) | explicitly_ignored 

496 ) # ignore nan by setting to False 

497 

498 condition = self.operator(obj, other_obj) 

499 condition = condition | ignore 

500 if not condition.any(): 

501 raise ValidationError( 

502 f"not a single value complies with criterion: {self.operator_str} {self.other}" 

503 ) 

504 

505 

506def _notnull(obj): 

507 """ 

508 Helper function; does the same as xr.DataArray.notnull. This function is to 

509 avoid an issue where xr.DataArray.notnull() returns ordinary numpy arrays 

510 for instances of xu.UgridDataArray. 

511 """ 

512 

513 return ~np.isnan(obj) 

514 

515 

516class NoDataSchema(BaseSchema): 

517 def __init__( 

518 self, 

519 is_notnull: Union[Callable, Tuple[str, Any]] = _notnull, 

520 ): 

521 if isinstance(is_notnull, tuple): 

522 op, value = is_notnull 

523 self.is_notnull = partial_operator(op, value) 

524 else: 

525 self.is_notnull = is_notnull 

526 

527 

528class AllNoDataSchema(NoDataSchema): 

529 """ 

530 Fails when all data is NoData. 

531 """ 

532 

533 def validate(self, obj: GridDataArray, **kwargs) -> None: 

534 valid = self.is_notnull(obj) 

535 if ~valid.any(): 

536 raise ValidationError("all nodata") 

537 

538 

539class AnyNoDataSchema(NoDataSchema): 

540 """ 

541 Fails when any data is NoData. 

542 """ 

543 

544 def validate(self, obj: GridDataArray, **kwargs) -> None: 

545 valid = self.is_notnull(obj) 

546 if ~valid.all(): 

547 raise ValidationError("found a nodata value") 

548 

549 

550class NoDataComparisonSchema(BaseSchema): 

551 """ 

552 Base class for IdentityNoDataSchema and AllInsideNoDataSchema. 

553 """ 

554 

555 def __init__( 

556 self, 

557 other: str, 

558 is_notnull: Union[Callable, Tuple[str, Any]] = _notnull, 

559 is_other_notnull: Union[Callable, Tuple[str, Any]] = _notnull, 

560 ): 

561 self.other = other 

562 if isinstance(is_notnull, tuple): 

563 op, value = is_notnull 

564 self.is_notnull = partial_operator(op, value) 

565 else: 

566 self.is_notnull = is_notnull 

567 

568 if isinstance(is_other_notnull, tuple): 

569 op, value = is_other_notnull 

570 self.is_other_notnull = partial_operator(op, value) 

571 else: 

572 self.is_other_notnull = is_other_notnull 

573 

574 

575class IdentityNoDataSchema(NoDataComparisonSchema): 

576 """ 

577 Checks that the NoData values are located at exactly the same locations. 

578 

579 Tests only if if all dimensions of the other object are present in the 

580 object. So tests if "stage" with `{time, layer, y, x}` compared to "idomain" 

581 `{layer, y, x}` but doesn't test if "k" with `{layer}` is comperated to 

582 "idomain" `{layer, y, x}` 

583 """ 

584 

585 def validate(self, obj: GridDataArray, **kwargs) -> None: 

586 other_obj = kwargs[self.other] 

587 

588 # Only test if object has all dimensions in other object. 

589 missing_dims = set(other_obj.dims) - set(obj.dims) 

590 

591 if len(missing_dims) == 0: 

592 valid = self.is_notnull(obj) 

593 other_valid = self.is_other_notnull(other_obj) 

594 if (valid ^ other_valid).any(): 

595 raise ValidationError(f"nodata is not aligned with {self.other}") 

596 

597 

598class AllInsideNoDataSchema(NoDataComparisonSchema): 

599 """ 

600 Checks that all notnull values all occur within the notnull values of other. 

601 """ 

602 

603 def validate(self, obj: GridDataArray, **kwargs) -> None: 

604 other_obj = kwargs[self.other] 

605 valid = self.is_notnull(obj) 

606 other_valid = self.is_other_notnull(other_obj) 

607 

608 valid, other_valid = align_other_obj_with_coords(valid, other_obj) 

609 

610 if (valid & ~other_valid).any(): 

611 raise ValidationError(f"data values found at nodata values of {self.other}") 

612 

613 

614class ActiveCellsConnectedSchema(BaseSchema): 

615 """ 

616 Check if active cells are connected, to avoid isolated islands which can 

617 cause convergence issues, if they don't have a head boundary condition, but 

618 do have a specified flux. 

619 

620 Note 

621 ---- 

622 This schema only works for structured grids. 

623 """ 

624 

625 def __init__( 

626 self, 

627 is_notnull: Union[Callable, Tuple[str, Any]] = _notnull, 

628 ): 

629 if isinstance(is_notnull, tuple): 

630 op, value = is_notnull 

631 self.is_notnull = partial_operator(op, value) 

632 else: 

633 self.is_notnull = is_notnull 

634 

635 def validate(self, obj: GridDataArray, **kwargs) -> None: 

636 if isinstance(obj, xu.UgridDataArray): 

637 # TODO: https://deltares.github.io/xugrid/api/xugrid.UgridDataArrayAccessor.connected_components.html 

638 raise NotImplementedError( 

639 f"Schema {self.__name__} only works for structured grids, received xu.UgridDataArray." 

640 ) 

641 

642 active = self.is_notnull(obj) 

643 

644 _, nlabels = scipy.ndimage.label(active) 

645 if nlabels > 1: 

646 raise ValidationError( 

647 f"{nlabels} disconnected areas detected in model domain" 

648 )