docs for muutils v0.6.13
View Source on GitHub

muutils.json_serialize.serializable_dataclass

save and load objects to and from json or compatible formats in a recoverable way

d = dataclasses.asdict(my_obj) will give you a dict, but if some fields are not json-serializable, you will get an error when you call json.dumps(d). This module provides a way around that.

Instead, you define your class:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True

  1"""save and load objects to and from json or compatible formats in a recoverable way
  2
  3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable,
  4you will get an error when you call `json.dumps(d)`. This module provides a way around that.
  5
  6Instead, you define your class:
  7
  8```python
  9@serializable_dataclass
 10class MyClass(SerializableDataclass):
 11    a: int
 12    b: str
 13```
 14
 15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
 16
 17    >>> my_obj = MyClass(a=1, b="q")
 18    >>> s = json.dumps(my_obj.serialize())
 19    >>> s
 20    '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
 21    >>> read_obj = MyClass.load(json.loads(s))
 22    >>> read_obj == my_obj
 23    True
 24
 25This isn't too impressive on its own, but it gets more useful when you have nested classses,
 26or fields that are not json-serializable by default:
 27
 28```python
 29@serializable_dataclass
 30class NestedClass(SerializableDataclass):
 31    x: str
 32    y: MyClass
 33    act_fun: torch.nn.Module = serializable_field(
 34        default=torch.nn.ReLU(),
 35        serialization_fn=lambda x: str(x),
 36        deserialize_fn=lambda x: getattr(torch.nn, x)(),
 37    )
 38```
 39
 40which gives us:
 41
 42    >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
 43    >>> s = json.dumps(nc.serialize())
 44    >>> s
 45    '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
 46    >>> read_nc = NestedClass.load(json.loads(s))
 47    >>> read_nc == nc
 48    True
 49
 50"""
 51
 52from __future__ import annotations
 53
 54import abc
 55import dataclasses
 56import functools
 57import json
 58import sys
 59import typing
 60import warnings
 61from typing import Any, Optional, Type, TypeVar
 62
 63from muutils.errormode import ErrorMode
 64from muutils.validate_type import validate_type
 65from muutils.json_serialize.serializable_field import (
 66    SerializableField,
 67    serializable_field,
 68)
 69from muutils.json_serialize.util import array_safe_eq, dc_eq
 70
 71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access
 72
 73
 74def _dataclass_transform_mock(
 75    *,
 76    eq_default: bool = True,
 77    order_default: bool = False,
 78    kw_only_default: bool = False,
 79    frozen_default: bool = False,
 80    field_specifiers: tuple[type[Any] | typing.Callable[..., Any], ...] = (),
 81    **kwargs: Any,
 82) -> typing.Callable:
 83    "mock `typing.dataclass_transform` for python <3.11"
 84
 85    def decorator(cls_or_fn):
 86        cls_or_fn.__dataclass_transform__ = {
 87            "eq_default": eq_default,
 88            "order_default": order_default,
 89            "kw_only_default": kw_only_default,
 90            "frozen_default": frozen_default,
 91            "field_specifiers": field_specifiers,
 92            "kwargs": kwargs,
 93        }
 94        return cls_or_fn
 95
 96    return decorator
 97
 98
 99dataclass_transform: typing.Callable
100if sys.version_info < (3, 11):
101    dataclass_transform = _dataclass_transform_mock
102else:
103    dataclass_transform = typing.dataclass_transform
104
105
106T = TypeVar("T")
107
108
109class CantGetTypeHintsWarning(UserWarning):
110    "special warning for when we can't get type hints"
111
112    pass
113
114
115class ZanjMissingWarning(UserWarning):
116    "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
117
118    pass
119
120
121_zanj_loading_needs_import: bool = True
122"flag to keep track of if we have successfully imported ZANJ"
123
124
125def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
126    """Register a serializable dataclass with the ZANJ import
127
128    this allows `ZANJ().read()` to load the class and not just return plain dicts
129
130
131    # TODO: there is some duplication here with register_loader_handler
132    """
133    global _zanj_loading_needs_import
134
135    if _zanj_loading_needs_import:
136        try:
137            from zanj.loading import (  # type: ignore[import]
138                LoaderHandler,
139                register_loader_handler,
140            )
141        except ImportError:
142            warnings.warn(
143                "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
144                ZanjMissingWarning,
145            )
146            return
147
148    _format: str = f"{cls.__name__}(SerializableDataclass)"
149    lh: LoaderHandler = LoaderHandler(
150        check=lambda json_item, path=None, z=None: (  # type: ignore
151            isinstance(json_item, dict)
152            and "__format__" in json_item
153            and json_item["__format__"].startswith(_format)
154        ),
155        load=lambda json_item, path=None, z=None: cls.load(json_item),  # type: ignore
156        uid=_format,
157        source_pckg=cls.__module__,
158        desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
159    )
160
161    register_loader_handler(lh)
162
163    return lh
164
165
166_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN
167_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT
168
169
170class FieldIsNotInitOrSerializeWarning(UserWarning):
171    pass
172
173
174def SerializableDataclass__validate_field_type(
175    self: SerializableDataclass,
176    field: SerializableField | str,
177    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
178) -> bool:
179    """given a dataclass, check the field matches the type hint
180
181    this function is written to `SerializableDataclass.validate_field_type`
182
183    # Parameters:
184     - `self : SerializableDataclass`
185       `SerializableDataclass` instance
186     - `field : SerializableField | str`
187        field to validate, will get from `self.__dataclass_fields__` if an `str`
188     - `on_typecheck_error : ErrorMode`
189        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
190       (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
191
192    # Returns:
193     - `bool`
194        if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
195    """
196    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
197
198    # get field
199    _field: SerializableField
200    if isinstance(field, str):
201        _field = self.__dataclass_fields__[field]  # type: ignore[attr-defined]
202    else:
203        _field = field
204
205    # do nothing case
206    if not _field.assert_type:
207        return True
208
209    # if field is not `init` or not `serialize`, skip but warn
210    # TODO: how to handle fields which are not `init` or `serialize`?
211    if not _field.init or not _field.serialize:
212        warnings.warn(
213            f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
214            FieldIsNotInitOrSerializeWarning,
215        )
216        return True
217
218    assert isinstance(
219        _field, SerializableField
220    ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
221
222    # get field type hints
223    try:
224        field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
225    except KeyError as e:
226        on_typecheck_error.process(
227            (
228                f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
229                + f"{get_cls_type_hints(self.__class__) = }\n"
230                + f"Python version is {sys.version_info = }. You can:\n"
231                + f"  - disable `assert_type`. Currently: {_field.assert_type = }\n"
232                + f"  - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
233                + "  - use python 3.9.x or higher\n"
234                + "  - specify custom type validation function via `custom_typecheck_fn`\n"
235            ),
236            except_cls=TypeError,
237            except_from=e,
238        )
239        return False
240
241    # get the value
242    value: Any = getattr(self, _field.name)
243
244    # validate the type
245    try:
246        type_is_valid: bool
247        # validate the type with the default type validator
248        if _field.custom_typecheck_fn is None:
249            type_is_valid = validate_type(value, field_type_hint)
250        # validate the type with a custom type validator
251        else:
252            type_is_valid = _field.custom_typecheck_fn(field_type_hint)
253
254        return type_is_valid
255
256    except Exception as e:
257        on_typecheck_error.process(
258            "exception while validating type: "
259            + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
260            except_cls=ValueError,
261            except_from=e,
262        )
263        return False
264
265
266def SerializableDataclass__validate_fields_types__dict(
267    self: SerializableDataclass,
268    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
269) -> dict[str, bool]:
270    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
271
272    returns a dict of field names to bools, where the bool is if the field type is valid
273    """
274    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
275
276    # if except, bundle the exceptions
277    results: dict[str, bool] = dict()
278    exceptions: dict[str, Exception] = dict()
279
280    # for each field in the class
281    cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self)  # type: ignore[arg-type, assignment]
282    for field in cls_fields:
283        try:
284            results[field.name] = self.validate_field_type(field, on_typecheck_error)
285        except Exception as e:
286            results[field.name] = False
287            exceptions[field.name] = e
288
289    # figure out what to do with the exceptions
290    if len(exceptions) > 0:
291        on_typecheck_error.process(
292            f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
293            + "\n\t"
294            + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
295            except_cls=ValueError,
296            # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
297            except_from=list(exceptions.values())[0],
298        )
299
300    return results
301
302
303def SerializableDataclass__validate_fields_types(
304    self: SerializableDataclass,
305    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
306) -> bool:
307    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
308    return all(
309        SerializableDataclass__validate_fields_types__dict(
310            self, on_typecheck_error=on_typecheck_error
311        ).values()
312    )
313
314
315@dataclass_transform(
316    field_specifiers=(serializable_field, SerializableField),
317)
318class SerializableDataclass(abc.ABC):
319    """Base class for serializable dataclasses
320
321    only for linting and type checking, still need to call `serializable_dataclass` decorator
322
323    # Usage:
324
325    ```python
326    @serializable_dataclass
327    class MyClass(SerializableDataclass):
328        a: int
329        b: str
330    ```
331
332    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
333
334        >>> my_obj = MyClass(a=1, b="q")
335        >>> s = json.dumps(my_obj.serialize())
336        >>> s
337        '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
338        >>> read_obj = MyClass.load(json.loads(s))
339        >>> read_obj == my_obj
340        True
341
342    This isn't too impressive on its own, but it gets more useful when you have nested classses,
343    or fields that are not json-serializable by default:
344
345    ```python
346    @serializable_dataclass
347    class NestedClass(SerializableDataclass):
348        x: str
349        y: MyClass
350        act_fun: torch.nn.Module = serializable_field(
351            default=torch.nn.ReLU(),
352            serialization_fn=lambda x: str(x),
353            deserialize_fn=lambda x: getattr(torch.nn, x)(),
354        )
355    ```
356
357    which gives us:
358
359        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
360        >>> s = json.dumps(nc.serialize())
361        >>> s
362        '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
363        >>> read_nc = NestedClass.load(json.loads(s))
364        >>> read_nc == nc
365        True
366    """
367
368    def serialize(self) -> dict[str, Any]:
369        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
370        raise NotImplementedError(
371            f"decorate {self.__class__ = } with `@serializable_dataclass`"
372        )
373
374    @classmethod
375    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
376        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
377        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
378
379    def validate_fields_types(
380        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
381    ) -> bool:
382        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
383        return SerializableDataclass__validate_fields_types(
384            self, on_typecheck_error=on_typecheck_error
385        )
386
387    def validate_field_type(
388        self,
389        field: "SerializableField|str",
390        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
391    ) -> bool:
392        """given a dataclass, check the field matches the type hint"""
393        return SerializableDataclass__validate_field_type(
394            self, field, on_typecheck_error=on_typecheck_error
395        )
396
397    def __eq__(self, other: Any) -> bool:
398        return dc_eq(self, other)
399
400    def __hash__(self) -> int:
401        "hashes the json-serialized representation of the class"
402        return hash(json.dumps(self.serialize()))
403
404    def diff(
405        self, other: "SerializableDataclass", of_serialized: bool = False
406    ) -> dict[str, Any]:
407        """get a rich and recursive diff between two instances of a serializable dataclass
408
409        ```python
410        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
411        {'b': {'self': 2, 'other': 3}}
412        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
413        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
414        ```
415
416        # Parameters:
417         - `other : SerializableDataclass`
418           other instance to compare against
419         - `of_serialized : bool`
420           if true, compare serialized data and not raw values
421           (defaults to `False`)
422
423        # Returns:
424         - `dict[str, Any]`
425
426
427        # Raises:
428         - `ValueError` : if the instances are not of the same type
429         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
430        """
431        # match types
432        if type(self) is not type(other):
433            raise ValueError(
434                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
435            )
436
437        # initialize the diff result
438        diff_result: dict = {}
439
440        # if they are the same, return the empty diff
441        if self == other:
442            return diff_result
443
444        # if we are working with serialized data, serialize the instances
445        if of_serialized:
446            ser_self: dict = self.serialize()
447            ser_other: dict = other.serialize()
448
449        # for each field in the class
450        for field in dataclasses.fields(self):  # type: ignore[arg-type]
451            # skip fields that are not for comparison
452            if not field.compare:
453                continue
454
455            # get values
456            field_name: str = field.name
457            self_value = getattr(self, field_name)
458            other_value = getattr(other, field_name)
459
460            # if the values are both serializable dataclasses, recurse
461            if isinstance(self_value, SerializableDataclass) and isinstance(
462                other_value, SerializableDataclass
463            ):
464                nested_diff: dict = self_value.diff(
465                    other_value, of_serialized=of_serialized
466                )
467                if nested_diff:
468                    diff_result[field_name] = nested_diff
469            # only support serializable dataclasses
470            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
471                other_value
472            ):
473                raise ValueError("Non-serializable dataclass is not supported")
474            else:
475                # get the values of either the serialized or the actual values
476                self_value_s = ser_self[field_name] if of_serialized else self_value
477                other_value_s = ser_other[field_name] if of_serialized else other_value
478                # compare the values
479                if not array_safe_eq(self_value_s, other_value_s):
480                    diff_result[field_name] = {"self": self_value, "other": other_value}
481
482        # return the diff result
483        return diff_result
484
485    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
486        """update the instance from a nested dict, useful for configuration from command line args
487
488        # Parameters:
489            - `nested_dict : dict[str, Any]`
490                nested dict to update the instance with
491        """
492        for field in dataclasses.fields(self):  # type: ignore[arg-type]
493            field_name: str = field.name
494            self_value = getattr(self, field_name)
495
496            if field_name in nested_dict:
497                if isinstance(self_value, SerializableDataclass):
498                    self_value.update_from_nested_dict(nested_dict[field_name])
499                else:
500                    setattr(self, field_name, nested_dict[field_name])
501
502    def __copy__(self) -> "SerializableDataclass":
503        "deep copy by serializing and loading the instance to json"
504        return self.__class__.load(json.loads(json.dumps(self.serialize())))
505
506    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
507        "deep copy by serializing and loading the instance to json"
508        return self.__class__.load(json.loads(json.dumps(self.serialize())))
509
510
511# cache this so we don't have to keep getting it
512# TODO: are the types hashable? does this even make sense?
513@functools.lru_cache(typed=True)
514def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
515    "cached typing.get_type_hints for a class"
516    return typing.get_type_hints(cls)
517
518
519def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
520    "helper function to get type hints for a class"
521    cls_type_hints: dict[str, Any]
522    try:
523        cls_type_hints = get_cls_type_hints_cached(cls)  # type: ignore
524        if len(cls_type_hints) == 0:
525            cls_type_hints = typing.get_type_hints(cls)
526
527        if len(cls_type_hints) == 0:
528            raise ValueError(f"empty type hints for {cls.__name__ = }")
529    except (TypeError, NameError, ValueError) as e:
530        raise TypeError(
531            f"Cannot get type hints for {cls = }\n"
532            + f"  Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
533            + f"  {dataclasses.fields(cls) = }\n"  # type: ignore[arg-type]
534            + f"  {e = }"
535        ) from e
536
537    return cls_type_hints
538
539
540class KWOnlyError(NotImplementedError):
541    "kw-only dataclasses are not supported in python <3.9"
542
543    pass
544
545
546class FieldError(ValueError):
547    "base class for field errors"
548
549    pass
550
551
552class NotSerializableFieldException(FieldError):
553    "field is not a `SerializableField`"
554
555    pass
556
557
558class FieldSerializationError(FieldError):
559    "error while serializing a field"
560
561    pass
562
563
564class FieldLoadingError(FieldError):
565    "error while loading a field"
566
567    pass
568
569
570@dataclass_transform(
571    field_specifiers=(serializable_field, SerializableField),
572)
573def serializable_dataclass(
574    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
575    _cls=None,  # type: ignore
576    *,
577    init: bool = True,
578    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
579    eq: bool = True,
580    order: bool = False,
581    unsafe_hash: bool = False,
582    frozen: bool = False,
583    properties_to_serialize: Optional[list[str]] = None,
584    register_handler: bool = True,
585    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
586    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
587    **kwargs,
588):
589    """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass`
590
591    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
592
593    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs
594
595    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
596
597    Examines PEP 526 `__annotations__` to determine fields.
598
599    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
600
601    ```python
602    @serializable_dataclass(kw_only=True)
603    class Myclass(SerializableDataclass):
604        a: int
605        b: str
606    ```
607    ```python
608    >>> Myclass(a=1, b="q").serialize()
609    {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
610    ```
611
612    # Parameters:
613     - `_cls : _type_`
614       class to decorate. don't pass this arg, just use this as a decorator
615       (defaults to `None`)
616     - `init : bool`
617       (defaults to `True`)
618     - `repr : bool`
619       (defaults to `True`)
620     - `order : bool`
621       (defaults to `False`)
622     - `unsafe_hash : bool`
623       (defaults to `False`)
624     - `frozen : bool`
625       (defaults to `False`)
626     - `properties_to_serialize : Optional[list[str]]`
627       **SerializableDataclass only:** which properties to add to the serialized data dict
628       (defaults to `None`)
629     - `register_handler : bool`
630        **SerializableDataclass only:** if true, register the class with ZANJ for loading
631       (defaults to `True`)
632     - `on_typecheck_error : ErrorMode`
633        **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
634     - `on_typecheck_mismatch : ErrorMode`
635        **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
636
637    # Returns:
638     - `_type_`
639       the decorated class
640
641    # Raises:
642     - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
643     - `NotSerializableFieldException` : if a field is not a `SerializableField`
644     - `FieldSerializationError` : if there is an error serializing a field
645     - `AttributeError` : if a property is not found on the class
646     - `FieldLoadingError` : if there is an error loading a field
647    """
648    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
649    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
650    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
651
652    if properties_to_serialize is None:
653        _properties_to_serialize: list = list()
654    else:
655        _properties_to_serialize = properties_to_serialize
656
657    def wrap(cls: Type[T]) -> Type[T]:
658        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
659        for field_name, field_type in cls.__annotations__.items():
660            field_value = getattr(cls, field_name, None)
661            if not isinstance(field_value, SerializableField):
662                if isinstance(field_value, dataclasses.Field):
663                    # Convert the field to a SerializableField while preserving properties
664                    field_value = SerializableField.from_Field(field_value)
665                else:
666                    # Create a new SerializableField
667                    field_value = serializable_field()
668                setattr(cls, field_name, field_value)
669
670        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
671        if sys.version_info < (3, 10):
672            if "kw_only" in kwargs:
673                if kwargs["kw_only"] == True:  # noqa: E712
674                    raise KWOnlyError("kw_only is not supported in python >=3.9")
675                else:
676                    del kwargs["kw_only"]
677
678        # call `dataclasses.dataclass` to set some stuff up
679        cls = dataclasses.dataclass(  # type: ignore[call-overload]
680            cls,
681            init=init,
682            repr=repr,
683            eq=eq,
684            order=order,
685            unsafe_hash=unsafe_hash,
686            frozen=frozen,
687            **kwargs,
688        )
689
690        # copy these to the class
691        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
692
693        # ======================================================================
694        # define `serialize` func
695        # done locally since it depends on args to the decorator
696        # ======================================================================
697        def serialize(self) -> dict[str, Any]:
698            result: dict[str, Any] = {
699                "__format__": f"{self.__class__.__name__}(SerializableDataclass)"
700            }
701            # for each field in the class
702            for field in dataclasses.fields(self):  # type: ignore[arg-type]
703                # need it to be our special SerializableField
704                if not isinstance(field, SerializableField):
705                    raise NotSerializableFieldException(
706                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
707                        f"but a {type(field)} "
708                        "this state should be inaccessible, please report this bug!"
709                    )
710
711                # try to save it
712                if field.serialize:
713                    try:
714                        # get the val
715                        value = getattr(self, field.name)
716                        # if it is a serializable dataclass, serialize it
717                        if isinstance(value, SerializableDataclass):
718                            value = value.serialize()
719                        # if the value has a serialization function, use that
720                        if hasattr(value, "serialize") and callable(value.serialize):
721                            value = value.serialize()
722                        # if the field has a serialization function, use that
723                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
724                        elif field.serialization_fn:
725                            value = field.serialization_fn(value)
726
727                        # store the value in the result
728                        result[field.name] = value
729                    except Exception as e:
730                        raise FieldSerializationError(
731                            "\n".join(
732                                [
733                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
734                                    f"{field = }",
735                                    f"{value = }",
736                                    f"{self = }",
737                                ]
738                            )
739                        ) from e
740
741            # store each property if we can get it
742            for prop in self._properties_to_serialize:
743                if hasattr(cls, prop):
744                    value = getattr(self, prop)
745                    result[prop] = value
746                else:
747                    raise AttributeError(
748                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
749                        + f"but it is in {self._properties_to_serialize = }"
750                        + f"\n{self = }"
751                    )
752
753            return result
754
755        # ======================================================================
756        # define `load` func
757        # done locally since it depends on args to the decorator
758        # ======================================================================
759        # mypy thinks this isnt a classmethod
760        @classmethod  # type: ignore[misc]
761        def load(cls, data: dict[str, Any] | T) -> Type[T]:
762            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
763            if isinstance(data, cls):
764                return data
765
766            assert isinstance(
767                data, typing.Mapping
768            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
769
770            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
771
772            # initialize dict for keeping what we will pass to the constructor
773            ctor_kwargs: dict[str, Any] = dict()
774
775            # iterate over the fields of the class
776            for field in dataclasses.fields(cls):
777                # check if the field is a SerializableField
778                assert isinstance(
779                    field, SerializableField
780                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
781
782                # check if the field is in the data and if it should be initialized
783                if (field.name in data) and field.init:
784                    # get the value, we will be processing it
785                    value: Any = data[field.name]
786
787                    # get the type hint for the field
788                    field_type_hint: Any = cls_type_hints.get(field.name, None)
789
790                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
791                    if field.deserialize_fn:
792                        # if it has a deserialization function, use that
793                        value = field.deserialize_fn(value)
794                    elif field.loading_fn:
795                        # if it has a loading function, use that
796                        value = field.loading_fn(data)
797                    elif (
798                        field_type_hint is not None
799                        and hasattr(field_type_hint, "load")
800                        and callable(field_type_hint.load)
801                    ):
802                        # if no loading function but has a type hint with a load method, use that
803                        if isinstance(value, dict):
804                            value = field_type_hint.load(value)
805                        else:
806                            raise FieldLoadingError(
807                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
808                            )
809                    else:
810                        # assume no loading needs to happen, keep `value` as-is
811                        pass
812
813                    # store the value in the constructor kwargs
814                    ctor_kwargs[field.name] = value
815
816            # create a new instance of the class with the constructor kwargs
817            output: cls = cls(**ctor_kwargs)
818
819            # validate the types of the fields if needed
820            if on_typecheck_mismatch != ErrorMode.IGNORE:
821                output.validate_fields_types(on_typecheck_error=on_typecheck_error)
822
823            # return the new instance
824            return output
825
826        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
827        # type is `Callable[[T], dict]`
828        cls.serialize = serialize  # type: ignore[attr-defined]
829        # type is `Callable[[dict], T]`
830        cls.load = load  # type: ignore[attr-defined]
831        # type is `Callable[[T, ErrorMode], bool]`
832        cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]
833
834        # type is `Callable[[T, T], bool]`
835        cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
836
837        # Register the class with ZANJ
838        if register_handler:
839            zanj_register_loader_serializable_dataclass(cls)
840
841        return cls
842
843    if _cls is None:
844        return wrap
845    else:
846        return wrap(_cls)

def dataclass_transform( *, eq_default: bool = True, order_default: bool = False, kw_only_default: bool = False, frozen_default: bool = False, field_specifiers: tuple[typing.Union[type[typing.Any], typing.Callable[..., typing.Any]], ...] = (), **kwargs: Any) -> <class '_IdentityCallable'>:
3275def dataclass_transform(
3276    *,
3277    eq_default: bool = True,
3278    order_default: bool = False,
3279    kw_only_default: bool = False,
3280    frozen_default: bool = False,
3281    field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (),
3282    **kwargs: Any,
3283) -> _IdentityCallable:
3284    """Decorator to mark an object as providing dataclass-like behaviour.
3285
3286    The decorator can be applied to a function, class, or metaclass.
3287
3288    Example usage with a decorator function::
3289
3290        @dataclass_transform()
3291        def create_model[T](cls: type[T]) -> type[T]:
3292            ...
3293            return cls
3294
3295        @create_model
3296        class CustomerModel:
3297            id: int
3298            name: str
3299
3300    On a base class::
3301
3302        @dataclass_transform()
3303        class ModelBase: ...
3304
3305        class CustomerModel(ModelBase):
3306            id: int
3307            name: str
3308
3309    On a metaclass::
3310
3311        @dataclass_transform()
3312        class ModelMeta(type): ...
3313
3314        class ModelBase(metaclass=ModelMeta): ...
3315
3316        class CustomerModel(ModelBase):
3317            id: int
3318            name: str
3319
3320    The ``CustomerModel`` classes defined above will
3321    be treated by type checkers similarly to classes created with
3322    ``@dataclasses.dataclass``.
3323    For example, type checkers will assume these classes have
3324    ``__init__`` methods that accept ``id`` and ``name``.
3325
3326    The arguments to this decorator can be used to customize this behavior:
3327    - ``eq_default`` indicates whether the ``eq`` parameter is assumed to be
3328        ``True`` or ``False`` if it is omitted by the caller.
3329    - ``order_default`` indicates whether the ``order`` parameter is
3330        assumed to be True or False if it is omitted by the caller.
3331    - ``kw_only_default`` indicates whether the ``kw_only`` parameter is
3332        assumed to be True or False if it is omitted by the caller.
3333    - ``frozen_default`` indicates whether the ``frozen`` parameter is
3334        assumed to be True or False if it is omitted by the caller.
3335    - ``field_specifiers`` specifies a static list of supported classes
3336        or functions that describe fields, similar to ``dataclasses.field()``.
3337    - Arbitrary other keyword arguments are accepted in order to allow for
3338        possible future extensions.
3339
3340    At runtime, this decorator records its arguments in the
3341    ``__dataclass_transform__`` attribute on the decorated object.
3342    It has no other runtime effect.
3343
3344    See PEP 681 for more details.
3345    """
3346    def decorator(cls_or_fn):
3347        cls_or_fn.__dataclass_transform__ = {
3348            "eq_default": eq_default,
3349            "order_default": order_default,
3350            "kw_only_default": kw_only_default,
3351            "frozen_default": frozen_default,
3352            "field_specifiers": field_specifiers,
3353            "kwargs": kwargs,
3354        }
3355        return cls_or_fn
3356    return decorator

Decorator to mark an object as providing dataclass-like behaviour.

The decorator can be applied to a function, class, or metaclass.

Example usage with a decorator function::

@dataclass_transform()
def create_model[T](cls: type[T]) -> type[T]:
    ...
    return cls

@create_model
class CustomerModel:
    id: int
    name: str

On a base class::

@dataclass_transform()
class ModelBase: ...

class CustomerModel(ModelBase):
    id: int
    name: str

On a metaclass::

@dataclass_transform()
class ModelMeta(type): ...

class ModelBase(metaclass=ModelMeta): ...

class CustomerModel(ModelBase):
    id: int
    name: str

The CustomerModel classes defined above will be treated by type checkers similarly to classes created with @dataclasses.dataclass. For example, type checkers will assume these classes have __init__ methods that accept id and name.

The arguments to this decorator can be used to customize this behavior:

  • eq_default indicates whether the eq parameter is assumed to be True or False if it is omitted by the caller.
  • order_default indicates whether the order parameter is assumed to be True or False if it is omitted by the caller.
  • kw_only_default indicates whether the kw_only parameter is assumed to be True or False if it is omitted by the caller.
  • frozen_default indicates whether the frozen parameter is assumed to be True or False if it is omitted by the caller.
  • field_specifiers specifies a static list of supported classes or functions that describe fields, similar to dataclasses.field().
  • Arbitrary other keyword arguments are accepted in order to allow for possible future extensions.

At runtime, this decorator records its arguments in the __dataclass_transform__ attribute on the decorated object. It has no other runtime effect.

See PEP 681 for more details.

class CantGetTypeHintsWarning(builtins.UserWarning):
110class CantGetTypeHintsWarning(UserWarning):
111    "special warning for when we can't get type hints"
112
113    pass

special warning for when we can't get type hints

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
class ZanjMissingWarning(builtins.UserWarning):
116class ZanjMissingWarning(UserWarning):
117    "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
118
119    pass

special warning for when ZANJ is missing -- register_loader_serializable_dataclass will not work

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
def zanj_register_loader_serializable_dataclass(cls: Type[~T]):
126def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
127    """Register a serializable dataclass with the ZANJ import
128
129    this allows `ZANJ().read()` to load the class and not just return plain dicts
130
131
132    # TODO: there is some duplication here with register_loader_handler
133    """
134    global _zanj_loading_needs_import
135
136    if _zanj_loading_needs_import:
137        try:
138            from zanj.loading import (  # type: ignore[import]
139                LoaderHandler,
140                register_loader_handler,
141            )
142        except ImportError:
143            warnings.warn(
144                "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
145                ZanjMissingWarning,
146            )
147            return
148
149    _format: str = f"{cls.__name__}(SerializableDataclass)"
150    lh: LoaderHandler = LoaderHandler(
151        check=lambda json_item, path=None, z=None: (  # type: ignore
152            isinstance(json_item, dict)
153            and "__format__" in json_item
154            and json_item["__format__"].startswith(_format)
155        ),
156        load=lambda json_item, path=None, z=None: cls.load(json_item),  # type: ignore
157        uid=_format,
158        source_pckg=cls.__module__,
159        desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
160    )
161
162    register_loader_handler(lh)
163
164    return lh

Register a serializable dataclass with the ZANJ import

this allows ZANJ().read() to load the class and not just return plain dicts

TODO: there is some duplication here with register_loader_handler

class FieldIsNotInitOrSerializeWarning(builtins.UserWarning):
171class FieldIsNotInitOrSerializeWarning(UserWarning):
172    pass

Base class for warnings generated by user code.

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
def SerializableDataclass__validate_field_type( self: SerializableDataclass, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
175def SerializableDataclass__validate_field_type(
176    self: SerializableDataclass,
177    field: SerializableField | str,
178    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
179) -> bool:
180    """given a dataclass, check the field matches the type hint
181
182    this function is written to `SerializableDataclass.validate_field_type`
183
184    # Parameters:
185     - `self : SerializableDataclass`
186       `SerializableDataclass` instance
187     - `field : SerializableField | str`
188        field to validate, will get from `self.__dataclass_fields__` if an `str`
189     - `on_typecheck_error : ErrorMode`
190        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
191       (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
192
193    # Returns:
194     - `bool`
195        if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
196    """
197    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
198
199    # get field
200    _field: SerializableField
201    if isinstance(field, str):
202        _field = self.__dataclass_fields__[field]  # type: ignore[attr-defined]
203    else:
204        _field = field
205
206    # do nothing case
207    if not _field.assert_type:
208        return True
209
210    # if field is not `init` or not `serialize`, skip but warn
211    # TODO: how to handle fields which are not `init` or `serialize`?
212    if not _field.init or not _field.serialize:
213        warnings.warn(
214            f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
215            FieldIsNotInitOrSerializeWarning,
216        )
217        return True
218
219    assert isinstance(
220        _field, SerializableField
221    ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
222
223    # get field type hints
224    try:
225        field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
226    except KeyError as e:
227        on_typecheck_error.process(
228            (
229                f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
230                + f"{get_cls_type_hints(self.__class__) = }\n"
231                + f"Python version is {sys.version_info = }. You can:\n"
232                + f"  - disable `assert_type`. Currently: {_field.assert_type = }\n"
233                + f"  - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
234                + "  - use python 3.9.x or higher\n"
235                + "  - specify custom type validation function via `custom_typecheck_fn`\n"
236            ),
237            except_cls=TypeError,
238            except_from=e,
239        )
240        return False
241
242    # get the value
243    value: Any = getattr(self, _field.name)
244
245    # validate the type
246    try:
247        type_is_valid: bool
248        # validate the type with the default type validator
249        if _field.custom_typecheck_fn is None:
250            type_is_valid = validate_type(value, field_type_hint)
251        # validate the type with a custom type validator
252        else:
253            type_is_valid = _field.custom_typecheck_fn(field_type_hint)
254
255        return type_is_valid
256
257    except Exception as e:
258        on_typecheck_error.process(
259            "exception while validating type: "
260            + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
261            except_cls=ValueError,
262            except_from=e,
263        )
264        return False

given a dataclass, check the field matches the type hint

this function is written to SerializableDataclass.validate_field_type

Parameters:

  • self : SerializableDataclass SerializableDataclass instance
  • field : SerializableField | str field to validate, will get from self.__dataclass_fields__ if an str
  • on_typecheck_error : ErrorMode what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, the function will return False (defaults to _DEFAULT_ON_TYPECHECK_ERROR)

Returns:

  • bool if the field type is correct. False if the field type is incorrect or an exception is thrown and on_typecheck_error is ignore
def SerializableDataclass__validate_fields_types__dict( self: SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> dict[str, bool]:
267def SerializableDataclass__validate_fields_types__dict(
268    self: SerializableDataclass,
269    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
270) -> dict[str, bool]:
271    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
272
273    returns a dict of field names to bools, where the bool is if the field type is valid
274    """
275    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
276
277    # if except, bundle the exceptions
278    results: dict[str, bool] = dict()
279    exceptions: dict[str, Exception] = dict()
280
281    # for each field in the class
282    cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self)  # type: ignore[arg-type, assignment]
283    for field in cls_fields:
284        try:
285            results[field.name] = self.validate_field_type(field, on_typecheck_error)
286        except Exception as e:
287            results[field.name] = False
288            exceptions[field.name] = e
289
290    # figure out what to do with the exceptions
291    if len(exceptions) > 0:
292        on_typecheck_error.process(
293            f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
294            + "\n\t"
295            + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
296            except_cls=ValueError,
297            # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
298            except_from=list(exceptions.values())[0],
299        )
300
301    return results

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

returns a dict of field names to bools, where the bool is if the field type is valid

def SerializableDataclass__validate_fields_types( self: SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
304def SerializableDataclass__validate_fields_types(
305    self: SerializableDataclass,
306    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
307) -> bool:
308    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
309    return all(
310        SerializableDataclass__validate_fields_types__dict(
311            self, on_typecheck_error=on_typecheck_error
312        ).values()
313    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
class SerializableDataclass(abc.ABC):
316@dataclass_transform(
317    field_specifiers=(serializable_field, SerializableField),
318)
319class SerializableDataclass(abc.ABC):
320    """Base class for serializable dataclasses
321
322    only for linting and type checking, still need to call `serializable_dataclass` decorator
323
324    # Usage:
325
326    ```python
327    @serializable_dataclass
328    class MyClass(SerializableDataclass):
329        a: int
330        b: str
331    ```
332
333    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
334
335        >>> my_obj = MyClass(a=1, b="q")
336        >>> s = json.dumps(my_obj.serialize())
337        >>> s
338        '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
339        >>> read_obj = MyClass.load(json.loads(s))
340        >>> read_obj == my_obj
341        True
342
343    This isn't too impressive on its own, but it gets more useful when you have nested classses,
344    or fields that are not json-serializable by default:
345
346    ```python
347    @serializable_dataclass
348    class NestedClass(SerializableDataclass):
349        x: str
350        y: MyClass
351        act_fun: torch.nn.Module = serializable_field(
352            default=torch.nn.ReLU(),
353            serialization_fn=lambda x: str(x),
354            deserialize_fn=lambda x: getattr(torch.nn, x)(),
355        )
356    ```
357
358    which gives us:
359
360        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
361        >>> s = json.dumps(nc.serialize())
362        >>> s
363        '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
364        >>> read_nc = NestedClass.load(json.loads(s))
365        >>> read_nc == nc
366        True
367    """
368
369    def serialize(self) -> dict[str, Any]:
370        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
371        raise NotImplementedError(
372            f"decorate {self.__class__ = } with `@serializable_dataclass`"
373        )
374
375    @classmethod
376    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
377        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
378        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
379
380    def validate_fields_types(
381        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
382    ) -> bool:
383        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
384        return SerializableDataclass__validate_fields_types(
385            self, on_typecheck_error=on_typecheck_error
386        )
387
388    def validate_field_type(
389        self,
390        field: "SerializableField|str",
391        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
392    ) -> bool:
393        """given a dataclass, check the field matches the type hint"""
394        return SerializableDataclass__validate_field_type(
395            self, field, on_typecheck_error=on_typecheck_error
396        )
397
398    def __eq__(self, other: Any) -> bool:
399        return dc_eq(self, other)
400
401    def __hash__(self) -> int:
402        "hashes the json-serialized representation of the class"
403        return hash(json.dumps(self.serialize()))
404
405    def diff(
406        self, other: "SerializableDataclass", of_serialized: bool = False
407    ) -> dict[str, Any]:
408        """get a rich and recursive diff between two instances of a serializable dataclass
409
410        ```python
411        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
412        {'b': {'self': 2, 'other': 3}}
413        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
414        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
415        ```
416
417        # Parameters:
418         - `other : SerializableDataclass`
419           other instance to compare against
420         - `of_serialized : bool`
421           if true, compare serialized data and not raw values
422           (defaults to `False`)
423
424        # Returns:
425         - `dict[str, Any]`
426
427
428        # Raises:
429         - `ValueError` : if the instances are not of the same type
430         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
431        """
432        # match types
433        if type(self) is not type(other):
434            raise ValueError(
435                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
436            )
437
438        # initialize the diff result
439        diff_result: dict = {}
440
441        # if they are the same, return the empty diff
442        if self == other:
443            return diff_result
444
445        # if we are working with serialized data, serialize the instances
446        if of_serialized:
447            ser_self: dict = self.serialize()
448            ser_other: dict = other.serialize()
449
450        # for each field in the class
451        for field in dataclasses.fields(self):  # type: ignore[arg-type]
452            # skip fields that are not for comparison
453            if not field.compare:
454                continue
455
456            # get values
457            field_name: str = field.name
458            self_value = getattr(self, field_name)
459            other_value = getattr(other, field_name)
460
461            # if the values are both serializable dataclasses, recurse
462            if isinstance(self_value, SerializableDataclass) and isinstance(
463                other_value, SerializableDataclass
464            ):
465                nested_diff: dict = self_value.diff(
466                    other_value, of_serialized=of_serialized
467                )
468                if nested_diff:
469                    diff_result[field_name] = nested_diff
470            # only support serializable dataclasses
471            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
472                other_value
473            ):
474                raise ValueError("Non-serializable dataclass is not supported")
475            else:
476                # get the values of either the serialized or the actual values
477                self_value_s = ser_self[field_name] if of_serialized else self_value
478                other_value_s = ser_other[field_name] if of_serialized else other_value
479                # compare the values
480                if not array_safe_eq(self_value_s, other_value_s):
481                    diff_result[field_name] = {"self": self_value, "other": other_value}
482
483        # return the diff result
484        return diff_result
485
486    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
487        """update the instance from a nested dict, useful for configuration from command line args
488
489        # Parameters:
490            - `nested_dict : dict[str, Any]`
491                nested dict to update the instance with
492        """
493        for field in dataclasses.fields(self):  # type: ignore[arg-type]
494            field_name: str = field.name
495            self_value = getattr(self, field_name)
496
497            if field_name in nested_dict:
498                if isinstance(self_value, SerializableDataclass):
499                    self_value.update_from_nested_dict(nested_dict[field_name])
500                else:
501                    setattr(self, field_name, nested_dict[field_name])
502
503    def __copy__(self) -> "SerializableDataclass":
504        "deep copy by serializing and loading the instance to json"
505        return self.__class__.load(json.loads(json.dumps(self.serialize())))
506
507    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
508        "deep copy by serializing and loading the instance to json"
509        return self.__class__.load(json.loads(json.dumps(self.serialize())))

Base class for serializable dataclasses

only for linting and type checking, still need to call serializable_dataclass decorator

Usage:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]:
369    def serialize(self) -> dict[str, Any]:
370        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
371        raise NotImplementedError(
372            f"decorate {self.__class__ = } with `@serializable_dataclass`"
373        )

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls: Type[~T], data: Union[dict[str, Any], ~T]) -> ~T:
375    @classmethod
376    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
377        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
378        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
380    def validate_fields_types(
381        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
382    ) -> bool:
383        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
384        return SerializableDataclass__validate_fields_types(
385            self, on_typecheck_error=on_typecheck_error
386        )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

def validate_field_type( self, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
388    def validate_field_type(
389        self,
390        field: "SerializableField|str",
391        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
392    ) -> bool:
393        """given a dataclass, check the field matches the type hint"""
394        return SerializableDataclass__validate_field_type(
395            self, field, on_typecheck_error=on_typecheck_error
396        )

given a dataclass, check the field matches the type hint

def diff( self, other: SerializableDataclass, of_serialized: bool = False) -> dict[str, typing.Any]:
405    def diff(
406        self, other: "SerializableDataclass", of_serialized: bool = False
407    ) -> dict[str, Any]:
408        """get a rich and recursive diff between two instances of a serializable dataclass
409
410        ```python
411        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
412        {'b': {'self': 2, 'other': 3}}
413        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
414        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
415        ```
416
417        # Parameters:
418         - `other : SerializableDataclass`
419           other instance to compare against
420         - `of_serialized : bool`
421           if true, compare serialized data and not raw values
422           (defaults to `False`)
423
424        # Returns:
425         - `dict[str, Any]`
426
427
428        # Raises:
429         - `ValueError` : if the instances are not of the same type
430         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
431        """
432        # match types
433        if type(self) is not type(other):
434            raise ValueError(
435                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
436            )
437
438        # initialize the diff result
439        diff_result: dict = {}
440
441        # if they are the same, return the empty diff
442        if self == other:
443            return diff_result
444
445        # if we are working with serialized data, serialize the instances
446        if of_serialized:
447            ser_self: dict = self.serialize()
448            ser_other: dict = other.serialize()
449
450        # for each field in the class
451        for field in dataclasses.fields(self):  # type: ignore[arg-type]
452            # skip fields that are not for comparison
453            if not field.compare:
454                continue
455
456            # get values
457            field_name: str = field.name
458            self_value = getattr(self, field_name)
459            other_value = getattr(other, field_name)
460
461            # if the values are both serializable dataclasses, recurse
462            if isinstance(self_value, SerializableDataclass) and isinstance(
463                other_value, SerializableDataclass
464            ):
465                nested_diff: dict = self_value.diff(
466                    other_value, of_serialized=of_serialized
467                )
468                if nested_diff:
469                    diff_result[field_name] = nested_diff
470            # only support serializable dataclasses
471            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
472                other_value
473            ):
474                raise ValueError("Non-serializable dataclass is not supported")
475            else:
476                # get the values of either the serialized or the actual values
477                self_value_s = ser_self[field_name] if of_serialized else self_value
478                other_value_s = ser_other[field_name] if of_serialized else other_value
479                # compare the values
480                if not array_safe_eq(self_value_s, other_value_s):
481                    diff_result[field_name] = {"self": self_value, "other": other_value}
482
483        # return the diff result
484        return diff_result

get a rich and recursive diff between two instances of a serializable dataclass

>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}

Parameters:

  • other : SerializableDataclass other instance to compare against
  • of_serialized : bool if true, compare serialized data and not raw values (defaults to False)

Returns:

  • dict[str, Any]

Raises:

  • ValueError : if the instances are not of the same type
  • ValueError : if the instances are dataclasses.dataclass but not SerializableDataclass
def update_from_nested_dict(self, nested_dict: dict[str, typing.Any]):
486    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
487        """update the instance from a nested dict, useful for configuration from command line args
488
489        # Parameters:
490            - `nested_dict : dict[str, Any]`
491                nested dict to update the instance with
492        """
493        for field in dataclasses.fields(self):  # type: ignore[arg-type]
494            field_name: str = field.name
495            self_value = getattr(self, field_name)
496
497            if field_name in nested_dict:
498                if isinstance(self_value, SerializableDataclass):
499                    self_value.update_from_nested_dict(nested_dict[field_name])
500                else:
501                    setattr(self, field_name, nested_dict[field_name])

update the instance from a nested dict, useful for configuration from command line args

Parameters:

- `nested_dict : dict[str, Any]`
    nested dict to update the instance with
@functools.lru_cache(typed=True)
def get_cls_type_hints_cached(cls: Type[~T]) -> dict[str, typing.Any]:
514@functools.lru_cache(typed=True)
515def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
516    "cached typing.get_type_hints for a class"
517    return typing.get_type_hints(cls)

cached typing.get_type_hints for a class

def get_cls_type_hints(cls: Type[~T]) -> dict[str, typing.Any]:
520def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
521    "helper function to get type hints for a class"
522    cls_type_hints: dict[str, Any]
523    try:
524        cls_type_hints = get_cls_type_hints_cached(cls)  # type: ignore
525        if len(cls_type_hints) == 0:
526            cls_type_hints = typing.get_type_hints(cls)
527
528        if len(cls_type_hints) == 0:
529            raise ValueError(f"empty type hints for {cls.__name__ = }")
530    except (TypeError, NameError, ValueError) as e:
531        raise TypeError(
532            f"Cannot get type hints for {cls = }\n"
533            + f"  Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
534            + f"  {dataclasses.fields(cls) = }\n"  # type: ignore[arg-type]
535            + f"  {e = }"
536        ) from e
537
538    return cls_type_hints

helper function to get type hints for a class

class KWOnlyError(builtins.NotImplementedError):
541class KWOnlyError(NotImplementedError):
542    "kw-only dataclasses are not supported in python <3.9"
543
544    pass

kw-only dataclasses are not supported in python <3.9

Inherited Members
builtins.NotImplementedError
NotImplementedError
builtins.BaseException
with_traceback
add_note
args
class FieldError(builtins.ValueError):
547class FieldError(ValueError):
548    "base class for field errors"
549
550    pass

base class for field errors

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class NotSerializableFieldException(FieldError):
553class NotSerializableFieldException(FieldError):
554    "field is not a `SerializableField`"
555
556    pass

field is not a SerializableField

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldSerializationError(FieldError):
559class FieldSerializationError(FieldError):
560    "error while serializing a field"
561
562    pass

error while serializing a field

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldLoadingError(FieldError):
565class FieldLoadingError(FieldError):
566    "error while loading a field"
567
568    pass

error while loading a field

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
def serializable_dataclass( _cls=None, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except, on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn, **kwargs):
571@dataclass_transform(
572    field_specifiers=(serializable_field, SerializableField),
573)
574def serializable_dataclass(
575    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
576    _cls=None,  # type: ignore
577    *,
578    init: bool = True,
579    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
580    eq: bool = True,
581    order: bool = False,
582    unsafe_hash: bool = False,
583    frozen: bool = False,
584    properties_to_serialize: Optional[list[str]] = None,
585    register_handler: bool = True,
586    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
587    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
588    **kwargs,
589):
590    """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass`
591
592    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
593
594    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs
595
596    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
597
598    Examines PEP 526 `__annotations__` to determine fields.
599
600    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
601
602    ```python
603    @serializable_dataclass(kw_only=True)
604    class Myclass(SerializableDataclass):
605        a: int
606        b: str
607    ```
608    ```python
609    >>> Myclass(a=1, b="q").serialize()
610    {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
611    ```
612
613    # Parameters:
614     - `_cls : _type_`
615       class to decorate. don't pass this arg, just use this as a decorator
616       (defaults to `None`)
617     - `init : bool`
618       (defaults to `True`)
619     - `repr : bool`
620       (defaults to `True`)
621     - `order : bool`
622       (defaults to `False`)
623     - `unsafe_hash : bool`
624       (defaults to `False`)
625     - `frozen : bool`
626       (defaults to `False`)
627     - `properties_to_serialize : Optional[list[str]]`
628       **SerializableDataclass only:** which properties to add to the serialized data dict
629       (defaults to `None`)
630     - `register_handler : bool`
631        **SerializableDataclass only:** if true, register the class with ZANJ for loading
632       (defaults to `True`)
633     - `on_typecheck_error : ErrorMode`
634        **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
635     - `on_typecheck_mismatch : ErrorMode`
636        **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
637
638    # Returns:
639     - `_type_`
640       the decorated class
641
642    # Raises:
643     - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
644     - `NotSerializableFieldException` : if a field is not a `SerializableField`
645     - `FieldSerializationError` : if there is an error serializing a field
646     - `AttributeError` : if a property is not found on the class
647     - `FieldLoadingError` : if there is an error loading a field
648    """
649    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
650    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
651    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
652
653    if properties_to_serialize is None:
654        _properties_to_serialize: list = list()
655    else:
656        _properties_to_serialize = properties_to_serialize
657
658    def wrap(cls: Type[T]) -> Type[T]:
659        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
660        for field_name, field_type in cls.__annotations__.items():
661            field_value = getattr(cls, field_name, None)
662            if not isinstance(field_value, SerializableField):
663                if isinstance(field_value, dataclasses.Field):
664                    # Convert the field to a SerializableField while preserving properties
665                    field_value = SerializableField.from_Field(field_value)
666                else:
667                    # Create a new SerializableField
668                    field_value = serializable_field()
669                setattr(cls, field_name, field_value)
670
671        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
672        if sys.version_info < (3, 10):
673            if "kw_only" in kwargs:
674                if kwargs["kw_only"] == True:  # noqa: E712
675                    raise KWOnlyError("kw_only is not supported in python >=3.9")
676                else:
677                    del kwargs["kw_only"]
678
679        # call `dataclasses.dataclass` to set some stuff up
680        cls = dataclasses.dataclass(  # type: ignore[call-overload]
681            cls,
682            init=init,
683            repr=repr,
684            eq=eq,
685            order=order,
686            unsafe_hash=unsafe_hash,
687            frozen=frozen,
688            **kwargs,
689        )
690
691        # copy these to the class
692        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
693
694        # ======================================================================
695        # define `serialize` func
696        # done locally since it depends on args to the decorator
697        # ======================================================================
698        def serialize(self) -> dict[str, Any]:
699            result: dict[str, Any] = {
700                "__format__": f"{self.__class__.__name__}(SerializableDataclass)"
701            }
702            # for each field in the class
703            for field in dataclasses.fields(self):  # type: ignore[arg-type]
704                # need it to be our special SerializableField
705                if not isinstance(field, SerializableField):
706                    raise NotSerializableFieldException(
707                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
708                        f"but a {type(field)} "
709                        "this state should be inaccessible, please report this bug!"
710                    )
711
712                # try to save it
713                if field.serialize:
714                    try:
715                        # get the val
716                        value = getattr(self, field.name)
717                        # if it is a serializable dataclass, serialize it
718                        if isinstance(value, SerializableDataclass):
719                            value = value.serialize()
720                        # if the value has a serialization function, use that
721                        if hasattr(value, "serialize") and callable(value.serialize):
722                            value = value.serialize()
723                        # if the field has a serialization function, use that
724                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
725                        elif field.serialization_fn:
726                            value = field.serialization_fn(value)
727
728                        # store the value in the result
729                        result[field.name] = value
730                    except Exception as e:
731                        raise FieldSerializationError(
732                            "\n".join(
733                                [
734                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
735                                    f"{field = }",
736                                    f"{value = }",
737                                    f"{self = }",
738                                ]
739                            )
740                        ) from e
741
742            # store each property if we can get it
743            for prop in self._properties_to_serialize:
744                if hasattr(cls, prop):
745                    value = getattr(self, prop)
746                    result[prop] = value
747                else:
748                    raise AttributeError(
749                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
750                        + f"but it is in {self._properties_to_serialize = }"
751                        + f"\n{self = }"
752                    )
753
754            return result
755
756        # ======================================================================
757        # define `load` func
758        # done locally since it depends on args to the decorator
759        # ======================================================================
760        # mypy thinks this isnt a classmethod
761        @classmethod  # type: ignore[misc]
762        def load(cls, data: dict[str, Any] | T) -> Type[T]:
763            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
764            if isinstance(data, cls):
765                return data
766
767            assert isinstance(
768                data, typing.Mapping
769            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
770
771            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
772
773            # initialize dict for keeping what we will pass to the constructor
774            ctor_kwargs: dict[str, Any] = dict()
775
776            # iterate over the fields of the class
777            for field in dataclasses.fields(cls):
778                # check if the field is a SerializableField
779                assert isinstance(
780                    field, SerializableField
781                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
782
783                # check if the field is in the data and if it should be initialized
784                if (field.name in data) and field.init:
785                    # get the value, we will be processing it
786                    value: Any = data[field.name]
787
788                    # get the type hint for the field
789                    field_type_hint: Any = cls_type_hints.get(field.name, None)
790
791                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
792                    if field.deserialize_fn:
793                        # if it has a deserialization function, use that
794                        value = field.deserialize_fn(value)
795                    elif field.loading_fn:
796                        # if it has a loading function, use that
797                        value = field.loading_fn(data)
798                    elif (
799                        field_type_hint is not None
800                        and hasattr(field_type_hint, "load")
801                        and callable(field_type_hint.load)
802                    ):
803                        # if no loading function but has a type hint with a load method, use that
804                        if isinstance(value, dict):
805                            value = field_type_hint.load(value)
806                        else:
807                            raise FieldLoadingError(
808                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
809                            )
810                    else:
811                        # assume no loading needs to happen, keep `value` as-is
812                        pass
813
814                    # store the value in the constructor kwargs
815                    ctor_kwargs[field.name] = value
816
817            # create a new instance of the class with the constructor kwargs
818            output: cls = cls(**ctor_kwargs)
819
820            # validate the types of the fields if needed
821            if on_typecheck_mismatch != ErrorMode.IGNORE:
822                output.validate_fields_types(on_typecheck_error=on_typecheck_error)
823
824            # return the new instance
825            return output
826
827        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
828        # type is `Callable[[T], dict]`
829        cls.serialize = serialize  # type: ignore[attr-defined]
830        # type is `Callable[[dict], T]`
831        cls.load = load  # type: ignore[attr-defined]
832        # type is `Callable[[T, ErrorMode], bool]`
833        cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]
834
835        # type is `Callable[[T, T], bool]`
836        cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
837
838        # Register the class with ZANJ
839        if register_handler:
840            zanj_register_loader_serializable_dataclass(cls)
841
842        return cls
843
844    if _cls is None:
845        return wrap
846    else:
847        return wrap(_cls)

decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass

types will be validated (like pydantic) unless on_typecheck_mismatch is set to ErrorMode.IGNORE

behavior of most kwargs matches that of dataclasses.dataclass, but with some additional kwargs

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

Examines PEP 526 __annotations__ to determine fields.

If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.

@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
    a: int
    b: str
>>> Myclass(a=1, b="q").serialize()
{'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}

Parameters:

  • _cls : _type_ class to decorate. don't pass this arg, just use this as a decorator (defaults to None)
  • init : bool (defaults to True)
  • repr : bool (defaults to True)
  • order : bool (defaults to False)
  • unsafe_hash : bool (defaults to False)
  • frozen : bool (defaults to False)
  • properties_to_serialize : Optional[list[str]] SerializableDataclass only: which properties to add to the serialized data dict (defaults to None)
  • register_handler : bool SerializableDataclass only: if true, register the class with ZANJ for loading (defaults to True)
  • on_typecheck_error : ErrorMode SerializableDataclass only: what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, type validation will still return false
  • on_typecheck_mismatch : ErrorMode SerializableDataclass only: what to do if a type mismatch is found (except, warn, ignore). If ignore, type validation will return True

Returns:

  • _type_ the decorated class

Raises: