Coverage for muutils\json_serialize\serializable_dataclass.py: 55%

242 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-09 01:48 -0600

1"""save and load objects to and from json or compatible formats in a recoverable way 

2 

3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable, 

4you will get an error when you call `json.dumps(d)`. This module provides a way around that. 

5 

6Instead, you define your class: 

7 

8```python 

9@serializable_dataclass 

10class MyClass(SerializableDataclass): 

11 a: int 

12 b: str 

13``` 

14 

15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 

16 

17 >>> my_obj = MyClass(a=1, b="q") 

18 >>> s = json.dumps(my_obj.serialize()) 

19 >>> s 

20 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 

21 >>> read_obj = MyClass.load(json.loads(s)) 

22 >>> read_obj == my_obj 

23 True 

24 

25This isn't too impressive on its own, but it gets more useful when you have nested classses, 

26or fields that are not json-serializable by default: 

27 

28```python 

29@serializable_dataclass 

30class NestedClass(SerializableDataclass): 

31 x: str 

32 y: MyClass 

33 act_fun: torch.nn.Module = serializable_field( 

34 default=torch.nn.ReLU(), 

35 serialization_fn=lambda x: str(x), 

36 deserialize_fn=lambda x: getattr(torch.nn, x)(), 

37 ) 

38``` 

39 

40which gives us: 

41 

42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 

43 >>> s = json.dumps(nc.serialize()) 

44 >>> s 

45 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 

46 >>> read_nc = NestedClass.load(json.loads(s)) 

47 >>> read_nc == nc 

48 True 

49 

50""" 

51 

52from __future__ import annotations 

53 

54import abc 

55import dataclasses 

56import functools 

57import json 

58import sys 

59import typing 

60import warnings 

61from typing import Any, Optional, Type, TypeVar 

62 

63from muutils.errormode import ErrorMode 

64from muutils.validate_type import validate_type 

65from muutils.json_serialize.serializable_field import ( 

66 SerializableField, 

67 serializable_field, 

68) 

69from muutils.json_serialize.util import array_safe_eq, dc_eq 

70 

71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access 

72 

73 

74def _dataclass_transform_mock( 

75 *, 

76 eq_default: bool = True, 

77 order_default: bool = False, 

78 kw_only_default: bool = False, 

79 frozen_default: bool = False, 

80 field_specifiers: tuple[type[Any] | typing.Callable[..., Any], ...] = (), 

81 **kwargs: Any, 

82) -> typing.Callable: 

83 "mock `typing.dataclass_transform` for python <3.11" 

84 

85 def decorator(cls_or_fn): 

86 cls_or_fn.__dataclass_transform__ = { 

87 "eq_default": eq_default, 

88 "order_default": order_default, 

89 "kw_only_default": kw_only_default, 

90 "frozen_default": frozen_default, 

91 "field_specifiers": field_specifiers, 

92 "kwargs": kwargs, 

93 } 

94 return cls_or_fn 

95 

96 return decorator 

97 

98 

99dataclass_transform: typing.Callable 

100if sys.version_info < (3, 11): 

101 dataclass_transform = _dataclass_transform_mock 

102else: 

103 dataclass_transform = typing.dataclass_transform 

104 

105 

106T = TypeVar("T") 

107 

108 

109class CantGetTypeHintsWarning(UserWarning): 

110 "special warning for when we can't get type hints" 

111 

112 pass 

113 

114 

115class ZanjMissingWarning(UserWarning): 

116 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 

117 

118 pass 

119 

120 

121_zanj_loading_needs_import: bool = True 

122"flag to keep track of if we have successfully imported ZANJ" 

123 

124 

125def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): 

126 """Register a serializable dataclass with the ZANJ import 

127 

128 this allows `ZANJ().read()` to load the class and not just return plain dicts 

129 

130 

131 # TODO: there is some duplication here with register_loader_handler 

132 """ 

133 global _zanj_loading_needs_import 

134 

135 if _zanj_loading_needs_import: 

136 try: 

137 from zanj.loading import ( # type: ignore[import] 

138 LoaderHandler, 

139 register_loader_handler, 

140 ) 

141 except ImportError: 

