Coverage for C:\src\imod-python\imod\schemata.py: 91%

250 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-08 13:27 +0200

1""" 

2Schemata to help validation of input. 

3 

4This code is based on: https://github.com/carbonplan/xarray-schema 

5 

6which has the following MIT license: 

7 

8 MIT License 

9 

10 Copyright (c) 2021 carbonplan 

11 

12 Permission is hereby granted, free of charge, to any person obtaining a copy 

13 of this software and associated documentation files (the "Software"), to deal 

14 in the Software without restriction, including without limitation the rights 

15 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

16 copies of the Software, and to permit persons to whom the Software is 

17 furnished to do so, subject to the following conditions: 

18 

19 The above copyright notice and this permission notice shall be included in all 

20 copies or substantial portions of the Software. 

21 

22 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

23 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

24 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

25 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

26 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

27 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

28 SOFTWARE. 

29 

30In the future, we may be able to replace this module by whatever the best 

31validation xarray library becomes. 

32""" 

33 

34import abc 

35import operator 

36from functools import partial 

37from typing import Any, Callable, Dict, Optional, Tuple, TypeAlias, Union 

38 

39import numpy as np 

40import scipy 

41import xarray as xr 

42import xugrid as xu 

43from numpy.typing import DTypeLike # noqa: F401 

44 

45from imod.typing import GridDataArray, ScalarAsDataArray 

46 

47DimsT = Union[str, None] 

48ShapeT = Tuple[Union[int, None]] 

49ChunksT = Union[bool, Dict[str, Union[int, None]]] 

50 

51OPERATORS = { 

52 "<": operator.lt, 

53 "<=": operator.le, 

54 "==": operator.eq, 

55 "!=": operator.ne, 

56 ">=": operator.ge, 

57 ">": operator.gt, 

58} 

59 

60 

61def partial_operator(op, value): 

62 # partial doesn't allow us to insert the 1st arg on call, and 

63 # operators don't work with kwargs, so resort to lambda to swap 

64 # args a and b around. 

65 # https://stackoverflow.com/a/37468215 

66 return partial(lambda b, a: OPERATORS[op](a, b), value) 

67 

68 

69def scalar_None(obj): 

70 """ 

71 Test if object is a scalar None DataArray, which is the default value for optional 

72 variables. 

73 """ 

74 if not isinstance(obj, (xr.DataArray, xu.UgridDataArray)): 

75 return False 

76 else: 

77 return (len(obj.shape) == 0) & (~obj.notnull()).all() 

78 

79 

80def align_other_obj_with_coords( 

81 obj: GridDataArray, other_obj: GridDataArray 

82) -> Tuple[xr.DataArray, xr.DataArray]: 

83 """ 

84 Align other_obj with obj if coordname in obj but not in its dims. 

85 Avoid issues like: 

86 https://github.com/Deltares/imod-python/issues/830 

87 

88 """ 

89 for coordname in obj.coords.keys(): 

90 if (coordname in other_obj.dims) and not (coordname in obj.dims): 

91 obj = obj.expand_dims(coordname) 

92 # Note: 

93 # xr.align forces xu.UgridDataArray to xr.DataArray. Keep that in mind 

94 # in further data processing. 

95 return xr.align(obj, other_obj, join="left") 

96 

97 

98class ValidationError(Exception): 

99 pass 

100 

101 

102class BaseSchema(abc.ABC): 

103 @abc.abstractmethod 

104 def validate(self, obj: GridDataArray, **kwargs) -> None: 

105 pass 

106 

107 def __or__(self, other): 

108 """ 

109 This allows us to write: 

110 

111 DimsSchema("layer", "y", "x") | DimsSchema("layer") 

112 

113 And get a SchemaUnion back. 

114 """ 

115 return SchemaUnion(self, other) 

116 

117 

118# SchemaType = TypeVar("SchemaType", bound=BaseSchema) 

119SchemaType: TypeAlias = BaseSchema 

120 

121 

122class SchemaUnion: 

123 """ 

124 Succesful validation only requires a single succes. 

125 

126 Used to validate multiple options. 

127 """ 

128 

129 def __init__(self, *args): 

130 ntypes = len(set(type(arg) for arg in args)) 

131 if ntypes > 1: 

132 raise TypeError("schemata in a union should have the same type") 

