muutils.json_serialize.serializable_dataclass
save and load objects to and from json or compatible formats in a recoverable way
d = dataclasses.asdict(my_obj) will give you a dict, but if some fields are not json-serializable,
you will get an error when you call json.dumps(d). This module provides a way around that.
Instead, you define your class:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
1"""save and load objects to and from json or compatible formats in a recoverable way 2 3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable, 4you will get an error when you call `json.dumps(d)`. This module provides a way around that. 5 6Instead, you define your class: 7 8```python 9@serializable_dataclass 10class MyClass(SerializableDataclass): 11 a: int 12 b: str 13``` 14 15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 16 17 >>> my_obj = MyClass(a=1, b="q") 18 >>> s = json.dumps(my_obj.serialize()) 19 >>> s 20 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 21 >>> read_obj = MyClass.load(json.loads(s)) 22 >>> read_obj == my_obj 23 True 24 25This isn't too impressive on its own, but it gets more useful when you have nested classses, 26or fields that are not json-serializable by default: 27 28```python 29@serializable_dataclass 30class NestedClass(SerializableDataclass): 31 x: str 32 y: MyClass 33 act_fun: torch.nn.Module = serializable_field( 34 default=torch.nn.ReLU(), 35 serialization_fn=lambda x: str(x), 36 deserialize_fn=lambda x: getattr(torch.nn, x)(), 37 ) 38``` 39 40which gives us: 41 42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 43 >>> s = json.dumps(nc.serialize()) 44 >>> s 45 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 46 >>> read_nc = NestedClass.load(json.loads(s)) 47 >>> read_nc == nc 48 True 49 50""" 51 52from __future__ import annotations 53 54import abc 55import dataclasses 56import functools 57import json 58import sys 59import typing 60import warnings 61from typing import Any, Optional, Type, TypeVar, overload, TYPE_CHECKING 62 63from muutils.errormode import ErrorMode 64from muutils.validate_type import validate_type 65from muutils.json_serialize.serializable_field import ( 66 SerializableField, 67 serializable_field, 68) 69from muutils.json_serialize.types import _FORMAT_KEY 70from muutils.json_serialize.util import ( 71 JSONdict, 72 array_safe_eq, 73 dc_eq, 74) 75 76# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access 77 78# For type checkers: always use typing_extensions which they can resolve 79# At runtime: use stdlib if available (3.11+), else typing_extensions, else mock 80if TYPE_CHECKING: 81 from typing_extensions import dataclass_transform, Self 82else: 83 if sys.version_info >= (3, 11): 84 from typing import dataclass_transform, Self 85 else: 86 try: 87 from typing_extensions import dataclass_transform, Self 88 except Exception: 89 from muutils.json_serialize.dataclass_transform_mock import ( 90 dataclass_transform, 91 ) 92 93 Self = TypeVar("Self") 94 95T_SerializeableDataclass = TypeVar( 96 "T_SerializeableDataclass", bound="SerializableDataclass" 97) 98 99 100class CantGetTypeHintsWarning(UserWarning): 101 "special warning for when we can't get type hints" 102 103 pass 104 105 106class ZanjMissingWarning(UserWarning): 107 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 108 109 pass 110 111 112_zanj_loading_needs_import: bool = True 113"flag to keep track of if we have successfully imported ZANJ" 114 115 116def zanj_register_loader_serializable_dataclass( 117 cls: typing.Type[T_SerializeableDataclass], 118): 119 """Register a serializable dataclass with the ZANJ import 120 121 this allows `ZANJ().read()` to load the class and not just return plain dicts 122 123 124 # TODO: there is some duplication here with register_loader_handler 125 """ 126 global _zanj_loading_needs_import 127 128 if _zanj_loading_needs_import: 129 try: 130 from zanj.loading import ( # type: ignore[import] # pyright: ignore[reportMissingImports] 131 LoaderHandler, # pyright: ignore[reportUnknownVariableType] 132 register_loader_handler, # pyright: ignore[reportUnknownVariableType] 133 ) 134 except ImportError: 135 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter 136 # warnings.warn( 137 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 138 # ZanjMissingWarning, 139 # ) 140 return 141 142 _format: str = f"{cls.__name__}(SerializableDataclass)" 143 lh: LoaderHandler = LoaderHandler( # pyright: ignore[reportPossiblyUnboundVariable] 144 check=lambda json_item, path=None, z=None: ( # type: ignore 145 isinstance(json_item, dict) 146 and _FORMAT_KEY in json_item 147 and json_item[_FORMAT_KEY].startswith(_format) 148 ), 149 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 150 uid=_format, 151 source_pckg=cls.__module__, 152 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 153 ) 154 155 register_loader_handler(lh) # pyright: ignore[reportPossiblyUnboundVariable] 156 157 return lh 158 159 160_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN 161_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT 162 163 164class FieldIsNotInitOrSerializeWarning(UserWarning): 165 pass 166 167 168def SerializableDataclass__validate_field_type( 169 self: SerializableDataclass, 170 field: SerializableField | str, 171 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 172) -> bool: 173 """given a dataclass, check the field matches the type hint 174 175 this function is written to `SerializableDataclass.validate_field_type` 176 177 # Parameters: 178 - `self : SerializableDataclass` 179 `SerializableDataclass` instance 180 - `field : SerializableField | str` 181 field to validate, will get from `self.__dataclass_fields__` if an `str` 182 - `on_typecheck_error : ErrorMode` 183 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 184 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 185 186 # Returns: 187 - `bool` 188 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 189 """ 190 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 191 192 # get field 193 _field: SerializableField 194 if isinstance(field, str): 195 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 196 else: 197 _field = field 198 199 # do nothing case 200 if not _field.assert_type: 201 return True 202 203 # if field is not `init` or not `serialize`, skip but warn 204 # TODO: how to handle fields which are not `init` or `serialize`? 205 if not _field.init or not _field.serialize: 206 warnings.warn( 207 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 208 FieldIsNotInitOrSerializeWarning, 209 ) 210 return True 211 212 assert isinstance(_field, SerializableField), ( 213 f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 214 ) 215 216 # get field type hints 217 try: 218 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 219 except KeyError as e: 220 on_typecheck_error.process( 221 ( 222 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 223 + f"{get_cls_type_hints(self.__class__) = }\n" 224 + f"Python version is {sys.version_info = }. You can:\n" 225 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 226 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 227 + " - use python 3.9.x or higher\n" 228 + " - specify custom type validation function via `custom_typecheck_fn`\n" 229 ), 230 except_cls=TypeError, 231 except_from=e, 232 ) 233 return False 234 235 # get the value 236 value: Any = getattr(self, _field.name) 237 238 # validate the type 239 try: 240 type_is_valid: bool 241 # validate the type with the default type validator 242 if _field.custom_typecheck_fn is None: 243 type_is_valid = validate_type(value, field_type_hint) 244 # validate the type with a custom type validator 245 else: 246 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 247 248 return type_is_valid 249 250 except Exception as e: 251 on_typecheck_error.process( 252 "exception while validating type: " 253 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 254 except_cls=ValueError, 255 except_from=e, 256 ) 257 return False 258 259 260def SerializableDataclass__validate_fields_types__dict( 261 self: SerializableDataclass, 262 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 263) -> dict[str, bool]: 264 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 265 266 returns a dict of field names to bools, where the bool is if the field type is valid 267 """ 268 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 269 270 # if except, bundle the exceptions 271 results: dict[str, bool] = dict() 272 exceptions: dict[str, Exception] = dict() 273 274 # for each field in the class 275 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 276 for field in cls_fields: 277 try: 278 results[field.name] = self.validate_field_type(field, on_typecheck_error) 279 except Exception as e: 280 results[field.name] = False 281 exceptions[field.name] = e 282 283 # figure out what to do with the exceptions 284 if len(exceptions) > 0: 285 on_typecheck_error.process( 286 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 287 + "\n\t" 288 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 289 except_cls=ValueError, 290 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 291 except_from=list(exceptions.values())[0], 292 ) 293 294 return results 295 296 297def SerializableDataclass__validate_fields_types( 298 self: SerializableDataclass, 299 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 300) -> bool: 301 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 302 return all( 303 SerializableDataclass__validate_fields_types__dict( 304 self, on_typecheck_error=on_typecheck_error 305 ).values() 306 ) 307 308 309@dataclass_transform( 310 field_specifiers=(serializable_field, SerializableField), 311) 312class SerializableDataclass(abc.ABC): 313 """Base class for serializable dataclasses 314 315 only for linting and type checking, still need to call `serializable_dataclass` decorator 316 317 # Usage: 318 319 ```python 320 @serializable_dataclass 321 class MyClass(SerializableDataclass): 322 a: int 323 b: str 324 ``` 325 326 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 327 328 >>> my_obj = MyClass(a=1, b="q") 329 >>> s = json.dumps(my_obj.serialize()) 330 >>> s 331 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 332 >>> read_obj = MyClass.load(json.loads(s)) 333 >>> read_obj == my_obj 334 True 335 336 This isn't too impressive on its own, but it gets more useful when you have nested classses, 337 or fields that are not json-serializable by default: 338 339 ```python 340 @serializable_dataclass 341 class NestedClass(SerializableDataclass): 342 x: str 343 y: MyClass 344 act_fun: torch.nn.Module = serializable_field( 345 default=torch.nn.ReLU(), 346 serialization_fn=lambda x: str(x), 347 deserialize_fn=lambda x: getattr(torch.nn, x)(), 348 ) 349 ``` 350 351 which gives us: 352 353 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 354 >>> s = json.dumps(nc.serialize()) 355 >>> s 356 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 357 >>> read_nc = NestedClass.load(json.loads(s)) 358 >>> read_nc == nc 359 True 360 """ 361 362 def serialize(self) -> dict[str, Any]: 363 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 364 raise NotImplementedError( 365 f"decorate {self.__class__ = } with `@serializable_dataclass`" 366 ) 367 368 @overload 369 @classmethod 370 def load(cls, data: dict[str, Any]) -> Self: ... 371 372 @overload 373 @classmethod 374 def load(cls, data: Self) -> Self: ... 375 376 @classmethod 377 def load(cls, data: dict[str, Any] | Self) -> Self: 378 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 379 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 380 381 def validate_fields_types( 382 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 383 ) -> bool: 384 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 385 return SerializableDataclass__validate_fields_types( 386 self, on_typecheck_error=on_typecheck_error 387 ) 388 389 def validate_field_type( 390 self, 391 field: "SerializableField|str", 392 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 393 ) -> bool: 394 """given a dataclass, check the field matches the type hint""" 395 return SerializableDataclass__validate_field_type( 396 self, field, on_typecheck_error=on_typecheck_error 397 ) 398 399 def __eq__(self, other: Any) -> bool: 400 return dc_eq(self, other) 401 402 def __hash__(self) -> int: 403 "hashes the json-serialized representation of the class" 404 return hash(json.dumps(self.serialize())) 405 406 def diff( 407 self, other: "SerializableDataclass", of_serialized: bool = False 408 ) -> dict[str, Any]: 409 """get a rich and recursive diff between two instances of a serializable dataclass 410 411 ```python 412 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 413 {'b': {'self': 2, 'other': 3}} 414 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 415 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 416 ``` 417 418 # Parameters: 419 - `other : SerializableDataclass` 420 other instance to compare against 421 - `of_serialized : bool` 422 if true, compare serialized data and not raw values 423 (defaults to `False`) 424 425 # Returns: 426 - `dict[str, Any]` 427 428 429 # Raises: 430 - `ValueError` : if the instances are not of the same type 431 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 432 """ 433 # match types 434 if type(self) is not type(other): 435 raise ValueError( 436 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 437 ) 438 439 # initialize the diff result 440 diff_result: dict = {} 441 442 # if they are the same, return the empty diff 443 try: 444 if self == other: 445 return diff_result 446 except Exception: 447 pass 448 449 # if we are working with serialized data, serialize the instances 450 if of_serialized: 451 ser_self: JSONdict = self.serialize() 452 ser_other: JSONdict = other.serialize() 453 454 # for each field in the class 455 for field in dataclasses.fields(self): # type: ignore[arg-type] # pyright: ignore[reportArgumentType] 456 # skip fields that are not for comparison 457 if not field.compare: 458 continue 459 460 # get values 461 field_name: str = field.name 462 self_value = getattr(self, field_name) 463 other_value = getattr(other, field_name) 464 465 # if the values are both serializable dataclasses, recurse 466 if isinstance(self_value, SerializableDataclass) and isinstance( 467 other_value, SerializableDataclass 468 ): 469 nested_diff: dict = self_value.diff( 470 other_value, of_serialized=of_serialized 471 ) 472 if nested_diff: 473 diff_result[field_name] = nested_diff 474 # only support serializable dataclasses 475 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 476 other_value 477 ): 478 raise ValueError("Non-serializable dataclass is not supported") 479 else: 480 # get the values of either the serialized or the actual values 481 if of_serialized: 482 self_value_s = ser_self[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] 483 other_value_s = ser_other[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] 484 else: 485 self_value_s = self_value 486 other_value_s = other_value 487 # compare the values 488 if not array_safe_eq(self_value_s, other_value_s): 489 diff_result[field_name] = {"self": self_value, "other": other_value} 490 491 # return the diff result 492 return diff_result 493 494 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 495 """update the instance from a nested dict, useful for configuration from command line args 496 497 # Parameters: 498 - `nested_dict : dict[str, Any]` 499 nested dict to update the instance with 500 """ 501 for field in dataclasses.fields(self): # type: ignore[arg-type] 502 field_name: str = field.name 503 self_value = getattr(self, field_name) 504 505 if field_name in nested_dict: 506 if isinstance(self_value, SerializableDataclass): 507 self_value.update_from_nested_dict(nested_dict[field_name]) 508 else: 509 setattr(self, field_name, nested_dict[field_name]) 510 511 def __copy__(self) -> "SerializableDataclass": 512 "deep copy by serializing and loading the instance to json" 513 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 514 515 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 516 "deep copy by serializing and loading the instance to json" 517 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 518 519 520# cache this so we don't have to keep getting it 521# TODO: are the types hashable? does this even make sense? 522@functools.lru_cache(typed=True) 523def get_cls_type_hints_cached(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]: 524 "cached typing.get_type_hints for a class" 525 return typing.get_type_hints(cls) 526 527 528def get_cls_type_hints(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]: 529 "helper function to get type hints for a class" 530 cls_type_hints: dict[str, Any] 531 try: 532 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 533 if len(cls_type_hints) == 0: 534 cls_type_hints = typing.get_type_hints(cls) 535 536 if len(cls_type_hints) == 0: 537 raise ValueError(f"empty type hints for {cls.__name__ = }") 538 except (TypeError, NameError, ValueError) as e: 539 raise TypeError( 540 f"Cannot get type hints for {cls = }\n" 541 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 542 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 543 + f" {e = }" 544 ) from e 545 546 return cls_type_hints 547 548 549class KWOnlyError(NotImplementedError): 550 "kw-only dataclasses are not supported in python <3.9" 551 552 pass 553 554 555class FieldError(ValueError): 556 "base class for field errors" 557 558 pass 559 560 561class NotSerializableFieldException(FieldError): 562 "field is not a `SerializableField`" 563 564 pass 565 566 567class FieldSerializationError(FieldError): 568 "error while serializing a field" 569 570 pass 571 572 573class FieldLoadingError(FieldError): 574 "error while loading a field" 575 576 pass 577 578 579class FieldTypeMismatchError(FieldError, TypeError): 580 "error when a field type does not match the type hint" 581 582 pass 583 584 585@dataclass_transform( 586 field_specifiers=(serializable_field, SerializableField), 587) 588def serializable_dataclass( 589 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 590 _cls=None, # type: ignore 591 *, 592 init: bool = True, 593 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 594 eq: bool = True, 595 order: bool = False, 596 unsafe_hash: bool = False, 597 frozen: bool = False, 598 properties_to_serialize: Optional[list[str]] = None, 599 register_handler: bool = True, 600 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 601 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 602 methods_no_override: list[str] | None = None, 603 **kwargs: Any, 604) -> Any: 605 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!** 606 607 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 608 609 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass` 610 611 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 612 613 Examines PEP 526 `__annotations__` to determine fields. 614 615 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 616 617 ```python 618 @serializable_dataclass(kw_only=True) 619 class Myclass(SerializableDataclass): 620 a: int 621 b: str 622 ``` 623 ```python 624 >>> Myclass(a=1, b="q").serialize() 625 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 626 ``` 627 628 # Parameters: 629 630 - `_cls : _type_` 631 class to decorate. don't pass this arg, just use this as a decorator 632 (defaults to `None`) 633 - `init : bool` 634 whether to add an `__init__` method 635 *(passed to dataclasses.dataclass)* 636 (defaults to `True`) 637 - `repr : bool` 638 whether to add a `__repr__` method 639 *(passed to dataclasses.dataclass)* 640 (defaults to `True`) 641 - `order : bool` 642 whether to add rich comparison methods 643 *(passed to dataclasses.dataclass)* 644 (defaults to `False`) 645 - `unsafe_hash : bool` 646 whether to add a `__hash__` method 647 *(passed to dataclasses.dataclass)* 648 (defaults to `False`) 649 - `frozen : bool` 650 whether to make the class frozen 651 *(passed to dataclasses.dataclass)* 652 (defaults to `False`) 653 - `properties_to_serialize : Optional[list[str]]` 654 which properties to add to the serialized data dict 655 **SerializableDataclass only** 656 (defaults to `None`) 657 - `register_handler : bool` 658 if true, register the class with ZANJ for loading 659 **SerializableDataclass only** 660 (defaults to `True`) 661 - `on_typecheck_error : ErrorMode` 662 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 663 **SerializableDataclass only** 664 - `on_typecheck_mismatch : ErrorMode` 665 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 666 **SerializableDataclass only** 667 - `methods_no_override : list[str]|None` 668 list of methods that should not be overridden by the decorator 669 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function, 670 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence 671 **SerializableDataclass only** 672 (defaults to `None`) 673 - `**kwargs` 674 *(passed to dataclasses.dataclass)* 675 676 # Returns: 677 678 - `_type_` 679 the decorated class 680 681 # Raises: 682 683 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 684 - `NotSerializableFieldException` : if a field is not a `SerializableField` 685 - `FieldSerializationError` : if there is an error serializing a field 686 - `AttributeError` : if a property is not found on the class 687 - `FieldLoadingError` : if there is an error loading a field 688 """ 689 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 690 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 691 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 692 693 if properties_to_serialize is None: 694 _properties_to_serialize: list = list() 695 else: 696 _properties_to_serialize = properties_to_serialize 697 698 def wrap(cls: Type[T_SerializeableDataclass]) -> Type[T_SerializeableDataclass]: 699 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 700 for field_name, field_type in cls.__annotations__.items(): 701 field_value = getattr(cls, field_name, None) 702 if not isinstance(field_value, SerializableField): 703 if isinstance(field_value, dataclasses.Field): 704 # Convert the field to a SerializableField while preserving properties 705 field_value = SerializableField.from_Field(field_value) 706 else: 707 # Create a new SerializableField 708 field_value = serializable_field() 709 setattr(cls, field_name, field_value) 710 711 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 712 if sys.version_info < (3, 10): 713 if "kw_only" in kwargs: 714 if kwargs["kw_only"] == True: # noqa: E712 715 raise KWOnlyError( 716 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored" 717 ) 718 else: 719 del kwargs["kw_only"] 720 721 # call `dataclasses.dataclass` to set some stuff up 722 cls = dataclasses.dataclass( # type: ignore[call-overload] 723 cls, 724 init=init, 725 repr=repr, 726 eq=eq, 727 order=order, 728 unsafe_hash=unsafe_hash, 729 frozen=frozen, 730 **kwargs, 731 ) 732 733 # copy these to the class 734 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 735 736 # ====================================================================== 737 # define `serialize` func 738 # done locally since it depends on args to the decorator 739 # ====================================================================== 740 def serialize(self: Any) -> dict[str, Any]: 741 result: dict[str, Any] = { 742 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 743 } 744 # for each field in the class 745 for field in dataclasses.fields(self): # type: ignore[arg-type] 746 # need it to be our special SerializableField 747 if not isinstance(field, SerializableField): 748 raise NotSerializableFieldException( 749 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 750 f"but a {type(field)} " 751 "this state should be inaccessible, please report this bug!" 752 ) 753 754 # try to save it 755 if field.serialize: 756 value: Any = None # init before try in case getattr raises 757 try: 758 # get the val 759 value = getattr(self, field.name) 760 # if it is a serializable dataclass, serialize it 761 if isinstance(value, SerializableDataclass): 762 value = value.serialize() 763 # if the value has a serialization function, use that 764 if hasattr(value, "serialize") and callable(value.serialize): # pyright: ignore[reportAttributeAccessIssue] 765 value = value.serialize() # pyright: ignore[reportAttributeAccessIssue] 766 # if the field has a serialization function, use that 767 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 768 elif field.serialization_fn: 769 value = field.serialization_fn(value) 770 771 # store the value in the result 772 result[field.name] = value 773 except Exception as e: 774 raise FieldSerializationError( 775 "\n".join( 776 [ 777 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 778 f"{field = }", 779 f"{value or '<unavailable>' = }", 780 f"{self = }", 781 ] 782 ) 783 ) from e 784 785 # store each property if we can get it 786 for prop in self._properties_to_serialize: 787 if hasattr(cls, prop): 788 value = getattr(self, prop) 789 result[prop] = value 790 else: 791 raise AttributeError( 792 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 793 + f"but it is in {self._properties_to_serialize = }" 794 + f"\n{self = }" 795 ) 796 797 return result 798 799 # ====================================================================== 800 # define `load` func 801 # done locally since it depends on args to the decorator 802 # ====================================================================== 803 # mypy thinks this isnt a classmethod 804 @classmethod # type: ignore[misc] 805 def load( 806 cls: type[T_SerializeableDataclass], 807 data: dict[str, Any] | T_SerializeableDataclass, 808 ) -> T_SerializeableDataclass: 809 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 810 if isinstance(data, cls): 811 return data 812 813 assert isinstance(data, typing.Mapping), ( 814 f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 815 ) 816 817 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 818 819 # initialize dict for keeping what we will pass to the constructor 820 ctor_kwargs: dict[str, Any] = dict() 821 822 # iterate over the fields of the class 823 # mypy doesn't recognize @dataclass_transform for dataclasses.fields() 824 # https://github.com/python/mypy/issues/16241 825 for field in dataclasses.fields(cls): # type: ignore[arg-type] 826 # check if the field is a SerializableField 827 assert isinstance(field, SerializableField), ( 828 f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 829 ) 830 831 # check if the field is in the data and if it should be initialized 832 if (field.name in data) and field.init: 833 # get the value, we will be processing it 834 value: Any = data[field.name] 835 836 # get the type hint for the field 837 field_type_hint: Any = cls_type_hints.get(field.name, None) 838 839 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 840 if field.deserialize_fn: 841 # if it has a deserialization function, use that 842 value = field.deserialize_fn(value) 843 elif field.loading_fn: 844 # if it has a loading function, use that 845 value = field.loading_fn(data) 846 elif ( 847 field_type_hint is not None 848 and hasattr(field_type_hint, "load") 849 and callable(field_type_hint.load) 850 ): 851 # if no loading function but has a type hint with a load method, use that 852 if isinstance(value, dict): 853 value = field_type_hint.load(value) 854 else: 855 raise FieldLoadingError( 856 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 857 ) 858 else: 859 # assume no loading needs to happen, keep `value` as-is 860 pass 861 862 # store the value in the constructor kwargs 863 ctor_kwargs[field.name] = value 864 865 # create a new instance of the class with the constructor kwargs 866 output: T_SerializeableDataclass = cls(**ctor_kwargs) 867 868 # validate the types of the fields if needed 869 if on_typecheck_mismatch != ErrorMode.IGNORE: 870 fields_valid: dict[str, bool] = ( 871 SerializableDataclass__validate_fields_types__dict( 872 output, 873 on_typecheck_error=on_typecheck_error, 874 ) 875 ) 876 877 # if there are any fields that are not valid, raise an error 878 if not all(fields_valid.values()): 879 msg: str = ( 880 f"Type mismatch in fields of {cls.__name__}:\n" 881 + "\n".join( 882 [ 883 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 884 for k, v in fields_valid.items() 885 if not v 886 ] 887 ) 888 ) 889 890 on_typecheck_mismatch.process( 891 msg, except_cls=FieldTypeMismatchError 892 ) 893 894 # return the new instance 895 return output 896 897 _methods_no_override: set[str] 898 if methods_no_override is None: 899 _methods_no_override = set() 900 else: 901 _methods_no_override = set(methods_no_override) 902 903 if _methods_no_override - { 904 "__eq__", 905 "serialize", 906 "load", 907 "validate_fields_types", 908 }: 909 warnings.warn( 910 f"Unknown methods in `methods_no_override`: {_methods_no_override = }" 911 ) 912 913 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 914 if "serialize" not in _methods_no_override: 915 # type is `Callable[[T], dict]` 916 cls.serialize = serialize # type: ignore[attr-defined, method-assign] 917 if "load" not in _methods_no_override: 918 # type is `Callable[[dict], T]` 919 cls.load = load # type: ignore[attr-defined, method-assign, assignment] 920 921 if "validate_field_type" not in _methods_no_override: 922 # type is `Callable[[T, ErrorMode], bool]` 923 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined, method-assign] 924 925 if "__eq__" not in _methods_no_override: 926 # type is `Callable[[T, T], bool]` 927 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 928 929 # Register the class with ZANJ 930 if register_handler: 931 zanj_register_loader_serializable_dataclass(cls) 932 933 return cls 934 935 if _cls is None: 936 return wrap 937 else: 938 return wrap(_cls)
101class CantGetTypeHintsWarning(UserWarning): 102 "special warning for when we can't get type hints" 103 104 pass
special warning for when we can't get type hints
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
107class ZanjMissingWarning(UserWarning): 108 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 109 110 pass
special warning for when ZANJ is missing -- register_loader_serializable_dataclass will not work
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
117def zanj_register_loader_serializable_dataclass( 118 cls: typing.Type[T_SerializeableDataclass], 119): 120 """Register a serializable dataclass with the ZANJ import 121 122 this allows `ZANJ().read()` to load the class and not just return plain dicts 123 124 125 # TODO: there is some duplication here with register_loader_handler 126 """ 127 global _zanj_loading_needs_import 128 129 if _zanj_loading_needs_import: 130 try: 131 from zanj.loading import ( # type: ignore[import] # pyright: ignore[reportMissingImports] 132 LoaderHandler, # pyright: ignore[reportUnknownVariableType] 133 register_loader_handler, # pyright: ignore[reportUnknownVariableType] 134 ) 135 except ImportError: 136 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter 137 # warnings.warn( 138 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 139 # ZanjMissingWarning, 140 # ) 141 return 142 143 _format: str = f"{cls.__name__}(SerializableDataclass)" 144 lh: LoaderHandler = LoaderHandler( # pyright: ignore[reportPossiblyUnboundVariable] 145 check=lambda json_item, path=None, z=None: ( # type: ignore 146 isinstance(json_item, dict) 147 and _FORMAT_KEY in json_item 148 and json_item[_FORMAT_KEY].startswith(_format) 149 ), 150 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 151 uid=_format, 152 source_pckg=cls.__module__, 153 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 154 ) 155 156 register_loader_handler(lh) # pyright: ignore[reportPossiblyUnboundVariable] 157 158 return lh
Register a serializable dataclass with the ZANJ import
this allows ZANJ().read() to load the class and not just return plain dicts
TODO: there is some duplication here with register_loader_handler
Base class for warnings generated by user code.
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
169def SerializableDataclass__validate_field_type( 170 self: SerializableDataclass, 171 field: SerializableField | str, 172 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 173) -> bool: 174 """given a dataclass, check the field matches the type hint 175 176 this function is written to `SerializableDataclass.validate_field_type` 177 178 # Parameters: 179 - `self : SerializableDataclass` 180 `SerializableDataclass` instance 181 - `field : SerializableField | str` 182 field to validate, will get from `self.__dataclass_fields__` if an `str` 183 - `on_typecheck_error : ErrorMode` 184 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 185 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 186 187 # Returns: 188 - `bool` 189 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 190 """ 191 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 192 193 # get field 194 _field: SerializableField 195 if isinstance(field, str): 196 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 197 else: 198 _field = field 199 200 # do nothing case 201 if not _field.assert_type: 202 return True 203 204 # if field is not `init` or not `serialize`, skip but warn 205 # TODO: how to handle fields which are not `init` or `serialize`? 206 if not _field.init or not _field.serialize: 207 warnings.warn( 208 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 209 FieldIsNotInitOrSerializeWarning, 210 ) 211 return True 212 213 assert isinstance(_field, SerializableField), ( 214 f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 215 ) 216 217 # get field type hints 218 try: 219 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 220 except KeyError as e: 221 on_typecheck_error.process( 222 ( 223 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 224 + f"{get_cls_type_hints(self.__class__) = }\n" 225 + f"Python version is {sys.version_info = }. You can:\n" 226 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 227 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 228 + " - use python 3.9.x or higher\n" 229 + " - specify custom type validation function via `custom_typecheck_fn`\n" 230 ), 231 except_cls=TypeError, 232 except_from=e, 233 ) 234 return False 235 236 # get the value 237 value: Any = getattr(self, _field.name) 238 239 # validate the type 240 try: 241 type_is_valid: bool 242 # validate the type with the default type validator 243 if _field.custom_typecheck_fn is None: 244 type_is_valid = validate_type(value, field_type_hint) 245 # validate the type with a custom type validator 246 else: 247 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 248 249 return type_is_valid 250 251 except Exception as e: 252 on_typecheck_error.process( 253 "exception while validating type: " 254 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 255 except_cls=ValueError, 256 except_from=e, 257 ) 258 return False
given a dataclass, check the field matches the type hint
this function is written to SerializableDataclass.validate_field_type
Parameters:
self : SerializableDataclassSerializableDataclassinstancefield : SerializableField | strfield to validate, will get fromself.__dataclass_fields__if anstron_typecheck_error : ErrorModewhat to do if type checking throws an exception (except, warn, ignore). Ifignoreand an exception is thrown, the function will returnFalse(defaults to_DEFAULT_ON_TYPECHECK_ERROR)
Returns:
boolif the field type is correct.Falseif the field type is incorrect or an exception is thrown andon_typecheck_errorisignore
261def SerializableDataclass__validate_fields_types__dict( 262 self: SerializableDataclass, 263 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 264) -> dict[str, bool]: 265 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 266 267 returns a dict of field names to bools, where the bool is if the field type is valid 268 """ 269 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 270 271 # if except, bundle the exceptions 272 results: dict[str, bool] = dict() 273 exceptions: dict[str, Exception] = dict() 274 275 # for each field in the class 276 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 277 for field in cls_fields: 278 try: 279 results[field.name] = self.validate_field_type(field, on_typecheck_error) 280 except Exception as e: 281 results[field.name] = False 282 exceptions[field.name] = e 283 284 # figure out what to do with the exceptions 285 if len(exceptions) > 0: 286 on_typecheck_error.process( 287 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 288 + "\n\t" 289 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 290 except_cls=ValueError, 291 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 292 except_from=list(exceptions.values())[0], 293 ) 294 295 return results
validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field
returns a dict of field names to bools, where the bool is if the field type is valid
298def SerializableDataclass__validate_fields_types( 299 self: SerializableDataclass, 300 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 301) -> bool: 302 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 303 return all( 304 SerializableDataclass__validate_fields_types__dict( 305 self, on_typecheck_error=on_typecheck_error 306 ).values() 307 )
validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field
310@dataclass_transform( 311 field_specifiers=(serializable_field, SerializableField), 312) 313class SerializableDataclass(abc.ABC): 314 """Base class for serializable dataclasses 315 316 only for linting and type checking, still need to call `serializable_dataclass` decorator 317 318 # Usage: 319 320 ```python 321 @serializable_dataclass 322 class MyClass(SerializableDataclass): 323 a: int 324 b: str 325 ``` 326 327 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 328 329 >>> my_obj = MyClass(a=1, b="q") 330 >>> s = json.dumps(my_obj.serialize()) 331 >>> s 332 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 333 >>> read_obj = MyClass.load(json.loads(s)) 334 >>> read_obj == my_obj 335 True 336 337 This isn't too impressive on its own, but it gets more useful when you have nested classses, 338 or fields that are not json-serializable by default: 339 340 ```python 341 @serializable_dataclass 342 class NestedClass(SerializableDataclass): 343 x: str 344 y: MyClass 345 act_fun: torch.nn.Module = serializable_field( 346 default=torch.nn.ReLU(), 347 serialization_fn=lambda x: str(x), 348 deserialize_fn=lambda x: getattr(torch.nn, x)(), 349 ) 350 ``` 351 352 which gives us: 353 354 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 355 >>> s = json.dumps(nc.serialize()) 356 >>> s 357 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 358 >>> read_nc = NestedClass.load(json.loads(s)) 359 >>> read_nc == nc 360 True 361 """ 362 363 def serialize(self) -> dict[str, Any]: 364 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 365 raise NotImplementedError( 366 f"decorate {self.__class__ = } with `@serializable_dataclass`" 367 ) 368 369 @overload 370 @classmethod 371 def load(cls, data: dict[str, Any]) -> Self: ... 372 373 @overload 374 @classmethod 375 def load(cls, data: Self) -> Self: ... 376 377 @classmethod 378 def load(cls, data: dict[str, Any] | Self) -> Self: 379 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 380 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 381 382 def validate_fields_types( 383 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 384 ) -> bool: 385 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 386 return SerializableDataclass__validate_fields_types( 387 self, on_typecheck_error=on_typecheck_error 388 ) 389 390 def validate_field_type( 391 self, 392 field: "SerializableField|str", 393 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 394 ) -> bool: 395 """given a dataclass, check the field matches the type hint""" 396 return SerializableDataclass__validate_field_type( 397 self, field, on_typecheck_error=on_typecheck_error 398 ) 399 400 def __eq__(self, other: Any) -> bool: 401 return dc_eq(self, other) 402 403 def __hash__(self) -> int: 404 "hashes the json-serialized representation of the class" 405 return hash(json.dumps(self.serialize())) 406 407 def diff( 408 self, other: "SerializableDataclass", of_serialized: bool = False 409 ) -> dict[str, Any]: 410 """get a rich and recursive diff between two instances of a serializable dataclass 411 412 ```python 413 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 414 {'b': {'self': 2, 'other': 3}} 415 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 416 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 417 ``` 418 419 # Parameters: 420 - `other : SerializableDataclass` 421 other instance to compare against 422 - `of_serialized : bool` 423 if true, compare serialized data and not raw values 424 (defaults to `False`) 425 426 # Returns: 427 - `dict[str, Any]` 428 429 430 # Raises: 431 - `ValueError` : if the instances are not of the same type 432 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 433 """ 434 # match types 435 if type(self) is not type(other): 436 raise ValueError( 437 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 438 ) 439 440 # initialize the diff result 441 diff_result: dict = {} 442 443 # if they are the same, return the empty diff 444 try: 445 if self == other: 446 return diff_result 447 except Exception: 448 pass 449 450 # if we are working with serialized data, serialize the instances 451 if of_serialized: 452 ser_self: JSONdict = self.serialize() 453 ser_other: JSONdict = other.serialize() 454 455 # for each field in the class 456 for field in dataclasses.fields(self): # type: ignore[arg-type] # pyright: ignore[reportArgumentType] 457 # skip fields that are not for comparison 458 if not field.compare: 459 continue 460 461 # get values 462 field_name: str = field.name 463 self_value = getattr(self, field_name) 464 other_value = getattr(other, field_name) 465 466 # if the values are both serializable dataclasses, recurse 467 if isinstance(self_value, SerializableDataclass) and isinstance( 468 other_value, SerializableDataclass 469 ): 470 nested_diff: dict = self_value.diff( 471 other_value, of_serialized=of_serialized 472 ) 473 if nested_diff: 474 diff_result[field_name] = nested_diff 475 # only support serializable dataclasses 476 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 477 other_value 478 ): 479 raise ValueError("Non-serializable dataclass is not supported") 480 else: 481 # get the values of either the serialized or the actual values 482 if of_serialized: 483 self_value_s = ser_self[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] 484 other_value_s = ser_other[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] 485 else: 486 self_value_s = self_value 487 other_value_s = other_value 488 # compare the values 489 if not array_safe_eq(self_value_s, other_value_s): 490 diff_result[field_name] = {"self": self_value, "other": other_value} 491 492 # return the diff result 493 return diff_result 494 495 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 496 """update the instance from a nested dict, useful for configuration from command line args 497 498 # Parameters: 499 - `nested_dict : dict[str, Any]` 500 nested dict to update the instance with 501 """ 502 for field in dataclasses.fields(self): # type: ignore[arg-type] 503 field_name: str = field.name 504 self_value = getattr(self, field_name) 505 506 if field_name in nested_dict: 507 if isinstance(self_value, SerializableDataclass): 508 self_value.update_from_nested_dict(nested_dict[field_name]) 509 else: 510 setattr(self, field_name, nested_dict[field_name]) 511 512 def __copy__(self) -> "SerializableDataclass": 513 "deep copy by serializing and loading the instance to json" 514 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 515 516 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 517 "deep copy by serializing and loading the instance to json" 518 return self.__class__.load(json.loads(json.dumps(self.serialize())))
Base class for serializable dataclasses
only for linting and type checking, still need to call serializable_dataclass decorator
Usage:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
363 def serialize(self) -> dict[str, Any]: 364 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 365 raise NotImplementedError( 366 f"decorate {self.__class__ = } with `@serializable_dataclass`" 367 )
returns the class as a dict, implemented by using @serializable_dataclass decorator
377 @classmethod 378 def load(cls, data: dict[str, Any] | Self) -> Self: 379 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 380 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator
382 def validate_fields_types( 383 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 384 ) -> bool: 385 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 386 return SerializableDataclass__validate_fields_types( 387 self, on_typecheck_error=on_typecheck_error 388 )
validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field
390 def validate_field_type( 391 self, 392 field: "SerializableField|str", 393 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 394 ) -> bool: 395 """given a dataclass, check the field matches the type hint""" 396 return SerializableDataclass__validate_field_type( 397 self, field, on_typecheck_error=on_typecheck_error 398 )
given a dataclass, check the field matches the type hint
407 def diff( 408 self, other: "SerializableDataclass", of_serialized: bool = False 409 ) -> dict[str, Any]: 410 """get a rich and recursive diff between two instances of a serializable dataclass 411 412 ```python 413 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 414 {'b': {'self': 2, 'other': 3}} 415 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 416 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 417 ``` 418 419 # Parameters: 420 - `other : SerializableDataclass` 421 other instance to compare against 422 - `of_serialized : bool` 423 if true, compare serialized data and not raw values 424 (defaults to `False`) 425 426 # Returns: 427 - `dict[str, Any]` 428 429 430 # Raises: 431 - `ValueError` : if the instances are not of the same type 432 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 433 """ 434 # match types 435 if type(self) is not type(other): 436 raise ValueError( 437 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 438 ) 439 440 # initialize the diff result 441 diff_result: dict = {} 442 443 # if they are the same, return the empty diff 444 try: 445 if self == other: 446 return diff_result 447 except Exception: 448 pass 449 450 # if we are working with serialized data, serialize the instances 451 if of_serialized: 452 ser_self: JSONdict = self.serialize() 453 ser_other: JSONdict = other.serialize() 454 455 # for each field in the class 456 for field in dataclasses.fields(self): # type: ignore[arg-type] # pyright: ignore[reportArgumentType] 457 # skip fields that are not for comparison 458 if not field.compare: 459 continue 460 461 # get values 462 field_name: str = field.name 463 self_value = getattr(self, field_name) 464 other_value = getattr(other, field_name) 465 466 # if the values are both serializable dataclasses, recurse 467 if isinstance(self_value, SerializableDataclass) and isinstance( 468 other_value, SerializableDataclass 469 ): 470 nested_diff: dict = self_value.diff( 471 other_value, of_serialized=of_serialized 472 ) 473 if nested_diff: 474 diff_result[field_name] = nested_diff 475 # only support serializable dataclasses 476 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 477 other_value 478 ): 479 raise ValueError("Non-serializable dataclass is not supported") 480 else: 481 # get the values of either the serialized or the actual values 482 if of_serialized: 483 self_value_s = ser_self[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] 484 other_value_s = ser_other[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] 485 else: 486 self_value_s = self_value 487 other_value_s = other_value 488 # compare the values 489 if not array_safe_eq(self_value_s, other_value_s): 490 diff_result[field_name] = {"self": self_value, "other": other_value} 491 492 # return the diff result 493 return diff_result
get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
Parameters:
other : SerializableDataclassother instance to compare againstof_serialized : boolif true, compare serialized data and not raw values (defaults toFalse)
Returns:
dict[str, Any]
Raises:
ValueError: if the instances are not of the same typeValueError: if the instances aredataclasses.dataclassbut notSerializableDataclass
495 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 496 """update the instance from a nested dict, useful for configuration from command line args 497 498 # Parameters: 499 - `nested_dict : dict[str, Any]` 500 nested dict to update the instance with 501 """ 502 for field in dataclasses.fields(self): # type: ignore[arg-type] 503 field_name: str = field.name 504 self_value = getattr(self, field_name) 505 506 if field_name in nested_dict: 507 if isinstance(self_value, SerializableDataclass): 508 self_value.update_from_nested_dict(nested_dict[field_name]) 509 else: 510 setattr(self, field_name, nested_dict[field_name])
update the instance from a nested dict, useful for configuration from command line args
Parameters:
- `nested_dict : dict[str, Any]`
nested dict to update the instance with
523@functools.lru_cache(typed=True) 524def get_cls_type_hints_cached(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]: 525 "cached typing.get_type_hints for a class" 526 return typing.get_type_hints(cls)
cached typing.get_type_hints for a class
529def get_cls_type_hints(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]: 530 "helper function to get type hints for a class" 531 cls_type_hints: dict[str, Any] 532 try: 533 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 534 if len(cls_type_hints) == 0: 535 cls_type_hints = typing.get_type_hints(cls) 536 537 if len(cls_type_hints) == 0: 538 raise ValueError(f"empty type hints for {cls.__name__ = }") 539 except (TypeError, NameError, ValueError) as e: 540 raise TypeError( 541 f"Cannot get type hints for {cls = }\n" 542 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 543 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 544 + f" {e = }" 545 ) from e 546 547 return cls_type_hints
helper function to get type hints for a class
550class KWOnlyError(NotImplementedError): 551 "kw-only dataclasses are not supported in python <3.9" 552 553 pass
kw-only dataclasses are not supported in python <3.9
Inherited Members
- builtins.NotImplementedError
- NotImplementedError
- builtins.BaseException
- with_traceback
- add_note
- args
base class for field errors
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
562class NotSerializableFieldException(FieldError): 563 "field is not a `SerializableField`" 564 565 pass
field is not a SerializableField
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
error while serializing a field
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
error while loading a field
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
580class FieldTypeMismatchError(FieldError, TypeError): 581 "error when a field type does not match the type hint" 582 583 pass
error when a field type does not match the type hint
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
586@dataclass_transform( 587 field_specifiers=(serializable_field, SerializableField), 588) 589def serializable_dataclass( 590 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 591 _cls=None, # type: ignore 592 *, 593 init: bool = True, 594 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 595 eq: bool = True, 596 order: bool = False, 597 unsafe_hash: bool = False, 598 frozen: bool = False, 599 properties_to_serialize: Optional[list[str]] = None, 600 register_handler: bool = True, 601 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 602 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 603 methods_no_override: list[str] | None = None, 604 **kwargs: Any, 605) -> Any: 606 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!** 607 608 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 609 610 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass` 611 612 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 613 614 Examines PEP 526 `__annotations__` to determine fields. 615 616 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 617 618 ```python 619 @serializable_dataclass(kw_only=True) 620 class Myclass(SerializableDataclass): 621 a: int 622 b: str 623 ``` 624 ```python 625 >>> Myclass(a=1, b="q").serialize() 626 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 627 ``` 628 629 # Parameters: 630 631 - `_cls : _type_` 632 class to decorate. don't pass this arg, just use this as a decorator 633 (defaults to `None`) 634 - `init : bool` 635 whether to add an `__init__` method 636 *(passed to dataclasses.dataclass)* 637 (defaults to `True`) 638 - `repr : bool` 639 whether to add a `__repr__` method 640 *(passed to dataclasses.dataclass)* 641 (defaults to `True`) 642 - `order : bool` 643 whether to add rich comparison methods 644 *(passed to dataclasses.dataclass)* 645 (defaults to `False`) 646 - `unsafe_hash : bool` 647 whether to add a `__hash__` method 648 *(passed to dataclasses.dataclass)* 649 (defaults to `False`) 650 - `frozen : bool` 651 whether to make the class frozen 652 *(passed to dataclasses.dataclass)* 653 (defaults to `False`) 654 - `properties_to_serialize : Optional[list[str]]` 655 which properties to add to the serialized data dict 656 **SerializableDataclass only** 657 (defaults to `None`) 658 - `register_handler : bool` 659 if true, register the class with ZANJ for loading 660 **SerializableDataclass only** 661 (defaults to `True`) 662 - `on_typecheck_error : ErrorMode` 663 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 664 **SerializableDataclass only** 665 - `on_typecheck_mismatch : ErrorMode` 666 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 667 **SerializableDataclass only** 668 - `methods_no_override : list[str]|None` 669 list of methods that should not be overridden by the decorator 670 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function, 671 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence 672 **SerializableDataclass only** 673 (defaults to `None`) 674 - `**kwargs` 675 *(passed to dataclasses.dataclass)* 676 677 # Returns: 678 679 - `_type_` 680 the decorated class 681 682 # Raises: 683 684 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 685 - `NotSerializableFieldException` : if a field is not a `SerializableField` 686 - `FieldSerializationError` : if there is an error serializing a field 687 - `AttributeError` : if a property is not found on the class 688 - `FieldLoadingError` : if there is an error loading a field 689 """ 690 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 691 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 692 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 693 694 if properties_to_serialize is None: 695 _properties_to_serialize: list = list() 696 else: 697 _properties_to_serialize = properties_to_serialize 698 699 def wrap(cls: Type[T_SerializeableDataclass]) -> Type[T_SerializeableDataclass]: 700 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 701 for field_name, field_type in cls.__annotations__.items(): 702 field_value = getattr(cls, field_name, None) 703 if not isinstance(field_value, SerializableField): 704 if isinstance(field_value, dataclasses.Field): 705 # Convert the field to a SerializableField while preserving properties 706 field_value = SerializableField.from_Field(field_value) 707 else: 708 # Create a new SerializableField 709 field_value = serializable_field() 710 setattr(cls, field_name, field_value) 711 712 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 713 if sys.version_info < (3, 10): 714 if "kw_only" in kwargs: 715 if kwargs["kw_only"] == True: # noqa: E712 716 raise KWOnlyError( 717 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored" 718 ) 719 else: 720 del kwargs["kw_only"] 721 722 # call `dataclasses.dataclass` to set some stuff up 723 cls = dataclasses.dataclass( # type: ignore[call-overload] 724 cls, 725 init=init, 726 repr=repr, 727 eq=eq, 728 order=order, 729 unsafe_hash=unsafe_hash, 730 frozen=frozen, 731 **kwargs, 732 ) 733 734 # copy these to the class 735 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 736 737 # ====================================================================== 738 # define `serialize` func 739 # done locally since it depends on args to the decorator 740 # ====================================================================== 741 def serialize(self: Any) -> dict[str, Any]: 742 result: dict[str, Any] = { 743 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 744 } 745 # for each field in the class 746 for field in dataclasses.fields(self): # type: ignore[arg-type] 747 # need it to be our special SerializableField 748 if not isinstance(field, SerializableField): 749 raise NotSerializableFieldException( 750 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 751 f"but a {type(field)} " 752 "this state should be inaccessible, please report this bug!" 753 ) 754 755 # try to save it 756 if field.serialize: 757 value: Any = None # init before try in case getattr raises 758 try: 759 # get the val 760 value = getattr(self, field.name) 761 # if it is a serializable dataclass, serialize it 762 if isinstance(value, SerializableDataclass): 763 value = value.serialize() 764 # if the value has a serialization function, use that 765 if hasattr(value, "serialize") and callable(value.serialize): # pyright: ignore[reportAttributeAccessIssue] 766 value = value.serialize() # pyright: ignore[reportAttributeAccessIssue] 767 # if the field has a serialization function, use that 768 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 769 elif field.serialization_fn: 770 value = field.serialization_fn(value) 771 772 # store the value in the result 773 result[field.name] = value 774 except Exception as e: 775 raise FieldSerializationError( 776 "\n".join( 777 [ 778 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 779 f"{field = }", 780 f"{value or '<unavailable>' = }", 781 f"{self = }", 782 ] 783 ) 784 ) from e 785 786 # store each property if we can get it 787 for prop in self._properties_to_serialize: 788 if hasattr(cls, prop): 789 value = getattr(self, prop) 790 result[prop] = value 791 else: 792 raise AttributeError( 793 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 794 + f"but it is in {self._properties_to_serialize = }" 795 + f"\n{self = }" 796 ) 797 798 return result 799 800 # ====================================================================== 801 # define `load` func 802 # done locally since it depends on args to the decorator 803 # ====================================================================== 804 # mypy thinks this isnt a classmethod 805 @classmethod # type: ignore[misc] 806 def load( 807 cls: type[T_SerializeableDataclass], 808 data: dict[str, Any] | T_SerializeableDataclass, 809 ) -> T_SerializeableDataclass: 810 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 811 if isinstance(data, cls): 812 return data 813 814 assert isinstance(data, typing.Mapping), ( 815 f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 816 ) 817 818 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 819 820 # initialize dict for keeping what we will pass to the constructor 821 ctor_kwargs: dict[str, Any] = dict() 822 823 # iterate over the fields of the class 824 # mypy doesn't recognize @dataclass_transform for dataclasses.fields() 825 # https://github.com/python/mypy/issues/16241 826 for field in dataclasses.fields(cls): # type: ignore[arg-type] 827 # check if the field is a SerializableField 828 assert isinstance(field, SerializableField), ( 829 f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 830 ) 831 832 # check if the field is in the data and if it should be initialized 833 if (field.name in data) and field.init: 834 # get the value, we will be processing it 835 value: Any = data[field.name] 836 837 # get the type hint for the field 838 field_type_hint: Any = cls_type_hints.get(field.name, None) 839 840 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 841 if field.deserialize_fn: 842 # if it has a deserialization function, use that 843 value = field.deserialize_fn(value) 844 elif field.loading_fn: 845 # if it has a loading function, use that 846 value = field.loading_fn(data) 847 elif ( 848 field_type_hint is not None 849 and hasattr(field_type_hint, "load") 850 and callable(field_type_hint.load) 851 ): 852 # if no loading function but has a type hint with a load method, use that 853 if isinstance(value, dict): 854 value = field_type_hint.load(value) 855 else: 856 raise FieldLoadingError( 857 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 858 ) 859 else: 860 # assume no loading needs to happen, keep `value` as-is 861 pass 862 863 # store the value in the constructor kwargs 864 ctor_kwargs[field.name] = value 865 866 # create a new instance of the class with the constructor kwargs 867 output: T_SerializeableDataclass = cls(**ctor_kwargs) 868 869 # validate the types of the fields if needed 870 if on_typecheck_mismatch != ErrorMode.IGNORE: 871 fields_valid: dict[str, bool] = ( 872 SerializableDataclass__validate_fields_types__dict( 873 output, 874 on_typecheck_error=on_typecheck_error, 875 ) 876 ) 877 878 # if there are any fields that are not valid, raise an error 879 if not all(fields_valid.values()): 880 msg: str = ( 881 f"Type mismatch in fields of {cls.__name__}:\n" 882 + "\n".join( 883 [ 884 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 885 for k, v in fields_valid.items() 886 if not v 887 ] 888 ) 889 ) 890 891 on_typecheck_mismatch.process( 892 msg, except_cls=FieldTypeMismatchError 893 ) 894 895 # return the new instance 896 return output 897 898 _methods_no_override: set[str] 899 if methods_no_override is None: 900 _methods_no_override = set() 901 else: 902 _methods_no_override = set(methods_no_override) 903 904 if _methods_no_override - { 905 "__eq__", 906 "serialize", 907 "load", 908 "validate_fields_types", 909 }: 910 warnings.warn( 911 f"Unknown methods in `methods_no_override`: {_methods_no_override = }" 912 ) 913 914 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 915 if "serialize" not in _methods_no_override: 916 # type is `Callable[[T], dict]` 917 cls.serialize = serialize # type: ignore[attr-defined, method-assign] 918 if "load" not in _methods_no_override: 919 # type is `Callable[[dict], T]` 920 cls.load = load # type: ignore[attr-defined, method-assign, assignment] 921 922 if "validate_field_type" not in _methods_no_override: 923 # type is `Callable[[T, ErrorMode], bool]` 924 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined, method-assign] 925 926 if "__eq__" not in _methods_no_override: 927 # type is `Callable[[T, T], bool]` 928 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 929 930 # Register the class with ZANJ 931 if register_handler: 932 zanj_register_loader_serializable_dataclass(cls) 933 934 return cls 935 936 if _cls is None: 937 return wrap 938 else: 939 return wrap(_cls)
decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass!!
types will be validated (like pydantic) unless on_typecheck_mismatch is set to ErrorMode.IGNORE
behavior of most kwargs matches that of dataclasses.dataclass, but with some additional kwargs. any kwargs not listed here are passed to dataclasses.dataclass
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__ to determine fields.
If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str
>>> Myclass(a=1, b="q").serialize()
{_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
Parameters:
_cls : _type_class to decorate. don't pass this arg, just use this as a decorator (defaults toNone)init : boolwhether to add an__init__method (passed to dataclasses.dataclass) (defaults toTrue)repr : boolwhether to add a__repr__method (passed to dataclasses.dataclass) (defaults toTrue)order : boolwhether to add rich comparison methods (passed to dataclasses.dataclass) (defaults toFalse)unsafe_hash : boolwhether to add a__hash__method (passed to dataclasses.dataclass) (defaults toFalse)frozen : boolwhether to make the class frozen (passed to dataclasses.dataclass) (defaults toFalse)properties_to_serialize : Optional[list[str]]which properties to add to the serialized data dict SerializableDataclass only (defaults toNone)register_handler : boolif true, register the class with ZANJ for loading SerializableDataclass only (defaults toTrue)on_typecheck_error : ErrorModewhat to do if type checking throws an exception (except, warn, ignore). Ifignoreand an exception is thrown, type validation will still return false SerializableDataclass onlyon_typecheck_mismatch : ErrorModewhat to do if a type mismatch is found (except, warn, ignore). Ifignore, type validation will returnTrueSerializableDataclass onlymethods_no_override : list[str]|Nonelist of methods that should not be overridden by the decorator by default,__eq__,serialize,load, andvalidate_fields_typesare overridden by this function, but you can disable this if you'd rather write your own.dataclasses.dataclassmight still overwrite these, and those options take precedence SerializableDataclass only (defaults toNone)**kwargs(passed to dataclasses.dataclass)
Returns:
_type_the decorated class
Raises:
KWOnlyError: only raised ifkw_onlyisTrueand python version is <3.9, sincedataclasses.dataclassdoes not support thisNotSerializableFieldException: if a field is not aSerializableFieldFieldSerializationError: if there is an error serializing a fieldAttributeError: if a property is not found on the classFieldLoadingError: if there is an error loading a field