142 warnings.warn( 

143 "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 

144 ZanjMissingWarning, 

145 ) 

146 return 

147 

148 _format: str = f"{cls.__name__}(SerializableDataclass)" 

149 lh: LoaderHandler = LoaderHandler( 

150 check=lambda json_item, path=None, z=None: ( # type: ignore 

151 isinstance(json_item, dict) 

152 and "__format__" in json_item 

153 and json_item["__format__"].startswith(_format) 

154 ), 

155 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 

156 uid=_format, 

157 source_pckg=cls.__module__, 

158 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 

159 ) 

160 

161 register_loader_handler(lh) 

162 

163 return lh 

164 

165 

166_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN 

167_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT 

168 

169 

170class FieldIsNotInitOrSerializeWarning(UserWarning): 

171 pass 

172 

173 

174def SerializableDataclass__validate_field_type( 

175 self: SerializableDataclass, 

176 field: SerializableField | str, 

177 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

178) -> bool: 

179 """given a dataclass, check the field matches the type hint 

180 

181 this function is written to `SerializableDataclass.validate_field_type` 

182 

183 # Parameters: 

184 - `self : SerializableDataclass` 

185 `SerializableDataclass` instance 

186 - `field : SerializableField | str` 

187 field to validate, will get from `self.__dataclass_fields__` if an `str` 

188 - `on_typecheck_error : ErrorMode` 

189 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 

190 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 

191 

192 # Returns: 

193 - `bool` 

194 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 

195 """ 

196 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

197 

198 # get field 

199 _field: SerializableField 

200 if isinstance(field, str): 

201 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 

202 else: 

203 _field = field 

204 

205 # do nothing case 

206 if not _field.assert_type: 

207 return True 

208 

209 # if field is not `init` or not `serialize`, skip but warn 

210 # TODO: how to handle fields which are not `init` or `serialize`? 

211 if not _field.init or not _field.serialize: 

212 warnings.warn( 

213 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 

214 FieldIsNotInitOrSerializeWarning, 

215 ) 

216 return True 

217 

218 assert isinstance( 

219 _field, SerializableField 

220 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 

221 

222 # get field type hints 

223 try: 

224 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 

225 except KeyError as e: 

226 on_typecheck_error.process( 

227 ( 

228 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 

229 + f"{get_cls_type_hints(self.__class__) = }\n" 

230 + f"Python version is {sys.version_info = }. You can:\n" 

231 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 

232 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 

233 + " - use python 3.9.x or higher\n" 

234 + " - specify custom type validation function via `custom_typecheck_fn`\n" 

235 ), 

236 except_cls=TypeError, 

237 except_from=e, 

238 ) 

239 return False 

240 

241 # get the value 

242 value: Any = getattr(self, _field.name) 

243 

244 # validate the type 

245 try: 

246 type_is_valid: bool 

247 # validate the type with the default type validator 

248 if _field.custom_typecheck_fn is None: 

249 type_is_valid = validate_type(value, field_type_hint) 

250 # validate the type with a custom type validator 

251 else: 

252 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 

253 

254 return type_is_valid 

255 

256 except Exception as e: 

257 on_typecheck_error.process( 

258 "exception while validating type: " 

259 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 

260 except_cls=ValueError, 

261 except_from=e, 

262 ) 

263 return False 

264 

265 

266def SerializableDataclass__validate_fields_types__dict( 

267 self: SerializableDataclass, 

268 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

269) -> dict[str, bool]: 

270 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 

271 

272 returns a dict of field names to bools, where the bool is if the field type is valid 

273 """ 

274 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

275 

276 # if except, bundle the exceptions 

277 results: dict[str, bool] = dict() 

278 exceptions: dict[str, Exception] = dict() 

279 

280 # for each field in the class 

281 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 

282 for field in cls_fields: 

283 try: 

284 results[field.name] = self.validate_field_type(field, on_typecheck_error) 

285 except Exception as e: 

286 results[field.name] = False 

287 exceptions[field.name] = e 

288 

289 # figure out what to do with the exceptions 

290 if len(exceptions) > 0: 

291 on_typecheck_error.process( 

292 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 

293 + "\n\t" 

294 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 

295 except_cls=ValueError, 

296 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 

297 except_from=list(exceptions.values())[0], 

298 ) 

299 

300 return results 

301 

302 

303def SerializableDataclass__validate_fields_types( 

304 self: SerializableDataclass, 

305 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

306) -> bool: 

307 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 

308 return all( 

309 SerializableDataclass__validate_fields_types__dict( 

310 self, on_typecheck_error=on_typecheck_error 

311 ).values() 

312 ) 

313 

314 

315@dataclass_transform( 

316 field_specifiers=(serializable_field, SerializableField), 

317) 

318class SerializableDataclass(abc.ABC): 

319 """Base class for serializable dataclasses 

320 

321 only for linting and type checking, still need to call `serializable_dataclass` decorator 

322 

323 # Usage: 

324 

325 ```python 

326 @serializable_dataclass 

327 class MyClass(SerializableDataclass): 

328 a: int 

329 b: str 

330 ``` 

331 

332 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 

333 

334 >>> my_obj = MyClass(a=1, b="q") 

335 >>> s = json.dumps(my_obj.serialize()) 

336 >>> s 

337 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 

338 >>> read_obj = MyClass.load(json.loads(s)) 

339 >>> read_obj == my_obj 

340 True 

341 

342 This isn't too impressive on its own, but it gets more useful when you have nested classses, 

343 or fields that are not json-serializable by default: 

344 

345 ```python 

346 @serializable_dataclass 

347 class NestedClass(SerializableDataclass): 

348 x: str 

349 y: MyClass 

350 act_fun: torch.nn.Module = serializable_field( 

351 default=torch.nn.ReLU(), 

352 serialization_fn=lambda x: str(x), 

353 deserialize_fn=lambda x: getattr(torch.nn, x)(), 

354 ) 

355 ``` 

356 

357 which gives us: 

358 

359 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 

360 >>> s = json.dumps(nc.serialize()) 

361 >>> s 

362 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 

363 >>> read_nc = NestedClass.load(json.loads(s)) 

364 >>> read_nc == nc 

365 True 

366 """ 

367 

368 def serialize(self) -> dict[str, Any]: 

369 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 

370 raise NotImplementedError( 

371 f"decorate {self.__class__ = } with `@serializable_dataclass`" 

372 ) 

373 

374 @classmethod 

375 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 

376 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 

377 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 

378 

379 def validate_fields_types( 

380 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 

381 ) -> bool: 

382 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 

383 return SerializableDataclass__validate_fields_types( 

384 self, on_typecheck_error=on_typecheck_error 

385 ) 

386 

387 def validate_field_type( 

388 self, 

389 field: "SerializableField|str", 

390 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

391 ) -> bool: 

392 """given a dataclass, check the field matches the type hint""" 

393 return SerializableDataclass__validate_field_type( 

394 self, field, on_typecheck_error=on_typecheck_error 

395 ) 

396 

397 def __eq__(self, other: Any) -> bool: 

398 return dc_eq(self, other) 

399 

400 def __hash__(self) -> int: 

401 "hashes the json-serialized representation of the class" 

402 return hash(json.dumps(self.serialize())) 

403 

404 def diff( 

405 self, other: "SerializableDataclass", of_serialized: bool = False 

406 ) -> dict[str, Any]: 

407 """get a rich and recursive diff between two instances of a serializable dataclass 

408 

409 ```python 

410 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 

411 {'b': {'self': 2, 'other': 3}} 

412 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 

413 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 

414 ``` 

415 

416 # Parameters: 

417 - `other : SerializableDataclass` 

418 other instance to compare against 

419 - `of_serialized : bool` 

420 if true, compare serialized data and not raw values 

421 (defaults to `False`) 

422 

423 # Returns: 

424 - `dict[str, Any]` 

425 

426 

427 # Raises: 

428 - `ValueError` : if the instances are not of the same type 

429 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 

430 """ 

431 # match types 

432 if type(self) is not type(other): 

433 raise ValueError( 

434 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 

435 ) 

436 

437 # initialize the diff result 

438 diff_result: dict = {} 

439 

440 # if they are the same, return the empty diff 

441 if self == other: 

442 return diff_result 

443 

444 # if we are working with serialized data, serialize the instances 

445 if of_serialized: 

446 ser_self: dict = self.serialize() 

447 ser_other: dict = other.serialize() 

448 

449 # for each field in the class 

450 for field in dataclasses.fields(self): # type: ignore[arg-type] 

451 # skip fields that are not for comparison 

452 if not field.compare: 

453 continue 

454 

455 # get values 

456 field_name: str = field.name 

457 self_value = getattr(self, field_name) 

458 other_value = getattr(other, field_name) 

459 

460 # if the values are both serializable dataclasses, recurse 

461 if isinstance(self_value, SerializableDataclass) and isinstance( 

462 other_value, SerializableDataclass 

463 ): 

464 nested_diff: dict = self_value.diff( 

465 other_value, of_serialized=of_serialized 

466 ) 

467 if nested_diff: 

468 diff_result[field_name] = nested_diff 

469 # only support serializable dataclasses 

470 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 

471 other_value 

472 ): 

473 raise ValueError("Non-serializable dataclass is not supported") 

474 else: 

475 # get the values of either the serialized or the actual values 

476 self_value_s = ser_self[field_name] if of_serialized else self_value 

477 other_value_s = ser_other[field_name] if of_serialized else other_value 

478 # compare the values 

479 if not array_safe_eq(self_value_s, other_value_s): 

480 diff_result[field_name] = {"self": self_value, "other": other_value} 

481 

482 # return the diff result 

483 return diff_result 

484 

485 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 

486 """update the instance from a nested dict, useful for configuration from command line args 

487 

488 # Parameters: 

489 - `nested_dict : dict[str, Any]` 

490 nested dict to update the instance with 

491 """ 

492 for field in dataclasses.fields(self): # type: ignore[arg-type] 

493 field_name: str = field.name 

494 self_value = getattr(self, field_name) 

495 

496 if field_name in nested_dict: 

497 if isinstance(self_value, SerializableDataclass): 

498 self_value.update_from_nested_dict(nested_dict[field_name]) 

499 else: 

500 setattr(self, field_name, nested_dict[field_name]) 

501 

502 def __copy__(self) -> "SerializableDataclass": 

503 "deep copy by serializing and loading the instance to json" 

504 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 

505 

506 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 

507 "deep copy by serializing and loading the instance to json" 

508 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 

509 

510 

511# cache this so we don't have to keep getting it 

512# TODO: are the types hashable? does this even make sense? 

513@functools.lru_cache(typed=True) 

514def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]: 

515 "cached typing.get_type_hints for a class" 

516 return typing.get_type_hints(cls) 

517 

518 

519def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: 

520 "helper function to get type hints for a class" 

521 cls_type_hints: dict[str, Any] 

522 try: 

523 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 

524 if len(cls_type_hints) == 0: 

525 cls_type_hints = typing.get_type_hints(cls) 

526 

527 if len(cls_type_hints) == 0: 

528 raise ValueError(f"empty type hints for {cls.__name__ = }") 

529 except (TypeError, NameError, ValueError) as e: 

530 raise TypeError( 

531 f"Cannot get type hints for {cls = }\n" 

532 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 

533 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 

534 + f" {e = }" 

535 ) from e 

536 

537 return cls_type_hints 

538 

539 

540class KWOnlyError(NotImplementedError): 

541 "kw-only dataclasses are not supported in python <3.9" 

542 

543 pass 

544 

545 

546class FieldError(ValueError): 

547 "base class for field errors" 

548 

549 pass 

550 

551 

552class NotSerializableFieldException(FieldError): 

553 "field is not a `SerializableField`" 

554 

555 pass 

556 

557 

558class FieldSerializationError(FieldError): 

559 "error while serializing a field" 

560 

561 pass 

562 

563 

564class FieldLoadingError(FieldError): 

565 "error while loading a field" 

566 

567 pass 

568 

569 

570@dataclass_transform( 

571 field_specifiers=(serializable_field, SerializableField), 

572) 

573def serializable_dataclass( 

574 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 

575 _cls=None, # type: ignore 

576 *, 

577 init: bool = True, 

578 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 

579 eq: bool = True, 

580 order: bool = False, 

581 unsafe_hash: bool = False, 

582 frozen: bool = False, 

583 properties_to_serialize: Optional[list[str]] = None, 

584 register_handler: bool = True, 

585 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

586 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 

587 **kwargs, 

588): 

589 """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass` 

590 

591 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 

592 

593 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs 

594 

595 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 

596 

597 Examines PEP 526 `__annotations__` to determine fields. 

598 

599 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 

600 

601 ```python 

602 @serializable_dataclass(kw_only=True) 

603 class Myclass(SerializableDataclass): 

604 a: int 

605 b: str 

606 ``` 

607 ```python 

608 >>> Myclass(a=1, b="q").serialize() 

609 {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 

610 ``` 

611 

612 # Parameters: 

613 - `_cls : _type_` 

614 class to decorate. don't pass this arg, just use this as a decorator 

615 (defaults to `None`) 

616 - `init : bool` 

617 (defaults to `True`) 

618 - `repr : bool` 

619 (defaults to `True`) 

620 - `order : bool` 

621 (defaults to `False`) 

622 - `unsafe_hash : bool` 

623 (defaults to `False`) 

624 - `frozen : bool` 

625 (defaults to `False`) 

626 - `properties_to_serialize : Optional[list[str]]` 

627 **SerializableDataclass only:** which properties to add to the serialized data dict 

628 (defaults to `None`) 

629 - `register_handler : bool` 

630 **SerializableDataclass only:** if true, register the class with ZANJ for loading 

631 (defaults to `True`) 

632 - `on_typecheck_error : ErrorMode` 

633 **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 

634 - `on_typecheck_mismatch : ErrorMode` 

635 **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 

636 

637 # Returns: 

638 - `_type_` 

639 the decorated class 

640 

641 # Raises: 

642 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 

643 - `NotSerializableFieldException` : if a field is not a `SerializableField` 

644 - `FieldSerializationError` : if there is an error serializing a field 

645 - `AttributeError` : if a property is not found on the class 

646 - `FieldLoadingError` : if there is an error loading a field 

647 """ 

648 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 

649 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

650 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 

651 

652 if properties_to_serialize is None: 

653 _properties_to_serialize: list = list() 

654 else: 

655 _properties_to_serialize = properties_to_serialize 

656 

657 def wrap(cls: Type[T]) -> Type[T]: 

658 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 

659 for field_name, field_type in cls.__annotations__.items(): 

660 field_value = getattr(cls, field_name, None) 

661 if not isinstance(field_value, SerializableField): 

662 if isinstance(field_value, dataclasses.Field): 

663 # Convert the field to a SerializableField while preserving properties 

664 field_value = SerializableField.from_Field(field_value) 

665 else: 

666 # Create a new SerializableField 

667 field_value = serializable_field() 

668 setattr(cls, field_name, field_value) 

669 

670 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 

671 if sys.version_info < (3, 10): 

672 if "kw_only" in kwargs: 

673 if kwargs["kw_only"] == True: # noqa: E712 

674 raise KWOnlyError("kw_only is not supported in python >=3.9") 

675 else: 

676 del kwargs["kw_only"] 

677 

678 # call `dataclasses.dataclass` to set some stuff up 

679 cls = dataclasses.dataclass( # type: ignore[call-overload] 

680 cls, 

681 init=init, 

682 repr=repr, 

683 eq=eq, 

684 order=order, 

685 unsafe_hash=unsafe_hash, 

686 frozen=frozen, 

687 **kwargs, 

688 ) 

689 

690 # copy these to the class 

691 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 

692 

693 # ====================================================================== 

694 # define `serialize` func 

695 # done locally since it depends on args to the decorator 

696 # ====================================================================== 

697 def serialize(self) -> dict[str, Any]: 

698 result: dict[str, Any] = { 

699 "__format__": f"{self.__class__.__name__}(SerializableDataclass)" 

700 } 

701 # for each field in the class 

702 for field in dataclasses.fields(self): # type: ignore[arg-type] 

703 # need it to be our special SerializableField 

704 if not isinstance(field, SerializableField): 

705 raise NotSerializableFieldException( 

706 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 

707 f"but a {type(field)} " 

708 "this state should be inaccessible, please report this bug!" 

709 ) 

710 

711 # try to save it 

712 if field.serialize: 

713 try: 

714 # get the val 

715 value = getattr(self, field.name) 

716 # if it is a serializable dataclass, serialize it 

717 if isinstance(value, SerializableDataclass): 

718 value = value.serialize() 

719 # if the value has a serialization function, use that 

720 if hasattr(value, "serialize") and callable(value.serialize): 

721 value = value.serialize() 

722 # if the field has a serialization function, use that 

723 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 

724 elif field.serialization_fn: 

725 value = field.serialization_fn(value) 

726 

727 # store the value in the result 

728 result[field.name] = value 

729 except Exception as e: 

730 raise FieldSerializationError( 

731 "\n".join( 

732 [ 

733 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 

734 f"{field = }", 

735 f"{value = }", 

736 f"{self = }", 

737 ] 

738 ) 

739 ) from e 

740 

741 # store each property if we can get it 

742 for prop in self._properties_to_serialize: 

743 if hasattr(cls, prop): 

744 value = getattr(self, prop) 

745 result[prop] = value 

746 else: 

747 raise AttributeError( 

748 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 

749 + f"but it is in {self._properties_to_serialize = }" 

750 + f"\n{self = }" 

751 ) 

752 

753 return result 

754 

755 # ====================================================================== 

756 # define `load` func 

757 # done locally since it depends on args to the decorator 

758 # ====================================================================== 

759 # mypy thinks this isnt a classmethod 

760 @classmethod # type: ignore[misc] 

761 def load(cls, data: dict[str, Any] | T) -> Type[T]: 

762 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 

763 if isinstance(data, cls): 

764 return data 

765 

766 assert isinstance( 

767 data, typing.Mapping 

768 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 

769 

770 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 

771 

772 # initialize dict for keeping what we will pass to the constructor 

773 ctor_kwargs: dict[str, Any] = dict() 

774 

775 # iterate over the fields of the class 

776 for field in dataclasses.fields(cls): 

777 # check if the field is a SerializableField 

778 assert isinstance( 

779 field, SerializableField 

780 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 

781 

782 # check if the field is in the data and if it should be initialized 

783 if (field.name in data) and field.init: 

784 # get the value, we will be processing it 

785 value: Any = data[field.name] 

786 

787 # get the type hint for the field 

788 field_type_hint: Any = cls_type_hints.get(field.name, None) 

789 

790 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 

791 if field.deserialize_fn: 

792 # if it has a deserialization function, use that 

793 value = field.deserialize_fn(value) 

794 elif field.loading_fn: 

795 # if it has a loading function, use that 

796 value = field.loading_fn(data) 

797 elif ( 

798 field_type_hint is not None 

799 and hasattr(field_type_hint, "load") 

800 and callable(field_type_hint.load) 

801 ): 

802 # if no loading function but has a type hint with a load method, use that 

803 if isinstance(value, dict): 

804 value = field_type_hint.load(value) 

805 else: 

806 raise FieldLoadingError( 

807 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 

808 ) 

809 else: 

810 # assume no loading needs to happen, keep `value` as-is 

811 pass 

812 

813 # store the value in the constructor kwargs 

814 ctor_kwargs[field.name] = value 

815 

816 # create a new instance of the class with the constructor kwargs 

817 output: cls = cls(**ctor_kwargs) 

818 

819 # validate the types of the fields if needed 

820 if on_typecheck_mismatch != ErrorMode.IGNORE: 

821 output.validate_fields_types(on_typecheck_error=on_typecheck_error) 

822 

823 # return the new instance 

824 return output 

825 

826 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 

827 # type is `Callable[[T], dict]` 

828 cls.serialize = serialize # type: ignore[attr-defined] 

829 # type is `Callable[[dict], T]` 

830 cls.load = load # type: ignore[attr-defined] 

831 # type is `Callable[[T, ErrorMode], bool]` 

832 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 

833 

834 # type is `Callable[[T, T], bool]` 

835 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 

836 

837 # Register the class with ZANJ 

838 if register_handler: 

839 zanj_register_loader_serializable_dataclass(cls) 

840 

841 return cls 

842 

843 if _cls is None: 

844 return wrap 

845 else: 

846 return wrap(_cls)