133 self.schemata = tuple(args) 

134 

135 def validate(self, obj: Any, **kwargs): 

136 errors = [] 

137 for schema in self.schemata: 

138 try: 

139 schema.validate(obj, **kwargs) 

140 except ValidationError as e: 

141 errors.append(e) 

142 

143 if len(errors) == len(self.schemata): # All schemata failed 

144 message = "\n\t" + "\n\t".join(str(error) for error in errors) 

145 raise ValidationError(f"No option succeeded:{message}") 

146 

147 def __or__(self, other): 

148 return SchemaUnion(*self.schemata, other) 

149 

150 

151class DTypeSchema(BaseSchema): 

152 def __init__(self, dtype: DTypeLike) -> None: 

153 if dtype in [ 

154 np.floating, 

155 np.integer, 

156 np.signedinteger, 

157 np.unsignedinteger, 

158 np.generic, 

159 ]: 

160 self.dtype = dtype 

161 else: 

162 self.dtype = np.dtype(dtype) 

163 

164 def validate(self, obj: GridDataArray, **kwargs) -> None: 

165 """ 

166 Validate dtype 

167 

168 Parameters 

169 ---------- 

170 dtype : Any 

171 Dtype of the DataArray. 

172 """ 

173 if scalar_None(obj): 

174 return 

175 

176 if not np.issubdtype(obj.dtype, self.dtype): 

177 raise ValidationError(f"dtype {obj.dtype} != {self.dtype}") 

178 

179 

180class DimsSchema(BaseSchema): 

181 def __init__(self, *dims: DimsT) -> None: 

182 self.dims = dims 

183 

184 def _fill_in_face_dim(self, obj: Union[xr.DataArray, xu.UgridDataArray]): 

185 """ 

186 Return dims with a filled in face dim if necessary. 

187 """ 

188 if "{face_dim}" in self.dims and isinstance(obj, xu.UgridDataArray): 

189 return tuple( 

190 ( 

191 obj.ugrid.grid.face_dimension if i == "{face_dim}" else i 

192 for i in self.dims 

193 ) 

194 ) 

195 elif "{edge_dim}" in self.dims and isinstance(obj, xu.UgridDataArray): 

196 return tuple( 

197 ( 

198 obj.ugrid.grid.edge_dimension if i == "{edge_dim}" else i 

199 for i in self.dims 

200 ) 

201 ) 

202 else: 

203 return self.dims 

204 

205 def validate(self, obj: GridDataArray, **kwargs) -> None: 

206 """Validate dimensions 

207 Parameters 

208 ---------- 

209 dims : Tuple[Union[str, None]] 

210 Dimensions of the DataArray. `None` may be used as a wildcard value. 

211 """ 

212 dims = self._fill_in_face_dim(obj) 

213 # Force to tuple for error message print 

214 expected = tuple(dims) 

215 actual = tuple(obj.dims) 

216 if actual != expected: 

217 raise ValidationError(f"dim mismatch: expected {expected}, got {actual}") 

218 

219 

220class EmptyIndexesSchema(BaseSchema): 

221 """ 

222 Verify indexes, check if no dims with zero size are included. Skips 

223 unstructured grid dimensions. 

224 """ 

225 

226 def __init__(self) -> None: 

227 pass 

228 

229 def get_dims_to_validate(self, obj: Union[xr.DataArray, xu.UgridDataArray]): 

230 dims_to_validate = list(obj.dims) 

231 

232 # Remove face dim from list to validate, as it has no ``indexes`` 

233 # attribute. 

234 if isinstance(obj, xu.UgridDataArray): 

235 ugrid_dims = obj.ugrid.grid.dimensions 

236 dims_to_validate = [ 

237 dim for dim in dims_to_validate if dim not in ugrid_dims 

238 ] 

239 return dims_to_validate 

240 

241 def validate(self, obj: GridDataArray, **kwargs) -> None: 

242 dims_to_validate = self.get_dims_to_validate(obj) 

243 

244 for dim in dims_to_validate: 

245 if len(obj.indexes[dim]) == 0: 

246 raise ValidationError(f"provided dimension {dim} with size 0") 

247 

248 

249class IndexesSchema(EmptyIndexesSchema): 

250 """ 

251 Verify indexes, check if no dims with zero size are included and that 

252 indexes are monotonic. Skips unstructured grid dimensions. 

253 """ 

