docs for muutils v0.9.1
View Source on GitHub

muutils.json_serialize.serializable_dataclass

save and load objects to and from json or compatible formats in a recoverable way

d = dataclasses.asdict(my_obj) will give you a dict, but if some fields are not json-serializable, you will get an error when you call json.dumps(d). This module provides a way around that.

Instead, you define your class:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True

  1"""save and load objects to and from json or compatible formats in a recoverable way
  2
  3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable,
  4you will get an error when you call `json.dumps(d)`. This module provides a way around that.
  5
  6Instead, you define your class:
  7
  8```python
  9@serializable_dataclass
 10class MyClass(SerializableDataclass):
 11    a: int
 12    b: str
 13```
 14
 15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
 16
 17    >>> my_obj = MyClass(a=1, b="q")
 18    >>> s = json.dumps(my_obj.serialize())
 19    >>> s
 20    '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
 21    >>> read_obj = MyClass.load(json.loads(s))
 22    >>> read_obj == my_obj
 23    True
 24
 25This isn't too impressive on its own, but it gets more useful when you have nested classses,
 26or fields that are not json-serializable by default:
 27
 28```python
 29@serializable_dataclass
 30class NestedClass(SerializableDataclass):
 31    x: str
 32    y: MyClass
 33    act_fun: torch.nn.Module = serializable_field(
 34        default=torch.nn.ReLU(),
 35        serialization_fn=lambda x: str(x),
 36        deserialize_fn=lambda x: getattr(torch.nn, x)(),
 37    )
 38```
 39
 40which gives us:
 41
 42    >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
 43    >>> s = json.dumps(nc.serialize())
 44    >>> s
 45    '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
 46    >>> read_nc = NestedClass.load(json.loads(s))
 47    >>> read_nc == nc
 48    True
 49
 50"""
 51
 52from __future__ import annotations
 53
 54import abc
 55import dataclasses
 56import functools
 57import json
 58import sys
 59import typing
 60import warnings
 61from typing import Any, Optional, Type, TypeVar, overload, TYPE_CHECKING
 62
 63from muutils.errormode import ErrorMode
 64from muutils.validate_type import validate_type
 65from muutils.json_serialize.serializable_field import (
 66    SerializableField,
 67    serializable_field,
 68)
 69from muutils.json_serialize.types import _FORMAT_KEY
 70from muutils.json_serialize.util import (
 71    JSONdict,
 72    array_safe_eq,
 73    dc_eq,
 74)
 75
 76# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access
 77
 78# For type checkers: always use typing_extensions which they can resolve
 79# At runtime: use stdlib if available (3.11+), else typing_extensions, else mock
 80if TYPE_CHECKING:
 81    from typing_extensions import dataclass_transform, Self
 82else:
 83    if sys.version_info >= (3, 11):
 84        from typing import dataclass_transform, Self
 85    else:
 86        try:
 87            from typing_extensions import dataclass_transform, Self
 88        except Exception:
 89            from muutils.json_serialize.dataclass_transform_mock import (
 90                dataclass_transform,
 91            )
 92
 93            Self = TypeVar("Self")
 94
 95T_SerializeableDataclass = TypeVar(
 96    "T_SerializeableDataclass", bound="SerializableDataclass"
 97)
 98
 99
100class CantGetTypeHintsWarning(UserWarning):
101    "special warning for when we can't get type hints"
102
103    pass
104
105
106class ZanjMissingWarning(UserWarning):
107    "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
108
109    pass
110
111
112_zanj_loading_needs_import: bool = True
113"flag to keep track of if we have successfully imported ZANJ"
114
115
116def zanj_register_loader_serializable_dataclass(
117    cls: typing.Type[T_SerializeableDataclass],
118):
119    """Register a serializable dataclass with the ZANJ import
120
121    this allows `ZANJ().read()` to load the class and not just return plain dicts
122
123
124    # TODO: there is some duplication here with register_loader_handler
125    """
126    global _zanj_loading_needs_import
127
128    if _zanj_loading_needs_import:
129        try:
130            from zanj.loading import (  # type: ignore[import]  # pyright: ignore[reportMissingImports]
131                LoaderHandler,  # pyright: ignore[reportUnknownVariableType]
132                register_loader_handler,  # pyright: ignore[reportUnknownVariableType]
133            )
134        except ImportError:
135            # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter
136            # warnings.warn(
137            #     "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
138            #     ZanjMissingWarning,
139            # )
140            return
141
142    _format: str = f"{cls.__name__}(SerializableDataclass)"
143    lh: LoaderHandler = LoaderHandler(  # pyright: ignore[reportPossiblyUnboundVariable]
144        check=lambda json_item, path=None, z=None: (  # type: ignore
145            isinstance(json_item, dict)
146            and _FORMAT_KEY in json_item
147            and json_item[_FORMAT_KEY].startswith(_format)
148        ),
149        load=lambda json_item, path=None, z=None: cls.load(json_item),  # type: ignore
150        uid=_format,
151        source_pckg=cls.__module__,
152        desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
153    )
154
155    register_loader_handler(lh)  # pyright: ignore[reportPossiblyUnboundVariable]
156
157    return lh
158
159
160_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN
161_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT
162
163
164class FieldIsNotInitOrSerializeWarning(UserWarning):
165    pass
166
167
168def SerializableDataclass__validate_field_type(
169    self: SerializableDataclass,
170    field: SerializableField | str,
171    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
172) -> bool:
173    """given a dataclass, check the field matches the type hint
174
175    this function is written to `SerializableDataclass.validate_field_type`
176
177    # Parameters:
178     - `self : SerializableDataclass`
179       `SerializableDataclass` instance
180     - `field : SerializableField | str`
181        field to validate, will get from `self.__dataclass_fields__` if an `str`
182     - `on_typecheck_error : ErrorMode`
183        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
184       (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
185
186    # Returns:
187     - `bool`
188        if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
189    """
190    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
191
192    # get field
193    _field: SerializableField
194    if isinstance(field, str):
195        _field = self.__dataclass_fields__[field]  # type: ignore[attr-defined]
196    else:
197        _field = field
198
199    # do nothing case
200    if not _field.assert_type:
201        return True
202
203    # if field is not `init` or not `serialize`, skip but warn
204    # TODO: how to handle fields which are not `init` or `serialize`?
205    if not _field.init or not _field.serialize:
206        warnings.warn(
207            f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
208            FieldIsNotInitOrSerializeWarning,
209        )
210        return True
211
212    assert isinstance(_field, SerializableField), (
213        f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
214    )
215
216    # get field type hints
217    try:
218        field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
219    except KeyError as e:
220        on_typecheck_error.process(
221            (
222                f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
223                + f"{get_cls_type_hints(self.__class__) = }\n"
224                + f"Python version is {sys.version_info = }. You can:\n"
225                + f"  - disable `assert_type`. Currently: {_field.assert_type = }\n"
226                + f"  - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
227                + "  - use python 3.9.x or higher\n"
228                + "  - specify custom type validation function via `custom_typecheck_fn`\n"
229            ),
230            except_cls=TypeError,
231            except_from=e,
232        )
233        return False
234
235    # get the value
236    value: Any = getattr(self, _field.name)
237
238    # validate the type
239    try:
240        type_is_valid: bool
241        # validate the type with the default type validator
242        if _field.custom_typecheck_fn is None:
243            type_is_valid = validate_type(value, field_type_hint)
244        # validate the type with a custom type validator
245        else:
246            type_is_valid = _field.custom_typecheck_fn(field_type_hint)
247
248        return type_is_valid
249
250    except Exception as e:
251        on_typecheck_error.process(
252            "exception while validating type: "
253            + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
254            except_cls=ValueError,
255            except_from=e,
256        )
257        return False
258
259
260def SerializableDataclass__validate_fields_types__dict(
261    self: SerializableDataclass,
262    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
263) -> dict[str, bool]:
264    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
265
266    returns a dict of field names to bools, where the bool is if the field type is valid
267    """
268    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
269
270    # if except, bundle the exceptions
271    results: dict[str, bool] = dict()
272    exceptions: dict[str, Exception] = dict()
273
274    # for each field in the class
275    cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self)  # type: ignore[arg-type, assignment]
276    for field in cls_fields:
277        try:
278            results[field.name] = self.validate_field_type(field, on_typecheck_error)
279        except Exception as e:
280            results[field.name] = False
281            exceptions[field.name] = e
282
283    # figure out what to do with the exceptions
284    if len(exceptions) > 0:
285        on_typecheck_error.process(
286            f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
287            + "\n\t"
288            + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
289            except_cls=ValueError,
290            # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
291            except_from=list(exceptions.values())[0],
292        )
293
294    return results
295
296
297def SerializableDataclass__validate_fields_types(
298    self: SerializableDataclass,
299    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
300) -> bool:
301    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
302    return all(
303        SerializableDataclass__validate_fields_types__dict(
304            self, on_typecheck_error=on_typecheck_error
305        ).values()
306    )
307
308
309@dataclass_transform(
310    field_specifiers=(serializable_field, SerializableField),
311)
312class SerializableDataclass(abc.ABC):
313    """Base class for serializable dataclasses
314
315    only for linting and type checking, still need to call `serializable_dataclass` decorator
316
317    # Usage:
318
319    ```python
320    @serializable_dataclass
321    class MyClass(SerializableDataclass):
322        a: int
323        b: str
324    ```
325
326    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
327
328        >>> my_obj = MyClass(a=1, b="q")
329        >>> s = json.dumps(my_obj.serialize())
330        >>> s
331        '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
332        >>> read_obj = MyClass.load(json.loads(s))
333        >>> read_obj == my_obj
334        True
335
336    This isn't too impressive on its own, but it gets more useful when you have nested classses,
337    or fields that are not json-serializable by default:
338
339    ```python
340    @serializable_dataclass
341    class NestedClass(SerializableDataclass):
342        x: str
343        y: MyClass
344        act_fun: torch.nn.Module = serializable_field(
345            default=torch.nn.ReLU(),
346            serialization_fn=lambda x: str(x),
347            deserialize_fn=lambda x: getattr(torch.nn, x)(),
348        )
349    ```
350
351    which gives us:
352
353        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
354        >>> s = json.dumps(nc.serialize())
355        >>> s
356        '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
357        >>> read_nc = NestedClass.load(json.loads(s))
358        >>> read_nc == nc
359        True
360    """
361
362    def serialize(self) -> dict[str, Any]:
363        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
364        raise NotImplementedError(
365            f"decorate {self.__class__ = } with `@serializable_dataclass`"
366        )
367
368    @overload
369    @classmethod
370    def load(cls, data: dict[str, Any]) -> Self: ...
371
372    @overload
373    @classmethod
374    def load(cls, data: Self) -> Self: ...
375
376    @classmethod
377    def load(cls, data: dict[str, Any] | Self) -> Self:
378        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
379        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
380
381    def validate_fields_types(
382        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
383    ) -> bool:
384        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
385        return SerializableDataclass__validate_fields_types(
386            self, on_typecheck_error=on_typecheck_error
387        )
388
389    def validate_field_type(
390        self,
391        field: "SerializableField|str",
392        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
393    ) -> bool:
394        """given a dataclass, check the field matches the type hint"""
395        return SerializableDataclass__validate_field_type(
396            self, field, on_typecheck_error=on_typecheck_error
397        )
398
399    def __eq__(self, other: Any) -> bool:
400        return dc_eq(self, other)
401
402    def __hash__(self) -> int:
403        "hashes the json-serialized representation of the class"
404        return hash(json.dumps(self.serialize()))
405
406    def diff(
407        self, other: "SerializableDataclass", of_serialized: bool = False
408    ) -> dict[str, Any]:
409        """get a rich and recursive diff between two instances of a serializable dataclass
410
411        ```python
412        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
413        {'b': {'self': 2, 'other': 3}}
414        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
415        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
416        ```
417
418        # Parameters:
419         - `other : SerializableDataclass`
420           other instance to compare against
421         - `of_serialized : bool`
422           if true, compare serialized data and not raw values
423           (defaults to `False`)
424
425        # Returns:
426         - `dict[str, Any]`
427
428
429        # Raises:
430         - `ValueError` : if the instances are not of the same type
431         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
432        """
433        # match types
434        if type(self) is not type(other):
435            raise ValueError(
436                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
437            )
438
439        # initialize the diff result
440        diff_result: dict = {}
441
442        # if they are the same, return the empty diff
443        try:
444            if self == other:
445                return diff_result
446        except Exception:
447            pass
448
449        # if we are working with serialized data, serialize the instances
450        if of_serialized:
451            ser_self: JSONdict = self.serialize()
452            ser_other: JSONdict = other.serialize()
453
454        # for each field in the class
455        for field in dataclasses.fields(self):  # type: ignore[arg-type]  # pyright: ignore[reportArgumentType]
456            # skip fields that are not for comparison
457            if not field.compare:
458                continue
459
460            # get values
461            field_name: str = field.name
462            self_value = getattr(self, field_name)
463            other_value = getattr(other, field_name)
464
465            # if the values are both serializable dataclasses, recurse
466            if isinstance(self_value, SerializableDataclass) and isinstance(
467                other_value, SerializableDataclass
468            ):
469                nested_diff: dict = self_value.diff(
470                    other_value, of_serialized=of_serialized
471                )
472                if nested_diff:
473                    diff_result[field_name] = nested_diff
474            # only support serializable dataclasses
475            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
476                other_value
477            ):
478                raise ValueError("Non-serializable dataclass is not supported")
479            else:
480                # get the values of either the serialized or the actual values
481                if of_serialized:
482                    self_value_s = ser_self[field_name]  # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType]
483                    other_value_s = ser_other[field_name]  # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType]
484                else:
485                    self_value_s = self_value
486                    other_value_s = other_value
487                # compare the values
488                if not array_safe_eq(self_value_s, other_value_s):
489                    diff_result[field_name] = {"self": self_value, "other": other_value}
490
491        # return the diff result
492        return diff_result
493
494    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
495        """update the instance from a nested dict, useful for configuration from command line args
496
497        # Parameters:
498            - `nested_dict : dict[str, Any]`
499                nested dict to update the instance with
500        """
501        for field in dataclasses.fields(self):  # type: ignore[arg-type]
502            field_name: str = field.name
503            self_value = getattr(self, field_name)
504
505            if field_name in nested_dict:
506                if isinstance(self_value, SerializableDataclass):
507                    self_value.update_from_nested_dict(nested_dict[field_name])
508                else:
509                    setattr(self, field_name, nested_dict[field_name])
510
511    def __copy__(self) -> "SerializableDataclass":
512        "deep copy by serializing and loading the instance to json"
513        return self.__class__.load(json.loads(json.dumps(self.serialize())))
514
515    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
516        "deep copy by serializing and loading the instance to json"
517        return self.__class__.load(json.loads(json.dumps(self.serialize())))
518
519
520# cache this so we don't have to keep getting it
521# TODO: are the types hashable? does this even make sense?
522@functools.lru_cache(typed=True)
523def get_cls_type_hints_cached(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]:
524    "cached typing.get_type_hints for a class"
525    return typing.get_type_hints(cls)
526
527
528def get_cls_type_hints(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]:
529    "helper function to get type hints for a class"
530    cls_type_hints: dict[str, Any]
531    try:
532        cls_type_hints = get_cls_type_hints_cached(cls)  # type: ignore
533        if len(cls_type_hints) == 0:
534            cls_type_hints = typing.get_type_hints(cls)
535
536        if len(cls_type_hints) == 0:
537            raise ValueError(f"empty type hints for {cls.__name__ = }")
538    except (TypeError, NameError, ValueError) as e:
539        raise TypeError(
540            f"Cannot get type hints for {cls = }\n"
541            + f"  Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
542            + f"  {dataclasses.fields(cls) = }\n"  # type: ignore[arg-type]
543            + f"  {e = }"
544        ) from e
545
546    return cls_type_hints
547
548
549class KWOnlyError(NotImplementedError):
550    "kw-only dataclasses are not supported in python <3.9"
551
552    pass
553
554
555class FieldError(ValueError):
556    "base class for field errors"
557
558    pass
559
560
561class NotSerializableFieldException(FieldError):
562    "field is not a `SerializableField`"
563
564    pass
565
566
567class FieldSerializationError(FieldError):
568    "error while serializing a field"
569
570    pass
571
572
573class FieldLoadingError(FieldError):
574    "error while loading a field"
575
576    pass
577
578
579class FieldTypeMismatchError(FieldError, TypeError):
580    "error when a field type does not match the type hint"
581
582    pass
583
584
585@dataclass_transform(
586    field_specifiers=(serializable_field, SerializableField),
587)
588def serializable_dataclass(
589    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
590    _cls=None,  # type: ignore
591    *,
592    init: bool = True,
593    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
594    eq: bool = True,
595    order: bool = False,
596    unsafe_hash: bool = False,
597    frozen: bool = False,
598    properties_to_serialize: Optional[list[str]] = None,
599    register_handler: bool = True,
600    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
601    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
602    methods_no_override: list[str] | None = None,
603    **kwargs: Any,
604) -> Any:
605    """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!**
606
607    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
608
609    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass`
610
611    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
612
613    Examines PEP 526 `__annotations__` to determine fields.
614
615    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
616
617    ```python
618    @serializable_dataclass(kw_only=True)
619    class Myclass(SerializableDataclass):
620        a: int
621        b: str
622    ```
623    ```python
624    >>> Myclass(a=1, b="q").serialize()
625    {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
626    ```
627
628    # Parameters:
629
630    - `_cls : _type_`
631       class to decorate. don't pass this arg, just use this as a decorator
632       (defaults to `None`)
633    - `init : bool`
634       whether to add an `__init__` method
635       *(passed to dataclasses.dataclass)*
636       (defaults to `True`)
637    - `repr : bool`
638       whether to add a `__repr__` method
639       *(passed to dataclasses.dataclass)*
640       (defaults to `True`)
641    - `order : bool`
642       whether to add rich comparison methods
643       *(passed to dataclasses.dataclass)*
644       (defaults to `False`)
645    - `unsafe_hash : bool`
646       whether to add a `__hash__` method
647       *(passed to dataclasses.dataclass)*
648       (defaults to `False`)
649    - `frozen : bool`
650       whether to make the class frozen
651       *(passed to dataclasses.dataclass)*
652       (defaults to `False`)
653    - `properties_to_serialize : Optional[list[str]]`
654       which properties to add to the serialized data dict
655       **SerializableDataclass only**
656       (defaults to `None`)
657    - `register_handler : bool`
658        if true, register the class with ZANJ for loading
659        **SerializableDataclass only**
660        (defaults to `True`)
661    - `on_typecheck_error : ErrorMode`
662        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
663        **SerializableDataclass only**
664    - `on_typecheck_mismatch : ErrorMode`
665        what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
666        **SerializableDataclass only**
667    - `methods_no_override : list[str]|None`
668        list of methods that should not be overridden by the decorator
669        by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function,
670        but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence
671        **SerializableDataclass only**
672        (defaults to `None`)
673    - `**kwargs`
674        *(passed to dataclasses.dataclass)*
675
676    # Returns:
677
678    - `_type_`
679       the decorated class
680
681    # Raises:
682
683    - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
684    - `NotSerializableFieldException` : if a field is not a `SerializableField`
685    - `FieldSerializationError` : if there is an error serializing a field
686    - `AttributeError` : if a property is not found on the class
687    - `FieldLoadingError` : if there is an error loading a field
688    """
689    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
690    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
691    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
692
693    if properties_to_serialize is None:
694        _properties_to_serialize: list = list()
695    else:
696        _properties_to_serialize = properties_to_serialize
697
698    def wrap(cls: Type[T_SerializeableDataclass]) -> Type[T_SerializeableDataclass]:
699        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
700        for field_name, field_type in cls.__annotations__.items():
701            field_value = getattr(cls, field_name, None)
702            if not isinstance(field_value, SerializableField):
703                if isinstance(field_value, dataclasses.Field):
704                    # Convert the field to a SerializableField while preserving properties
705                    field_value = SerializableField.from_Field(field_value)
706                else:
707                    # Create a new SerializableField
708                    field_value = serializable_field()
709                setattr(cls, field_name, field_value)
710
711        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
712        if sys.version_info < (3, 10):
713            if "kw_only" in kwargs:
714                if kwargs["kw_only"] == True:  # noqa: E712
715                    raise KWOnlyError(
716                        "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored"
717                    )
718                else:
719                    del kwargs["kw_only"]
720
721        # call `dataclasses.dataclass` to set some stuff up
722        cls = dataclasses.dataclass(  # type: ignore[call-overload]
723            cls,
724            init=init,
725            repr=repr,
726            eq=eq,
727            order=order,
728            unsafe_hash=unsafe_hash,
729            frozen=frozen,
730            **kwargs,
731        )
732
733        # copy these to the class
734        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
735
736        # ======================================================================
737        # define `serialize` func
738        # done locally since it depends on args to the decorator
739        # ======================================================================
740        def serialize(self: Any) -> dict[str, Any]:
741            result: dict[str, Any] = {
742                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
743            }
744            # for each field in the class
745            for field in dataclasses.fields(self):  # type: ignore[arg-type]
746                # need it to be our special SerializableField
747                if not isinstance(field, SerializableField):
748                    raise NotSerializableFieldException(
749                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
750                        f"but a {type(field)} "
751                        "this state should be inaccessible, please report this bug!"
752                    )
753
754                # try to save it
755                if field.serialize:
756                    value: Any = None  # init before try in case getattr raises
757                    try:
758                        # get the val
759                        value = getattr(self, field.name)
760                        # if it is a serializable dataclass, serialize it
761                        if isinstance(value, SerializableDataclass):
762                            value = value.serialize()
763                        # if the value has a serialization function, use that
764                        if hasattr(value, "serialize") and callable(value.serialize):  # pyright: ignore[reportAttributeAccessIssue]
765                            value = value.serialize()  # pyright: ignore[reportAttributeAccessIssue]
766                        # if the field has a serialization function, use that
767                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
768                        elif field.serialization_fn:
769                            value = field.serialization_fn(value)
770
771                        # store the value in the result
772                        result[field.name] = value
773                    except Exception as e:
774                        raise FieldSerializationError(
775                            "\n".join(
776                                [
777                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
778                                    f"{field = }",
779                                    f"{value or '<unavailable>' = }",
780                                    f"{self = }",
781                                ]
782                            )
783                        ) from e
784
785            # store each property if we can get it
786            for prop in self._properties_to_serialize:
787                if hasattr(cls, prop):
788                    value = getattr(self, prop)
789                    result[prop] = value
790                else:
791                    raise AttributeError(
792                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
793                        + f"but it is in {self._properties_to_serialize = }"
794                        + f"\n{self = }"
795                    )
796
797            return result
798
799        # ======================================================================
800        # define `load` func
801        # done locally since it depends on args to the decorator
802        # ======================================================================
803        # mypy thinks this isnt a classmethod
804        @classmethod  # type: ignore[misc]
805        def load(
806            cls: type[T_SerializeableDataclass],
807            data: dict[str, Any] | T_SerializeableDataclass,
808        ) -> T_SerializeableDataclass:
809            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
810            if isinstance(data, cls):
811                return data
812
813            assert isinstance(data, typing.Mapping), (
814                f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
815            )
816
817            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
818
819            # initialize dict for keeping what we will pass to the constructor
820            ctor_kwargs: dict[str, Any] = dict()
821
822            # iterate over the fields of the class
823            # mypy doesn't recognize @dataclass_transform for dataclasses.fields()
824            # https://github.com/python/mypy/issues/16241
825            for field in dataclasses.fields(cls):  # type: ignore[arg-type]
826                # check if the field is a SerializableField
827                assert isinstance(field, SerializableField), (
828                    f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
829                )
830
831                # check if the field is in the data and if it should be initialized
832                if (field.name in data) and field.init:
833                    # get the value, we will be processing it
834                    value: Any = data[field.name]
835
836                    # get the type hint for the field
837                    field_type_hint: Any = cls_type_hints.get(field.name, None)
838
839                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
840                    if field.deserialize_fn:
841                        # if it has a deserialization function, use that
842                        value = field.deserialize_fn(value)
843                    elif field.loading_fn:
844                        # if it has a loading function, use that
845                        value = field.loading_fn(data)
846                    elif (
847                        field_type_hint is not None
848                        and hasattr(field_type_hint, "load")
849                        and callable(field_type_hint.load)
850                    ):
851                        # if no loading function but has a type hint with a load method, use that
852                        if isinstance(value, dict):
853                            value = field_type_hint.load(value)
854                        else:
855                            raise FieldLoadingError(
856                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
857                            )
858                    else:
859                        # assume no loading needs to happen, keep `value` as-is
860                        pass
861
862                    # store the value in the constructor kwargs
863                    ctor_kwargs[field.name] = value
864
865            # create a new instance of the class with the constructor kwargs
866            output: T_SerializeableDataclass = cls(**ctor_kwargs)
867
868            # validate the types of the fields if needed
869            if on_typecheck_mismatch != ErrorMode.IGNORE:
870                fields_valid: dict[str, bool] = (
871                    SerializableDataclass__validate_fields_types__dict(
872                        output,
873                        on_typecheck_error=on_typecheck_error,
874                    )
875                )
876
877                # if there are any fields that are not valid, raise an error
878                if not all(fields_valid.values()):
879                    msg: str = (
880                        f"Type mismatch in fields of {cls.__name__}:\n"
881                        + "\n".join(
882                            [
883                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
884                                for k, v in fields_valid.items()
885                                if not v
886                            ]
887                        )
888                    )
889
890                    on_typecheck_mismatch.process(
891                        msg, except_cls=FieldTypeMismatchError
892                    )
893
894            # return the new instance
895            return output
896
897        _methods_no_override: set[str]
898        if methods_no_override is None:
899            _methods_no_override = set()
900        else:
901            _methods_no_override = set(methods_no_override)
902
903        if _methods_no_override - {
904            "__eq__",
905            "serialize",
906            "load",
907            "validate_fields_types",
908        }:
909            warnings.warn(
910                f"Unknown methods in `methods_no_override`: {_methods_no_override = }"
911            )
912
913        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
914        if "serialize" not in _methods_no_override:
915            # type is `Callable[[T], dict]`
916            cls.serialize = serialize  # type: ignore[attr-defined, method-assign]
917        if "load" not in _methods_no_override:
918            # type is `Callable[[dict], T]`
919            cls.load = load  # type: ignore[attr-defined, method-assign, assignment]
920
921        if "validate_field_type" not in _methods_no_override:
922            # type is `Callable[[T, ErrorMode], bool]`
923            cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined, method-assign]
924
925        if "__eq__" not in _methods_no_override:
926            # type is `Callable[[T, T], bool]`
927            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
928
929        # Register the class with ZANJ
930        if register_handler:
931            zanj_register_loader_serializable_dataclass(cls)
932
933        return cls
934
935    if _cls is None:
936        return wrap
937    else:
938        return wrap(_cls)

class CantGetTypeHintsWarning(builtins.UserWarning):
101class CantGetTypeHintsWarning(UserWarning):
102    "special warning for when we can't get type hints"
103
104    pass

special warning for when we can't get type hints

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
class ZanjMissingWarning(builtins.UserWarning):
107class ZanjMissingWarning(UserWarning):
108    "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
109
110    pass

special warning for when ZANJ is missing -- register_loader_serializable_dataclass will not work

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
def zanj_register_loader_serializable_dataclass(cls: Type[~T_SerializeableDataclass]):
117def zanj_register_loader_serializable_dataclass(
118    cls: typing.Type[T_SerializeableDataclass],
119):
120    """Register a serializable dataclass with the ZANJ import
121
122    this allows `ZANJ().read()` to load the class and not just return plain dicts
123
124
125    # TODO: there is some duplication here with register_loader_handler
126    """
127    global _zanj_loading_needs_import
128
129    if _zanj_loading_needs_import:
130        try:
131            from zanj.loading import (  # type: ignore[import]  # pyright: ignore[reportMissingImports]
132                LoaderHandler,  # pyright: ignore[reportUnknownVariableType]
133                register_loader_handler,  # pyright: ignore[reportUnknownVariableType]
134            )
135        except ImportError:
136            # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter
137            # warnings.warn(
138            #     "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
139            #     ZanjMissingWarning,
140            # )
141            return
142
143    _format: str = f"{cls.__name__}(SerializableDataclass)"
144    lh: LoaderHandler = LoaderHandler(  # pyright: ignore[reportPossiblyUnboundVariable]
145        check=lambda json_item, path=None, z=None: (  # type: ignore
146            isinstance(json_item, dict)
147            and _FORMAT_KEY in json_item
148            and json_item[_FORMAT_KEY].startswith(_format)
149        ),
150        load=lambda json_item, path=None, z=None: cls.load(json_item),  # type: ignore
151        uid=_format,
152        source_pckg=cls.__module__,
153        desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
154    )
155
156    register_loader_handler(lh)  # pyright: ignore[reportPossiblyUnboundVariable]
157
158    return lh

Register a serializable dataclass with the ZANJ import

this allows ZANJ().read() to load the class and not just return plain dicts

TODO: there is some duplication here with register_loader_handler

class FieldIsNotInitOrSerializeWarning(builtins.UserWarning):
165class FieldIsNotInitOrSerializeWarning(UserWarning):
166    pass

Base class for warnings generated by user code.

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
def SerializableDataclass__validate_field_type( self: SerializableDataclass, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
169def SerializableDataclass__validate_field_type(
170    self: SerializableDataclass,
171    field: SerializableField | str,
172    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
173) -> bool:
174    """given a dataclass, check the field matches the type hint
175
176    this function is written to `SerializableDataclass.validate_field_type`
177
178    # Parameters:
179     - `self : SerializableDataclass`
180       `SerializableDataclass` instance
181     - `field : SerializableField | str`
182        field to validate, will get from `self.__dataclass_fields__` if an `str`
183     - `on_typecheck_error : ErrorMode`
184        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
185       (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
186
187    # Returns:
188     - `bool`
189        if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
190    """
191    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
192
193    # get field
194    _field: SerializableField
195    if isinstance(field, str):
196        _field = self.__dataclass_fields__[field]  # type: ignore[attr-defined]
197    else:
198        _field = field
199
200    # do nothing case
201    if not _field.assert_type:
202        return True
203
204    # if field is not `init` or not `serialize`, skip but warn
205    # TODO: how to handle fields which are not `init` or `serialize`?
206    if not _field.init or not _field.serialize:
207        warnings.warn(
208            f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
209            FieldIsNotInitOrSerializeWarning,
210        )
211        return True
212
213    assert isinstance(_field, SerializableField), (
214        f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
215    )
216
217    # get field type hints
218    try:
219        field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
220    except KeyError as e:
221        on_typecheck_error.process(
222            (
223                f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
224                + f"{get_cls_type_hints(self.__class__) = }\n"
225                + f"Python version is {sys.version_info = }. You can:\n"
226                + f"  - disable `assert_type`. Currently: {_field.assert_type = }\n"
227                + f"  - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
228                + "  - use python 3.9.x or higher\n"
229                + "  - specify custom type validation function via `custom_typecheck_fn`\n"
230            ),
231            except_cls=TypeError,
232            except_from=e,
233        )
234        return False
235
236    # get the value
237    value: Any = getattr(self, _field.name)
238
239    # validate the type
240    try:
241        type_is_valid: bool
242        # validate the type with the default type validator
243        if _field.custom_typecheck_fn is None:
244            type_is_valid = validate_type(value, field_type_hint)
245        # validate the type with a custom type validator
246        else:
247            type_is_valid = _field.custom_typecheck_fn(field_type_hint)
248
249        return type_is_valid
250
251    except Exception as e:
252        on_typecheck_error.process(
253            "exception while validating type: "
254            + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
255            except_cls=ValueError,
256            except_from=e,
257        )
258        return False

given a dataclass, check the field matches the type hint

this function is written to SerializableDataclass.validate_field_type

Parameters:

  • self : SerializableDataclass SerializableDataclass instance
  • field : SerializableField | str field to validate, will get from self.__dataclass_fields__ if an str
  • on_typecheck_error : ErrorMode what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, the function will return False (defaults to _DEFAULT_ON_TYPECHECK_ERROR)

Returns:

  • bool if the field type is correct. False if the field type is incorrect or an exception is thrown and on_typecheck_error is ignore
def SerializableDataclass__validate_fields_types__dict( self: SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> dict[str, bool]:
261def SerializableDataclass__validate_fields_types__dict(
262    self: SerializableDataclass,
263    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
264) -> dict[str, bool]:
265    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
266
267    returns a dict of field names to bools, where the bool is if the field type is valid
268    """
269    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
270
271    # if except, bundle the exceptions
272    results: dict[str, bool] = dict()
273    exceptions: dict[str, Exception] = dict()
274
275    # for each field in the class
276    cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self)  # type: ignore[arg-type, assignment]
277    for field in cls_fields:
278        try:
279            results[field.name] = self.validate_field_type(field, on_typecheck_error)
280        except Exception as e:
281            results[field.name] = False
282            exceptions[field.name] = e
283
284    # figure out what to do with the exceptions
285    if len(exceptions) > 0:
286        on_typecheck_error.process(
287            f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
288            + "\n\t"
289            + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
290            except_cls=ValueError,
291            # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
292            except_from=list(exceptions.values())[0],
293        )
294
295    return results

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

returns a dict of field names to bools, where the bool is if the field type is valid

def SerializableDataclass__validate_fields_types( self: SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
298def SerializableDataclass__validate_fields_types(
299    self: SerializableDataclass,
300    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
301) -> bool:
302    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
303    return all(
304        SerializableDataclass__validate_fields_types__dict(
305            self, on_typecheck_error=on_typecheck_error
306        ).values()
307    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
class SerializableDataclass(abc.ABC):
310@dataclass_transform(
311    field_specifiers=(serializable_field, SerializableField),
312)
313class SerializableDataclass(abc.ABC):
314    """Base class for serializable dataclasses
315
316    only for linting and type checking, still need to call `serializable_dataclass` decorator
317
318    # Usage:
319
320    ```python
321    @serializable_dataclass
322    class MyClass(SerializableDataclass):
323        a: int
324        b: str
325    ```
326
327    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
328
329        >>> my_obj = MyClass(a=1, b="q")
330        >>> s = json.dumps(my_obj.serialize())
331        >>> s
332        '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
333        >>> read_obj = MyClass.load(json.loads(s))
334        >>> read_obj == my_obj
335        True
336
337    This isn't too impressive on its own, but it gets more useful when you have nested classses,
338    or fields that are not json-serializable by default:
339
340    ```python
341    @serializable_dataclass
342    class NestedClass(SerializableDataclass):
343        x: str
344        y: MyClass
345        act_fun: torch.nn.Module = serializable_field(
346            default=torch.nn.ReLU(),
347            serialization_fn=lambda x: str(x),
348            deserialize_fn=lambda x: getattr(torch.nn, x)(),
349        )
350    ```
351
352    which gives us:
353
354        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
355        >>> s = json.dumps(nc.serialize())
356        >>> s
357        '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
358        >>> read_nc = NestedClass.load(json.loads(s))
359        >>> read_nc == nc
360        True
361    """
362
363    def serialize(self) -> dict[str, Any]:
364        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
365        raise NotImplementedError(
366            f"decorate {self.__class__ = } with `@serializable_dataclass`"
367        )
368
369    @overload
370    @classmethod
371    def load(cls, data: dict[str, Any]) -> Self: ...
372
373    @overload
374    @classmethod
375    def load(cls, data: Self) -> Self: ...
376
377    @classmethod
378    def load(cls, data: dict[str, Any] | Self) -> Self:
379        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
380        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
381
382    def validate_fields_types(
383        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
384    ) -> bool:
385        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
386        return SerializableDataclass__validate_fields_types(
387            self, on_typecheck_error=on_typecheck_error
388        )
389
390    def validate_field_type(
391        self,
392        field: "SerializableField|str",
393        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
394    ) -> bool:
395        """given a dataclass, check the field matches the type hint"""
396        return SerializableDataclass__validate_field_type(
397            self, field, on_typecheck_error=on_typecheck_error
398        )
399
400    def __eq__(self, other: Any) -> bool:
401        return dc_eq(self, other)
402
403    def __hash__(self) -> int:
404        "hashes the json-serialized representation of the class"
405        return hash(json.dumps(self.serialize()))
406
407    def diff(
408        self, other: "SerializableDataclass", of_serialized: bool = False
409    ) -> dict[str, Any]:
410        """get a rich and recursive diff between two instances of a serializable dataclass
411
412        ```python
413        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
414        {'b': {'self': 2, 'other': 3}}
415        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
416        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
417        ```
418
419        # Parameters:
420         - `other : SerializableDataclass`
421           other instance to compare against
422         - `of_serialized : bool`
423           if true, compare serialized data and not raw values
424           (defaults to `False`)
425
426        # Returns:
427         - `dict[str, Any]`
428
429
430        # Raises:
431         - `ValueError` : if the instances are not of the same type
432         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
433        """
434        # match types
435        if type(self) is not type(other):
436            raise ValueError(
437                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
438            )
439
440        # initialize the diff result
441        diff_result: dict = {}
442
443        # if they are the same, return the empty diff
444        try:
445            if self == other:
446                return diff_result
447        except Exception:
448            pass
449
450        # if we are working with serialized data, serialize the instances
451        if of_serialized:
452            ser_self: JSONdict = self.serialize()
453            ser_other: JSONdict = other.serialize()
454
455        # for each field in the class
456        for field in dataclasses.fields(self):  # type: ignore[arg-type]  # pyright: ignore[reportArgumentType]
457            # skip fields that are not for comparison
458            if not field.compare:
459                continue
460
461            # get values
462            field_name: str = field.name
463            self_value = getattr(self, field_name)
464            other_value = getattr(other, field_name)
465
466            # if the values are both serializable dataclasses, recurse
467            if isinstance(self_value, SerializableDataclass) and isinstance(
468                other_value, SerializableDataclass
469            ):
470                nested_diff: dict = self_value.diff(
471                    other_value, of_serialized=of_serialized
472                )
473                if nested_diff:
474                    diff_result[field_name] = nested_diff
475            # only support serializable dataclasses
476            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
477                other_value
478            ):
479                raise ValueError("Non-serializable dataclass is not supported")
480            else:
481                # get the values of either the serialized or the actual values
482                if of_serialized:
483                    self_value_s = ser_self[field_name]  # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType]
484                    other_value_s = ser_other[field_name]  # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType]
485                else:
486                    self_value_s = self_value
487                    other_value_s = other_value
488                # compare the values
489                if not array_safe_eq(self_value_s, other_value_s):
490                    diff_result[field_name] = {"self": self_value, "other": other_value}
491
492        # return the diff result
493        return diff_result
494
495    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
496        """update the instance from a nested dict, useful for configuration from command line args
497
498        # Parameters:
499            - `nested_dict : dict[str, Any]`
500                nested dict to update the instance with
501        """
502        for field in dataclasses.fields(self):  # type: ignore[arg-type]
503            field_name: str = field.name
504            self_value = getattr(self, field_name)
505
506            if field_name in nested_dict:
507                if isinstance(self_value, SerializableDataclass):
508                    self_value.update_from_nested_dict(nested_dict[field_name])
509                else:
510                    setattr(self, field_name, nested_dict[field_name])
511
512    def __copy__(self) -> "SerializableDataclass":
513        "deep copy by serializing and loading the instance to json"
514        return self.__class__.load(json.loads(json.dumps(self.serialize())))
515
516    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
517        "deep copy by serializing and loading the instance to json"
518        return self.__class__.load(json.loads(json.dumps(self.serialize())))

Base class for serializable dataclasses

only for linting and type checking, still need to call serializable_dataclass decorator

Usage:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]:
363    def serialize(self) -> dict[str, Any]:
364        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
365        raise NotImplementedError(
366            f"decorate {self.__class__ = } with `@serializable_dataclass`"
367        )

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], Self]) -> Self:
377    @classmethod
378    def load(cls, data: dict[str, Any] | Self) -> Self:
379        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
380        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
382    def validate_fields_types(
383        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
384    ) -> bool:
385        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
386        return SerializableDataclass__validate_fields_types(
387            self, on_typecheck_error=on_typecheck_error
388        )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

def validate_field_type( self, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
390    def validate_field_type(
391        self,
392        field: "SerializableField|str",
393        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
394    ) -> bool:
395        """given a dataclass, check the field matches the type hint"""
396        return SerializableDataclass__validate_field_type(
397            self, field, on_typecheck_error=on_typecheck_error
398        )

given a dataclass, check the field matches the type hint

def diff( self, other: SerializableDataclass, of_serialized: bool = False) -> dict[str, typing.Any]:
407    def diff(
408        self, other: "SerializableDataclass", of_serialized: bool = False
409    ) -> dict[str, Any]:
410        """get a rich and recursive diff between two instances of a serializable dataclass
411
412        ```python
413        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
414        {'b': {'self': 2, 'other': 3}}
415        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
416        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
417        ```
418
419        # Parameters:
420         - `other : SerializableDataclass`
421           other instance to compare against
422         - `of_serialized : bool`
423           if true, compare serialized data and not raw values
424           (defaults to `False`)
425
426        # Returns:
427         - `dict[str, Any]`
428
429
430        # Raises:
431         - `ValueError` : if the instances are not of the same type
432         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
433        """
434        # match types
435        if type(self) is not type(other):
436            raise ValueError(
437                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
438            )
439
440        # initialize the diff result
441        diff_result: dict = {}
442
443        # if they are the same, return the empty diff
444        try:
445            if self == other:
446                return diff_result
447        except Exception:
448            pass
449
450        # if we are working with serialized data, serialize the instances
451        if of_serialized:
452            ser_self: JSONdict = self.serialize()
453            ser_other: JSONdict = other.serialize()
454
455        # for each field in the class
456        for field in dataclasses.fields(self):  # type: ignore[arg-type]  # pyright: ignore[reportArgumentType]
457            # skip fields that are not for comparison
458            if not field.compare:
459                continue
460
461            # get values
462            field_name: str = field.name
463            self_value = getattr(self, field_name)
464            other_value = getattr(other, field_name)
465
466            # if the values are both serializable dataclasses, recurse
467            if isinstance(self_value, SerializableDataclass) and isinstance(
468                other_value, SerializableDataclass
469            ):
470                nested_diff: dict = self_value.diff(
471                    other_value, of_serialized=of_serialized
472                )
473                if nested_diff:
474                    diff_result[field_name] = nested_diff
475            # only support serializable dataclasses
476            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
477                other_value
478            ):
479                raise ValueError("Non-serializable dataclass is not supported")
480            else:
481                # get the values of either the serialized or the actual values
482                if of_serialized:
483                    self_value_s = ser_self[field_name]  # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType]
484                    other_value_s = ser_other[field_name]  # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType]
485                else:
486                    self_value_s = self_value
487                    other_value_s = other_value
488                # compare the values
489                if not array_safe_eq(self_value_s, other_value_s):
490                    diff_result[field_name] = {"self": self_value, "other": other_value}
491
492        # return the diff result
493        return diff_result

get a rich and recursive diff between two instances of a serializable dataclass

>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}

Parameters:

  • other : SerializableDataclass other instance to compare against
  • of_serialized : bool if true, compare serialized data and not raw values (defaults to False)

Returns:

  • dict[str, Any]

Raises:

  • ValueError : if the instances are not of the same type
  • ValueError : if the instances are dataclasses.dataclass but not SerializableDataclass
def update_from_nested_dict(self, nested_dict: dict[str, typing.Any]):
495    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
496        """update the instance from a nested dict, useful for configuration from command line args
497
498        # Parameters:
499            - `nested_dict : dict[str, Any]`
500                nested dict to update the instance with
501        """
502        for field in dataclasses.fields(self):  # type: ignore[arg-type]
503            field_name: str = field.name
504            self_value = getattr(self, field_name)
505
506            if field_name in nested_dict:
507                if isinstance(self_value, SerializableDataclass):
508                    self_value.update_from_nested_dict(nested_dict[field_name])
509                else:
510                    setattr(self, field_name, nested_dict[field_name])

update the instance from a nested dict, useful for configuration from command line args

Parameters:

- `nested_dict : dict[str, Any]`
    nested dict to update the instance with
@functools.lru_cache(typed=True)
def get_cls_type_hints_cached(cls: Type[~T_SerializeableDataclass]) -> dict[str, typing.Any]:
523@functools.lru_cache(typed=True)
524def get_cls_type_hints_cached(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]:
525    "cached typing.get_type_hints for a class"
526    return typing.get_type_hints(cls)

cached typing.get_type_hints for a class

def get_cls_type_hints(cls: Type[~T_SerializeableDataclass]) -> dict[str, typing.Any]:
529def get_cls_type_hints(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]:
530    "helper function to get type hints for a class"
531    cls_type_hints: dict[str, Any]
532    try:
533        cls_type_hints = get_cls_type_hints_cached(cls)  # type: ignore
534        if len(cls_type_hints) == 0:
535            cls_type_hints = typing.get_type_hints(cls)
536
537        if len(cls_type_hints) == 0:
538            raise ValueError(f"empty type hints for {cls.__name__ = }")
539    except (TypeError, NameError, ValueError) as e:
540        raise TypeError(
541            f"Cannot get type hints for {cls = }\n"
542            + f"  Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
543            + f"  {dataclasses.fields(cls) = }\n"  # type: ignore[arg-type]
544            + f"  {e = }"
545        ) from e
546
547    return cls_type_hints

helper function to get type hints for a class

class KWOnlyError(builtins.NotImplementedError):
550class KWOnlyError(NotImplementedError):
551    "kw-only dataclasses are not supported in python <3.9"
552
553    pass

kw-only dataclasses are not supported in python <3.9

Inherited Members
builtins.NotImplementedError
NotImplementedError
builtins.BaseException
with_traceback
add_note
args
class FieldError(builtins.ValueError):
556class FieldError(ValueError):
557    "base class for field errors"
558
559    pass

base class for field errors

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class NotSerializableFieldException(FieldError):
562class NotSerializableFieldException(FieldError):
563    "field is not a `SerializableField`"
564
565    pass

field is not a SerializableField

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldSerializationError(FieldError):
568class FieldSerializationError(FieldError):
569    "error while serializing a field"
570
571    pass

error while serializing a field

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldLoadingError(FieldError):
574class FieldLoadingError(FieldError):
575    "error while loading a field"
576
577    pass

error while loading a field

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldTypeMismatchError(FieldError, builtins.TypeError):
580class FieldTypeMismatchError(FieldError, TypeError):
581    "error when a field type does not match the type hint"
582
583    pass

error when a field type does not match the type hint

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
def serializable_dataclass( _cls=None, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except, on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn, methods_no_override: list[str] | None = None, **kwargs: Any) -> Any:
586@dataclass_transform(
587    field_specifiers=(serializable_field, SerializableField),
588)
589def serializable_dataclass(
590    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
591    _cls=None,  # type: ignore
592    *,
593    init: bool = True,
594    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
595    eq: bool = True,
596    order: bool = False,
597    unsafe_hash: bool = False,
598    frozen: bool = False,
599    properties_to_serialize: Optional[list[str]] = None,
600    register_handler: bool = True,
601    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
602    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
603    methods_no_override: list[str] | None = None,
604    **kwargs: Any,
605) -> Any:
606    """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!**
607
608    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
609
610    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass`
611
612    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
613
614    Examines PEP 526 `__annotations__` to determine fields.
615
616    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
617
618    ```python
619    @serializable_dataclass(kw_only=True)
620    class Myclass(SerializableDataclass):
621        a: int
622        b: str
623    ```
624    ```python
625    >>> Myclass(a=1, b="q").serialize()
626    {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
627    ```
628
629    # Parameters:
630
631    - `_cls : _type_`
632       class to decorate. don't pass this arg, just use this as a decorator
633       (defaults to `None`)
634    - `init : bool`
635       whether to add an `__init__` method
636       *(passed to dataclasses.dataclass)*
637       (defaults to `True`)
638    - `repr : bool`
639       whether to add a `__repr__` method
640       *(passed to dataclasses.dataclass)*
641       (defaults to `True`)
642    - `order : bool`
643       whether to add rich comparison methods
644       *(passed to dataclasses.dataclass)*
645       (defaults to `False`)
646    - `unsafe_hash : bool`
647       whether to add a `__hash__` method
648       *(passed to dataclasses.dataclass)*
649       (defaults to `False`)
650    - `frozen : bool`
651       whether to make the class frozen
652       *(passed to dataclasses.dataclass)*
653       (defaults to `False`)
654    - `properties_to_serialize : Optional[list[str]]`
655       which properties to add to the serialized data dict
656       **SerializableDataclass only**
657       (defaults to `None`)
658    - `register_handler : bool`
659        if true, register the class with ZANJ for loading
660        **SerializableDataclass only**
661        (defaults to `True`)
662    - `on_typecheck_error : ErrorMode`
663        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
664        **SerializableDataclass only**
665    - `on_typecheck_mismatch : ErrorMode`
666        what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
667        **SerializableDataclass only**
668    - `methods_no_override : list[str]|None`
669        list of methods that should not be overridden by the decorator
670        by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function,
671        but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence
672        **SerializableDataclass only**
673        (defaults to `None`)
674    - `**kwargs`
675        *(passed to dataclasses.dataclass)*
676
677    # Returns:
678
679    - `_type_`
680       the decorated class
681
682    # Raises:
683
684    - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
685    - `NotSerializableFieldException` : if a field is not a `SerializableField`
686    - `FieldSerializationError` : if there is an error serializing a field
687    - `AttributeError` : if a property is not found on the class
688    - `FieldLoadingError` : if there is an error loading a field
689    """
690    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
691    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
692    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
693
694    if properties_to_serialize is None:
695        _properties_to_serialize: list = list()
696    else:
697        _properties_to_serialize = properties_to_serialize
698
699    def wrap(cls: Type[T_SerializeableDataclass]) -> Type[T_SerializeableDataclass]:
700        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
701        for field_name, field_type in cls.__annotations__.items():
702            field_value = getattr(cls, field_name, None)
703            if not isinstance(field_value, SerializableField):
704                if isinstance(field_value, dataclasses.Field):
705                    # Convert the field to a SerializableField while preserving properties
706                    field_value = SerializableField.from_Field(field_value)
707                else:
708                    # Create a new SerializableField
709                    field_value = serializable_field()
710                setattr(cls, field_name, field_value)
711
712        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
713        if sys.version_info < (3, 10):
714            if "kw_only" in kwargs:
715                if kwargs["kw_only"] == True:  # noqa: E712
716                    raise KWOnlyError(
717                        "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored"
718                    )
719                else:
720                    del kwargs["kw_only"]
721
722        # call `dataclasses.dataclass` to set some stuff up
723        cls = dataclasses.dataclass(  # type: ignore[call-overload]
724            cls,
725            init=init,
726            repr=repr,
727            eq=eq,
728            order=order,
729            unsafe_hash=unsafe_hash,
730            frozen=frozen,
731            **kwargs,
732        )
733
734        # copy these to the class
735        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
736
737        # ======================================================================
738        # define `serialize` func
739        # done locally since it depends on args to the decorator
740        # ======================================================================
741        def serialize(self: Any) -> dict[str, Any]:
742            result: dict[str, Any] = {
743                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
744            }
745            # for each field in the class
746            for field in dataclasses.fields(self):  # type: ignore[arg-type]
747                # need it to be our special SerializableField
748                if not isinstance(field, SerializableField):
749                    raise NotSerializableFieldException(
750                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
751                        f"but a {type(field)} "
752                        "this state should be inaccessible, please report this bug!"
753                    )
754
755                # try to save it
756                if field.serialize:
757                    value: Any = None  # init before try in case getattr raises
758                    try:
759                        # get the val
760                        value = getattr(self, field.name)
761                        # if it is a serializable dataclass, serialize it
762                        if isinstance(value, SerializableDataclass):
763                            value = value.serialize()
764                        # if the value has a serialization function, use that
765                        if hasattr(value, "serialize") and callable(value.serialize):  # pyright: ignore[reportAttributeAccessIssue]
766                            value = value.serialize()  # pyright: ignore[reportAttributeAccessIssue]
767                        # if the field has a serialization function, use that
768                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
769                        elif field.serialization_fn:
770                            value = field.serialization_fn(value)
771
772                        # store the value in the result
773                        result[field.name] = value
774                    except Exception as e:
775                        raise FieldSerializationError(
776                            "\n".join(
777                                [
778                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
779                                    f"{field = }",
780                                    f"{value or '<unavailable>' = }",
781                                    f"{self = }",
782                                ]
783                            )
784                        ) from e
785
786            # store each property if we can get it
787            for prop in self._properties_to_serialize:
788                if hasattr(cls, prop):
789                    value = getattr(self, prop)
790                    result[prop] = value
791                else:
792                    raise AttributeError(
793                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
794                        + f"but it is in {self._properties_to_serialize = }"
795                        + f"\n{self = }"
796                    )
797
798            return result
799
800        # ======================================================================
801        # define `load` func
802        # done locally since it depends on args to the decorator
803        # ======================================================================
804        # mypy thinks this isnt a classmethod
805        @classmethod  # type: ignore[misc]
806        def load(
807            cls: type[T_SerializeableDataclass],
808            data: dict[str, Any] | T_SerializeableDataclass,
809        ) -> T_SerializeableDataclass:
810            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
811            if isinstance(data, cls):
812                return data
813
814            assert isinstance(data, typing.Mapping), (
815                f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
816            )
817
818            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
819
820            # initialize dict for keeping what we will pass to the constructor
821            ctor_kwargs: dict[str, Any] = dict()
822
823            # iterate over the fields of the class
824            # mypy doesn't recognize @dataclass_transform for dataclasses.fields()
825            # https://github.com/python/mypy/issues/16241
826            for field in dataclasses.fields(cls):  # type: ignore[arg-type]
827                # check if the field is a SerializableField
828                assert isinstance(field, SerializableField), (
829                    f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
830                )
831
832                # check if the field is in the data and if it should be initialized
833                if (field.name in data) and field.init:
834                    # get the value, we will be processing it
835                    value: Any = data[field.name]
836
837                    # get the type hint for the field
838                    field_type_hint: Any = cls_type_hints.get(field.name, None)
839
840                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
841                    if field.deserialize_fn:
842                        # if it has a deserialization function, use that
843                        value = field.deserialize_fn(value)
844                    elif field.loading_fn:
845                        # if it has a loading function, use that
846                        value = field.loading_fn(data)
847                    elif (
848                        field_type_hint is not None
849                        and hasattr(field_type_hint, "load")
850                        and callable(field_type_hint.load)
851                    ):
852                        # if no loading function but has a type hint with a load method, use that
853                        if isinstance(value, dict):
854                            value = field_type_hint.load(value)
855                        else:
856                            raise FieldLoadingError(
857                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
858                            )
859                    else:
860                        # assume no loading needs to happen, keep `value` as-is
861                        pass
862
863                    # store the value in the constructor kwargs
864                    ctor_kwargs[field.name] = value
865
866            # create a new instance of the class with the constructor kwargs
867            output: T_SerializeableDataclass = cls(**ctor_kwargs)
868
869            # validate the types of the fields if needed
870            if on_typecheck_mismatch != ErrorMode.IGNORE:
871                fields_valid: dict[str, bool] = (
872                    SerializableDataclass__validate_fields_types__dict(
873                        output,
874                        on_typecheck_error=on_typecheck_error,
875                    )
876                )
877
878                # if there are any fields that are not valid, raise an error
879                if not all(fields_valid.values()):
880                    msg: str = (
881                        f"Type mismatch in fields of {cls.__name__}:\n"
882                        + "\n".join(
883                            [
884                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
885                                for k, v in fields_valid.items()
886                                if not v
887                            ]
888                        )
889                    )
890
891                    on_typecheck_mismatch.process(
892                        msg, except_cls=FieldTypeMismatchError
893                    )
894
895            # return the new instance
896            return output
897
898        _methods_no_override: set[str]
899        if methods_no_override is None:
900            _methods_no_override = set()
901        else:
902            _methods_no_override = set(methods_no_override)
903
904        if _methods_no_override - {
905            "__eq__",
906            "serialize",
907            "load",
908            "validate_fields_types",
909        }:
910            warnings.warn(
911                f"Unknown methods in `methods_no_override`: {_methods_no_override = }"
912            )
913
914        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
915        if "serialize" not in _methods_no_override:
916            # type is `Callable[[T], dict]`
917            cls.serialize = serialize  # type: ignore[attr-defined, method-assign]
918        if "load" not in _methods_no_override:
919            # type is `Callable[[dict], T]`
920            cls.load = load  # type: ignore[attr-defined, method-assign, assignment]
921
922        if "validate_field_type" not in _methods_no_override:
923            # type is `Callable[[T, ErrorMode], bool]`
924            cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined, method-assign]
925
926        if "__eq__" not in _methods_no_override:
927            # type is `Callable[[T, T], bool]`
928            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
929
930        # Register the class with ZANJ
931        if register_handler:
932            zanj_register_loader_serializable_dataclass(cls)
933
934        return cls
935
936    if _cls is None:
937        return wrap
938    else:
939        return wrap(_cls)

decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass!!

types will be validated (like pydantic) unless on_typecheck_mismatch is set to ErrorMode.IGNORE

behavior of most kwargs matches that of dataclasses.dataclass, but with some additional kwargs. any kwargs not listed here are passed to dataclasses.dataclass

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

Examines PEP 526 __annotations__ to determine fields.

If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.

@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
    a: int
    b: str
>>> Myclass(a=1, b="q").serialize()
{_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}

Parameters:

  • _cls : _type_ class to decorate. don't pass this arg, just use this as a decorator (defaults to None)
  • init : bool whether to add an __init__ method (passed to dataclasses.dataclass) (defaults to True)
  • repr : bool whether to add a __repr__ method (passed to dataclasses.dataclass) (defaults to True)
  • order : bool whether to add rich comparison methods (passed to dataclasses.dataclass) (defaults to False)
  • unsafe_hash : bool whether to add a __hash__ method (passed to dataclasses.dataclass) (defaults to False)
  • frozen : bool whether to make the class frozen (passed to dataclasses.dataclass) (defaults to False)
  • properties_to_serialize : Optional[list[str]] which properties to add to the serialized data dict SerializableDataclass only (defaults to None)
  • register_handler : bool if true, register the class with ZANJ for loading SerializableDataclass only (defaults to True)
  • on_typecheck_error : ErrorMode what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, type validation will still return false SerializableDataclass only
  • on_typecheck_mismatch : ErrorMode what to do if a type mismatch is found (except, warn, ignore). If ignore, type validation will return True SerializableDataclass only
  • methods_no_override : list[str]|None list of methods that should not be overridden by the decorator by default, __eq__, serialize, load, and validate_fields_types are overridden by this function, but you can disable this if you'd rather write your own. dataclasses.dataclass might still overwrite these, and those options take precedence SerializableDataclass only (defaults to None)
  • **kwargs (passed to dataclasses.dataclass)

Returns:

  • _type_ the decorated class

Raises: