Coverage for muutils\json_serialize\serializable_dataclass.py: 55%
242 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-15 21:53 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-15 21:53 -0600
1"""save and load objects to and from json or compatible formats in a recoverable way
3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable,
4you will get an error when you call `json.dumps(d)`. This module provides a way around that.
6Instead, you define your class:
8```python
9@serializable_dataclass
10class MyClass(SerializableDataclass):
11 a: int
12 b: str
13```
15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
17 >>> my_obj = MyClass(a=1, b="q")
18 >>> s = json.dumps(my_obj.serialize())
19 >>> s
20 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
21 >>> read_obj = MyClass.load(json.loads(s))
22 >>> read_obj == my_obj
23 True
25This isn't too impressive on its own, but it gets more useful when you have nested classses,
26or fields that are not json-serializable by default:
28```python
29@serializable_dataclass
30class NestedClass(SerializableDataclass):
31 x: str
32 y: MyClass
33 act_fun: torch.nn.Module = serializable_field(
34 default=torch.nn.ReLU(),
35 serialization_fn=lambda x: str(x),
36 deserialize_fn=lambda x: getattr(torch.nn, x)(),
37 )
38```
40which gives us:
42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
43 >>> s = json.dumps(nc.serialize())
44 >>> s
45 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
46 >>> read_nc = NestedClass.load(json.loads(s))
47 >>> read_nc == nc
48 True
50"""
52from __future__ import annotations
54import abc
55import dataclasses
56import functools
57import json
58import sys
59import typing
60import warnings
61from typing import Any, Optional, Type, TypeVar
63from muutils.errormode import ErrorMode
64from muutils.validate_type import validate_type
65from muutils.json_serialize.serializable_field import (
66 SerializableField,
67 serializable_field,
68)
69from muutils.json_serialize.util import array_safe_eq, dc_eq
71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access
74def _dataclass_transform_mock(
75 *,
76 eq_default: bool = True,
77 order_default: bool = False,
78 kw_only_default: bool = False,
79 frozen_default: bool = False,
80 field_specifiers: tuple[type[Any] | typing.Callable[..., Any], ...] = (),
81 **kwargs: Any,
82) -> typing.Callable:
83 "mock `typing.dataclass_transform` for python <3.11"
85 def decorator(cls_or_fn):
86 cls_or_fn.__dataclass_transform__ = {
87 "eq_default": eq_default,
88 "order_default": order_default,
89 "kw_only_default": kw_only_default,
90 "frozen_default": frozen_default,
91 "field_specifiers": field_specifiers,
92 "kwargs": kwargs,
93 }
94 return cls_or_fn
96 return decorator
99dataclass_transform: typing.Callable
100if sys.version_info < (3, 11):
101 dataclass_transform = _dataclass_transform_mock
102else:
103 dataclass_transform = typing.dataclass_transform
106T = TypeVar("T")
109class CantGetTypeHintsWarning(UserWarning):
110 "special warning for when we can't get type hints"
112 pass
115class ZanjMissingWarning(UserWarning):
116 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
118 pass
121_zanj_loading_needs_import: bool = True
122"flag to keep track of if we have successfully imported ZANJ"
125def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
126 """Register a serializable dataclass with the ZANJ import
128 this allows `ZANJ().read()` to load the class and not just return plain dicts
131 # TODO: there is some duplication here with register_loader_handler
132 """
133 global _zanj_loading_needs_import
135 if _zanj_loading_needs_import:
136 try:
137 from zanj.loading import ( # type: ignore[import]
138 LoaderHandler,
139 register_loader_handler,
140 )
141 except ImportError:
142 warnings.warn(
143 "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
144 ZanjMissingWarning,
145 )
146 return
148 _format: str = f"{cls.__name__}(SerializableDataclass)"
149 lh: LoaderHandler = LoaderHandler(
150 check=lambda json_item, path=None, z=None: ( # type: ignore
151 isinstance(json_item, dict)
152 and "__format__" in json_item
153 and json_item["__format__"].startswith(_format)
154 ),
155 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore
156 uid=_format,
157 source_pckg=cls.__module__,
158 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
159 )
161 register_loader_handler(lh)
163 return lh
166_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN
167_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT
170class FieldIsNotInitOrSerializeWarning(UserWarning):
171 pass
174def SerializableDataclass__validate_field_type(
175 self: SerializableDataclass,
176 field: SerializableField | str,
177 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
178) -> bool:
179 """given a dataclass, check the field matches the type hint
181 this function is written to `SerializableDataclass.validate_field_type`
183 # Parameters:
184 - `self : SerializableDataclass`
185 `SerializableDataclass` instance
186 - `field : SerializableField | str`
187 field to validate, will get from `self.__dataclass_fields__` if an `str`
188 - `on_typecheck_error : ErrorMode`
189 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
190 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
192 # Returns:
193 - `bool`
194 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
195 """
196 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
198 # get field
199 _field: SerializableField
200 if isinstance(field, str):
201 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined]
202 else:
203 _field = field
205 # do nothing case
206 if not _field.assert_type:
207 return True
209 # if field is not `init` or not `serialize`, skip but warn
210 # TODO: how to handle fields which are not `init` or `serialize`?
211 if not _field.init or not _field.serialize:
212 warnings.warn(
213 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
214 FieldIsNotInitOrSerializeWarning,
215 )
216 return True
218 assert isinstance(
219 _field, SerializableField
220 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
222 # get field type hints
223 try:
224 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
225 except KeyError as e:
226 on_typecheck_error.process(
227 (
228 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
229 + f"{get_cls_type_hints(self.__class__) = }\n"
230 + f"Python version is {sys.version_info = }. You can:\n"
231 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n"
232 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
233 + " - use python 3.9.x or higher\n"
234 + " - specify custom type validation function via `custom_typecheck_fn`\n"
235 ),
236 except_cls=TypeError,
237 except_from=e,
238 )
239 return False
241 # get the value
242 value: Any = getattr(self, _field.name)
244 # validate the type
245 try:
246 type_is_valid: bool
247 # validate the type with the default type validator
248 if _field.custom_typecheck_fn is None:
249 type_is_valid = validate_type(value, field_type_hint)
250 # validate the type with a custom type validator
251 else:
252 type_is_valid = _field.custom_typecheck_fn(field_type_hint)
254 return type_is_valid
256 except Exception as e:
257 on_typecheck_error.process(
258 "exception while validating type: "
259 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
260 except_cls=ValueError,
261 except_from=e,
262 )
263 return False
266def SerializableDataclass__validate_fields_types__dict(
267 self: SerializableDataclass,
268 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
269) -> dict[str, bool]:
270 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
272 returns a dict of field names to bools, where the bool is if the field type is valid
273 """
274 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
276 # if except, bundle the exceptions
277 results: dict[str, bool] = dict()
278 exceptions: dict[str, Exception] = dict()
280 # for each field in the class
281 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment]
282 for field in cls_fields:
283 try:
284 results[field.name] = self.validate_field_type(field, on_typecheck_error)
285 except Exception as e:
286 results[field.name] = False
287 exceptions[field.name] = e
289 # figure out what to do with the exceptions
290 if len(exceptions) > 0:
291 on_typecheck_error.process(
292 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
293 + "\n\t"
294 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
295 except_cls=ValueError,
296 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
297 except_from=list(exceptions.values())[0],
298 )
300 return results
303def SerializableDataclass__validate_fields_types(
304 self: SerializableDataclass,
305 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
306) -> bool:
307 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
308 return all(
309 SerializableDataclass__validate_fields_types__dict(
310 self, on_typecheck_error=on_typecheck_error
311 ).values()
312 )
315@dataclass_transform(
316 field_specifiers=(serializable_field, SerializableField),
317)
318class SerializableDataclass(abc.ABC):
319 """Base class for serializable dataclasses
321 only for linting and type checking, still need to call `serializable_dataclass` decorator
323 # Usage:
325 ```python
326 @serializable_dataclass
327 class MyClass(SerializableDataclass):
328 a: int
329 b: str
330 ```
332 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
334 >>> my_obj = MyClass(a=1, b="q")
335 >>> s = json.dumps(my_obj.serialize())
336 >>> s
337 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
338 >>> read_obj = MyClass.load(json.loads(s))
339 >>> read_obj == my_obj
340 True
342 This isn't too impressive on its own, but it gets more useful when you have nested classses,
343 or fields that are not json-serializable by default:
345 ```python
346 @serializable_dataclass
347 class NestedClass(SerializableDataclass):
348 x: str
349 y: MyClass
350 act_fun: torch.nn.Module = serializable_field(
351 default=torch.nn.ReLU(),
352 serialization_fn=lambda x: str(x),
353 deserialize_fn=lambda x: getattr(torch.nn, x)(),
354 )
355 ```
357 which gives us:
359 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
360 >>> s = json.dumps(nc.serialize())
361 >>> s
362 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
363 >>> read_nc = NestedClass.load(json.loads(s))
364 >>> read_nc == nc
365 True
366 """
368 def serialize(self) -> dict[str, Any]:
369 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
370 raise NotImplementedError(
371 f"decorate {self.__class__ = } with `@serializable_dataclass`"
372 )
374 @classmethod
375 def load(cls: Type[T], data: dict[str, Any] | T) -> T:
376 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
377 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
379 def validate_fields_types(
380 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
381 ) -> bool:
382 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
383 return SerializableDataclass__validate_fields_types(
384 self, on_typecheck_error=on_typecheck_error
385 )
387 def validate_field_type(
388 self,
389 field: "SerializableField|str",
390 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
391 ) -> bool:
392 """given a dataclass, check the field matches the type hint"""
393 return SerializableDataclass__validate_field_type(
394 self, field, on_typecheck_error=on_typecheck_error
395 )
397 def __eq__(self, other: Any) -> bool:
398 return dc_eq(self, other)
400 def __hash__(self) -> int:
401 "hashes the json-serialized representation of the class"
402 return hash(json.dumps(self.serialize()))
404 def diff(
405 self, other: "SerializableDataclass", of_serialized: bool = False
406 ) -> dict[str, Any]:
407 """get a rich and recursive diff between two instances of a serializable dataclass
409 ```python
410 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
411 {'b': {'self': 2, 'other': 3}}
412 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
413 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
414 ```
416 # Parameters:
417 - `other : SerializableDataclass`
418 other instance to compare against
419 - `of_serialized : bool`
420 if true, compare serialized data and not raw values
421 (defaults to `False`)
423 # Returns:
424 - `dict[str, Any]`
427 # Raises:
428 - `ValueError` : if the instances are not of the same type
429 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
430 """
431 # match types
432 if type(self) is not type(other):
433 raise ValueError(
434 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
435 )
437 # initialize the diff result
438 diff_result: dict = {}
440 # if they are the same, return the empty diff
441 if self == other:
442 return diff_result
444 # if we are working with serialized data, serialize the instances
445 if of_serialized:
446 ser_self: dict = self.serialize()
447 ser_other: dict = other.serialize()
449 # for each field in the class
450 for field in dataclasses.fields(self): # type: ignore[arg-type]
451 # skip fields that are not for comparison
452 if not field.compare:
453 continue
455 # get values
456 field_name: str = field.name
457 self_value = getattr(self, field_name)
458 other_value = getattr(other, field_name)
460 # if the values are both serializable dataclasses, recurse
461 if isinstance(self_value, SerializableDataclass) and isinstance(
462 other_value, SerializableDataclass
463 ):
464 nested_diff: dict = self_value.diff(
465 other_value, of_serialized=of_serialized
466 )
467 if nested_diff:
468 diff_result[field_name] = nested_diff
469 # only support serializable dataclasses
470 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
471 other_value
472 ):
473 raise ValueError("Non-serializable dataclass is not supported")
474 else:
475 # get the values of either the serialized or the actual values
476 self_value_s = ser_self[field_name] if of_serialized else self_value
477 other_value_s = ser_other[field_name] if of_serialized else other_value
478 # compare the values
479 if not array_safe_eq(self_value_s, other_value_s):
480 diff_result[field_name] = {"self": self_value, "other": other_value}
482 # return the diff result
483 return diff_result
485 def update_from_nested_dict(self, nested_dict: dict[str, Any]):
486 """update the instance from a nested dict, useful for configuration from command line args
488 # Parameters:
489 - `nested_dict : dict[str, Any]`
490 nested dict to update the instance with
491 """
492 for field in dataclasses.fields(self): # type: ignore[arg-type]
493 field_name: str = field.name
494 self_value = getattr(self, field_name)
496 if field_name in nested_dict:
497 if isinstance(self_value, SerializableDataclass):
498 self_value.update_from_nested_dict(nested_dict[field_name])
499 else:
500 setattr(self, field_name, nested_dict[field_name])
502 def __copy__(self) -> "SerializableDataclass":
503 "deep copy by serializing and loading the instance to json"
504 return self.__class__.load(json.loads(json.dumps(self.serialize())))
506 def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
507 "deep copy by serializing and loading the instance to json"
508 return self.__class__.load(json.loads(json.dumps(self.serialize())))
511# cache this so we don't have to keep getting it
512# TODO: are the types hashable? does this even make sense?
513@functools.lru_cache(typed=True)
514def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
515 "cached typing.get_type_hints for a class"
516 return typing.get_type_hints(cls)
519def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
520 "helper function to get type hints for a class"
521 cls_type_hints: dict[str, Any]
522 try:
523 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore
524 if len(cls_type_hints) == 0:
525 cls_type_hints = typing.get_type_hints(cls)
527 if len(cls_type_hints) == 0:
528 raise ValueError(f"empty type hints for {cls.__name__ = }")
529 except (TypeError, NameError, ValueError) as e:
530 raise TypeError(
531 f"Cannot get type hints for {cls = }\n"
532 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
533 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type]
534 + f" {e = }"
535 ) from e
537 return cls_type_hints
540class KWOnlyError(NotImplementedError):
541 "kw-only dataclasses are not supported in python <3.9"
543 pass
546class FieldError(ValueError):
547 "base class for field errors"
549 pass
552class NotSerializableFieldException(FieldError):
553 "field is not a `SerializableField`"
555 pass
558class FieldSerializationError(FieldError):
559 "error while serializing a field"
561 pass
564class FieldLoadingError(FieldError):
565 "error while loading a field"
567 pass
570@dataclass_transform(
571 field_specifiers=(serializable_field, SerializableField),
572)
573def serializable_dataclass(
574 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
575 _cls=None, # type: ignore
576 *,
577 init: bool = True,
578 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
579 eq: bool = True,
580 order: bool = False,
581 unsafe_hash: bool = False,
582 frozen: bool = False,
583 properties_to_serialize: Optional[list[str]] = None,
584 register_handler: bool = True,
585 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
586 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
587 **kwargs,
588):
589 """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass`
591 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
593 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs
595 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
597 Examines PEP 526 `__annotations__` to determine fields.
599 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
601 ```python
602 @serializable_dataclass(kw_only=True)
603 class Myclass(SerializableDataclass):
604 a: int
605 b: str
606 ```
607 ```python
608 >>> Myclass(a=1, b="q").serialize()
609 {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
610 ```
612 # Parameters:
613 - `_cls : _type_`
614 class to decorate. don't pass this arg, just use this as a decorator
615 (defaults to `None`)
616 - `init : bool`
617 (defaults to `True`)
618 - `repr : bool`
619 (defaults to `True`)
620 - `order : bool`
621 (defaults to `False`)
622 - `unsafe_hash : bool`
623 (defaults to `False`)
624 - `frozen : bool`
625 (defaults to `False`)
626 - `properties_to_serialize : Optional[list[str]]`
627 **SerializableDataclass only:** which properties to add to the serialized data dict
628 (defaults to `None`)
629 - `register_handler : bool`
630 **SerializableDataclass only:** if true, register the class with ZANJ for loading
631 (defaults to `True`)
632 - `on_typecheck_error : ErrorMode`
633 **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
634 - `on_typecheck_mismatch : ErrorMode`
635 **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
637 # Returns:
638 - `_type_`
639 the decorated class
641 # Raises:
642 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
643 - `NotSerializableFieldException` : if a field is not a `SerializableField`
644 - `FieldSerializationError` : if there is an error serializing a field
645 - `AttributeError` : if a property is not found on the class
646 - `FieldLoadingError` : if there is an error loading a field
647 """
648 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
649 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
650 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
652 if properties_to_serialize is None:
653 _properties_to_serialize: list = list()
654 else:
655 _properties_to_serialize = properties_to_serialize
657 def wrap(cls: Type[T]) -> Type[T]:
658 # Modify the __annotations__ dictionary to replace regular fields with SerializableField
659 for field_name, field_type in cls.__annotations__.items():
660 field_value = getattr(cls, field_name, None)
661 if not isinstance(field_value, SerializableField):
662 if isinstance(field_value, dataclasses.Field):
663 # Convert the field to a SerializableField while preserving properties
664 field_value = SerializableField.from_Field(field_value)
665 else:
666 # Create a new SerializableField
667 field_value = serializable_field()
668 setattr(cls, field_name, field_value)
670 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
671 if sys.version_info < (3, 10):
672 if "kw_only" in kwargs:
673 if kwargs["kw_only"] == True: # noqa: E712
674 raise KWOnlyError("kw_only is not supported in python >=3.9")
675 else:
676 del kwargs["kw_only"]
678 # call `dataclasses.dataclass` to set some stuff up
679 cls = dataclasses.dataclass( # type: ignore[call-overload]
680 cls,
681 init=init,
682 repr=repr,
683 eq=eq,
684 order=order,
685 unsafe_hash=unsafe_hash,
686 frozen=frozen,
687 **kwargs,
688 )
690 # copy these to the class
691 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined]
693 # ======================================================================
694 # define `serialize` func
695 # done locally since it depends on args to the decorator
696 # ======================================================================
697 def serialize(self) -> dict[str, Any]:
698 result: dict[str, Any] = {
699 "__format__": f"{self.__class__.__name__}(SerializableDataclass)"
700 }
701 # for each field in the class
702 for field in dataclasses.fields(self): # type: ignore[arg-type]
703 # need it to be our special SerializableField
704 if not isinstance(field, SerializableField):
705 raise NotSerializableFieldException(
706 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
707 f"but a {type(field)} "
708 "this state should be inaccessible, please report this bug!"
709 )
711 # try to save it
712 if field.serialize:
713 try:
714 # get the val
715 value = getattr(self, field.name)
716 # if it is a serializable dataclass, serialize it
717 if isinstance(value, SerializableDataclass):
718 value = value.serialize()
719 # if the value has a serialization function, use that
720 if hasattr(value, "serialize") and callable(value.serialize):
721 value = value.serialize()
722 # if the field has a serialization function, use that
723 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
724 elif field.serialization_fn:
725 value = field.serialization_fn(value)
727 # store the value in the result
728 result[field.name] = value
729 except Exception as e:
730 raise FieldSerializationError(
731 "\n".join(
732 [
733 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
734 f"{field = }",
735 f"{value = }",
736 f"{self = }",
737 ]
738 )
739 ) from e
741 # store each property if we can get it
742 for prop in self._properties_to_serialize:
743 if hasattr(cls, prop):
744 value = getattr(self, prop)
745 result[prop] = value
746 else:
747 raise AttributeError(
748 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
749 + f"but it is in {self._properties_to_serialize = }"
750 + f"\n{self = }"
751 )
753 return result
755 # ======================================================================
756 # define `load` func
757 # done locally since it depends on args to the decorator
758 # ======================================================================
759 # mypy thinks this isnt a classmethod
760 @classmethod # type: ignore[misc]
761 def load(cls, data: dict[str, Any] | T) -> Type[T]:
762 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
763 if isinstance(data, cls):
764 return data
766 assert isinstance(
767 data, typing.Mapping
768 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
770 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
772 # initialize dict for keeping what we will pass to the constructor
773 ctor_kwargs: dict[str, Any] = dict()
775 # iterate over the fields of the class
776 for field in dataclasses.fields(cls):
777 # check if the field is a SerializableField
778 assert isinstance(
779 field, SerializableField
780 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
782 # check if the field is in the data and if it should be initialized
783 if (field.name in data) and field.init:
784 # get the value, we will be processing it
785 value: Any = data[field.name]
787 # get the type hint for the field
788 field_type_hint: Any = cls_type_hints.get(field.name, None)
790 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
791 if field.deserialize_fn:
792 # if it has a deserialization function, use that
793 value = field.deserialize_fn(value)
794 elif field.loading_fn:
795 # if it has a loading function, use that
796 value = field.loading_fn(data)
797 elif (
798 field_type_hint is not None
799 and hasattr(field_type_hint, "load")
800 and callable(field_type_hint.load)
801 ):
802 # if no loading function but has a type hint with a load method, use that
803 if isinstance(value, dict):
804 value = field_type_hint.load(value)
805 else:
806 raise FieldLoadingError(
807 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
808 )
809 else:
810 # assume no loading needs to happen, keep `value` as-is
811 pass
813 # store the value in the constructor kwargs
814 ctor_kwargs[field.name] = value
816 # create a new instance of the class with the constructor kwargs
817 output: cls = cls(**ctor_kwargs)
819 # validate the types of the fields if needed
820 if on_typecheck_mismatch != ErrorMode.IGNORE:
821 output.validate_fields_types(on_typecheck_error=on_typecheck_error)
823 # return the new instance
824 return output
826 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
827 # type is `Callable[[T], dict]`
828 cls.serialize = serialize # type: ignore[attr-defined]
829 # type is `Callable[[dict], T]`
830 cls.load = load # type: ignore[attr-defined]
831 # type is `Callable[[T, ErrorMode], bool]`
832 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined]
834 # type is `Callable[[T, T], bool]`
835 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment]
837 # Register the class with ZANJ
838 if register_handler:
839 zanj_register_loader_serializable_dataclass(cls)
841 return cls
843 if _cls is None:
844 return wrap
845 else:
846 return wrap(_cls)