254 

255 def __init__(self) -> None: 

256 pass 

257 

258 def validate(self, obj: GridDataArray, **kwargs) -> None: 

259 # Test if indexes all empty 

260 super().validate(obj) 

261 

262 dims_to_validate = self.get_dims_to_validate(obj) 

263 

264 for dim in dims_to_validate: 

265 if dim == "y": 

266 if not obj.indexes[dim].is_monotonic_decreasing: 

267 raise ValidationError( 

268 f"coord {dim} which is not monotonically decreasing" 

269 ) 

270 

271 else: 

272 if not obj.indexes[dim].is_monotonic_increasing: 

273 raise ValidationError( 

274 f"coord {dim} which is not monotonically increasing" 

275 ) 

276 

277 

278class ShapeSchema(BaseSchema): 

279 def __init__(self, shape: ShapeT) -> None: 

280 """ 

281 Validate shape. 

282 

283 Parameters 

284 ---------- 

285 shape : ShapeT 

286 Shape of the DataArray. `None` may be used as a wildcard value. 

287 """ 

288 self.shape = shape 

289 

290 def validate(self, obj: GridDataArray, **kwargs) -> None: 

291 if len(self.shape) != len(obj.shape): 

292 raise ValidationError( 

293 f"number of dimensions in shape ({len(obj.shape)}) o!= da.ndim ({len(self.shape)})" 

294 ) 

295 

296 for i, (actual, expected) in enumerate(zip(obj.shape, self.shape)): 

297 if expected is not None and actual != expected: 

298 raise ValidationError( 

299 f"shape mismatch in axis {i}: {actual} != {expected}" 

300 ) 

301 

302 

303class CompatibleSettingsSchema(BaseSchema): 

304 def __init__(self, other: ScalarAsDataArray, other_value: bool) -> None: 

305 """ 

306 Validate if settings are compatible 

307 """ 

308 self.other = other 

309 self.other_value = other_value 

310 

311 def validate(self, obj: ScalarAsDataArray, **kwargs) -> None: 

312 other_obj = kwargs[self.other] 

313 if scalar_None(obj) or scalar_None(other_obj): 

314 return 

315 expected = np.all(other_obj == self.other_value) 

316 

317 if obj and not expected: 

318 raise ValidationError( 

319 f"Incompatible setting: {self.other} should be {self.other_value}" 

320 ) 

321 

322 

323class CoordsSchema(BaseSchema): 

324 """ 

325 Validate presence of coords. 

326 

327 Parameters 

328 ---------- 

329 coords : dict_like 

330 coords of the DataArray. `None` may be used as a wildcard value. 

331 """ 

332 

333 def __init__( 

334 self, 

335 coords: Tuple[str, ...], 

336 require_all_keys: bool = True, 

337 allow_extra_keys: bool = True, 

338 ) -> None: 

339 self.coords = coords 

340 self.require_all_keys = require_all_keys 

341 self.allow_extra_keys = allow_extra_keys 

342 

343 def validate(self, obj: GridDataArray, **kwargs) -> None: 

344 coords = list(obj.coords.keys()) 

345 

346 if self.require_all_keys: 

347 missing_keys = set(self.coords) - set(coords) 

348 if missing_keys: 

349 raise ValidationError(f"coords has missing keys: {missing_keys}") 

350 

351 if not self.allow_extra_keys: 

352 extra_keys = set(coords) - set(self.coords) 

353 if extra_keys: 

354 raise ValidationError(f"coords has extra keys: {extra_keys}") 

355 

356 for key in self.coords: 

357 if key not in coords: 

358 raise ValidationError(f"key {key} not in coords") 

359 

360 

361class OtherCoordsSchema(BaseSchema): 

362 """ 

363 Validate whether coordinates match those of other. 

364 """ 

365 

366 def __init__( 

367 self, 

368 other: str, 

369 require_all_keys: bool = True, 

370 allow_extra_keys: bool = True, 

371 ): 

372 self.other = other 

373 self.require_all_keys = require_all_keys 

374 self.allow_extra_keys = allow_extra_keys 

375 

376 def validate(self, obj: GridDataArray, **kwargs) -> None: 

377 other_obj = kwargs[self.other] 

378 other_coords = list(other_obj.coords.keys()) 

379 return CoordsSchema( 

380 other_coords, 

381 self.require_all_keys, 

382 self.allow_extra_keys, 

383 ).validate(obj) 

384 

385 

386class ValueSchema(BaseSchema, abc.ABC): 

387 """ 

388 Base class for AllValueSchema or AnyValueSchema. 

389 """ 

390 

391 def __init__( 

392 self, 

393 operator: str, 

394 other: Any, 

395 ignore: Optional[Tuple[str, str, Any]] = None, 

396 ): 

397 self.operator = OPERATORS[operator] 

398 self.operator_str = operator 

399 self.other = other 

400 self.to_ignore = None 

401 self.ignore_varname = None 

402 

403 if ignore: 

404 self.ignore_varname = ignore[0] 

405 self.to_ignore = partial_operator(ignore[1], ignore[2]) 

406 

407 def get_explicitly_ignored(self, kwargs: Dict) -> Any: 

408 """ 

409 Get cells that should be explicitly ignored by the schema 

410 """ 

411 if self.to_ignore: 

412 ignore_obj = kwargs[self.ignore_varname] 

413 return self.to_ignore(ignore_obj) 

414 else: 

415 return False 

416 

417 

418class AllValueSchema(ValueSchema): 

419 """ 

420 Validate whether all values pass a condition. 

421 

422 E.g. if operator is ">": 

423 

424 assert (values > threshold).all() 

425 """ 

426 

427 def validate(self, obj: GridDataArray, **kwargs) -> None: 

428 if isinstance(self.other, str): 

429 other_obj = kwargs[self.other] 

430 else: 

431 other_obj = self.other 

432 

433 if scalar_None(obj) or scalar_None(other_obj): 

434 return 

435 

436 explicitly_ignored = self.get_explicitly_ignored(kwargs) 

437 

438 ignore = ( 

439 np.isnan(obj) | np.isnan(other_obj) | explicitly_ignored 

440 ) # ignore nan by setting to True 

441 

442 condition = self.operator(obj, other_obj) 

443 condition = condition | ignore 

444 if not condition.all(): 

445 raise ValidationError( 

446 f"not all values comply with criterion: {self.operator_str} {self.other}" 

447 ) 

448 

449 

450class AnyValueSchema(ValueSchema): 

451 """ 

452 Validate whether any value passes a condition. 

453 

454 E.g. if operator is ">": 

455 

456 assert (values > threshold).any() 

457 """ 

458 

459 def validate(self, obj: GridDataArray, **kwargs) -> None: 

460 if isinstance(self.other, str): 

461 other_obj = kwargs[self.other] 

462 else: 

463 other_obj = self.other 

464 

465 if scalar_None(obj) or scalar_None(other_obj): 

466 return 

467 

468 explicitly_ignored = self.get_explicitly_ignored(kwargs) 

469 

470 ignore = ( 

471 ~np.isnan(obj) | ~np.isnan(other_obj) | explicitly_ignored 

472 ) # ignore nan by setting to False 

473 

474 condition = self.operator(obj, other_obj) 

475 condition = condition | ignore 

476 if not condition.any(): 

477 raise ValidationError( 

478 f"not a single value complies with criterion: {self.operator_str} {self.other}" 

479 ) 

480 

481 

482def _notnull(obj): 

483 """ 

484 Helper function; does the same as xr.DataArray.notnull. This function is to 

485 avoid an issue where xr.DataArray.notnull() returns ordinary numpy arrays 

486 for instances of xu.UgridDataArray. 

487 """ 

488 

489 return ~np.isnan(obj) 

490 

491 

492class NoDataSchema(BaseSchema): 

493 def __init__( 

494 self, 

495 is_notnull: Union[Callable, Tuple[str, Any]] = _notnull, 

496 ): 

497 if isinstance(is_notnull, tuple): 

498 op, value = is_notnull 

499 self.is_notnull = partial_operator(op, value) 

500 else: 

501 self.is_notnull = is_notnull 

502 

503 

504class AllNoDataSchema(NoDataSchema): 

505 """ 

506 Fails when all data is NoData. 

507 """ 

508 

509 def validate(self, obj: GridDataArray, **kwargs) -> None: 

510 valid = self.is_notnull(obj) 

511 if ~valid.any(): 

512 raise ValidationError("all nodata") 

513 

514 

515class AnyNoDataSchema(NoDataSchema): 

516 """ 

517 Fails when any data is NoData. 

518 """ 

519 

520 def validate(self, obj: GridDataArray, **kwargs) -> None: 

521 valid = self.is_notnull(obj) 

522 if ~valid.all(): 

523 raise ValidationError("found a nodata value") 

524 

525 

526class NoDataComparisonSchema(BaseSchema): 

527 """ 

528 Base class for IdentityNoDataSchema and AllInsideNoDataSchema. 

529 """ 

530 

531 def __init__( 

532 self, 

533 other: str, 

534 is_notnull: Union[Callable, Tuple[str, Any]] = _notnull, 

535 is_other_notnull: Union[Callable, Tuple[str, Any]] = _notnull, 

536 ): 

537 self.other = other 

538 if isinstance(is_notnull, tuple): 

539 op, value = is_notnull 

540 self.is_notnull = partial_operator(op, value) 

541 else: 

542 self.is_notnull = is_notnull 

543 

544 if isinstance(is_other_notnull, tuple): 

545 op, value = is_other_notnull 

546 self.is_other_notnull = partial_operator(op, value) 

547 else: 

548 self.is_other_notnull = is_other_notnull 

549 

550 

551class IdentityNoDataSchema(NoDataComparisonSchema): 

552 """ 

553 Checks that the NoData values are located at exactly the same locations. 

554 

555 Tests only if if all dimensions of the other object are present in the 

556 object. So tests if "stage" with `{time, layer, y, x}` compared to "idomain" 

557 `{layer, y, x}` but doesn't test if "k" with `{layer}` is comperated to 

558 "idomain" `{layer, y, x}` 

559 """ 

560 

561 def validate(self, obj: GridDataArray, **kwargs) -> None: 

562 other_obj = kwargs[self.other] 

563 

564 # Only test if object has all dimensions in other object. 

565 missing_dims = set(other_obj.dims) - set(obj.dims) 

566 

567 if len(missing_dims) == 0: 

568 valid = self.is_notnull(obj) 

569 other_valid = self.is_other_notnull(other_obj) 

570 if (valid ^ other_valid).any(): 

571 raise ValidationError(f"nodata is not aligned with {self.other}") 

572 

573 

574class AllInsideNoDataSchema(NoDataComparisonSchema): 

575 """ 

576 Checks that all notnull values all occur within the notnull values of other. 

577 """ 

578 

579 def validate(self, obj: GridDataArray, **kwargs) -> None: 

580 other_obj = kwargs[self.other] 

581 valid = self.is_notnull(obj) 

582 other_valid = self.is_other_notnull(other_obj) 

583 

584 valid, other_valid = align_other_obj_with_coords(valid, other_obj) 

585 

586 if (valid & ~other_valid).any(): 

587 raise ValidationError(f"data values found at nodata values of {self.other}") 

588 

589 

590class ActiveCellsConnectedSchema(BaseSchema): 

591 """ 

592 Check if active cells are connected, to avoid isolated islands which can 

593 cause convergence issues, if they don't have a head boundary condition, but 

594 do have a specified flux. 

595 

596 Note 

597 ---- 

598 This schema only works for structured grids. 

599 """ 

600 

601 def __init__( 

602 self, 

603 is_notnull: Union[Callable, Tuple[str, Any]] = _notnull, 

604 ): 

605 if isinstance(is_notnull, tuple): 

606 op, value = is_notnull 

607 self.is_notnull = partial_operator(op, value) 

608 else: 

609 self.is_notnull = is_notnull 

610 

611 def validate(self, obj: GridDataArray, **kwargs) -> None: 

612 if isinstance(obj, xu.UgridDataArray): 

613 # TODO: https://deltares.github.io/xugrid/api/xugrid.UgridDataArrayAccessor.connected_components.html 

614 raise NotImplementedError( 

615 f"Schema {self.__name__} only works for structured grids, received xu.UgridDataArray." 

616 ) 

617 

618 active = self.is_notnull(obj) 

619 

620 _, nlabels = scipy.ndimage.label(active) 

621 if nlabels > 1: 

622 raise ValidationError( 

623 f"{nlabels} disconnected areas detected in model domain" 

624 )