docs for muutils v0.6.13
View Source on GitHub

muutils.json_serialize

submodule for serializing things to json in a recoverable way

you can throw any object into muutils.json_serialize.json_serialize and it will return a JSONitem, meaning a bool, int, float, str, None, list of JSONitems, or a dict mappting to JSONitem.

The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into json_serialize and it will just work. If you want to do so in a recoverable way, check out ZANJ.

it will do so by looking in DEFAULT_HANDLERS, which will keep it as-is if its already valid, then try to find a .serialize() method on the object, and then have a bunch of special cases. You can add handlers by initializing a JsonSerializer object and passing a sequence of them to handlers_pre

additionally, SerializeableDataclass is a special kind of dataclass where you specify how to serialize each field, and a .serialize() method is automatically added to the class. This is done by using the serializable_dataclass decorator, inheriting from SerializeableDataclass, and serializable_field in place of dataclasses.field when defining non-standard fields.

This module plays nicely with and is a dependency of the ZANJ library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.


 1"""submodule for serializing things to json in a recoverable way
 2
 3you can throw *any* object into `muutils.json_serialize.json_serialize`
 4and it will return a `JSONitem`, meaning a bool, int, float, str, None, list of `JSONitem`s, or a dict mappting to `JSONitem`.
 5
 6The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into `json_serialize` and it will just work. If you want to do so in a recoverable way, check out [`ZANJ`](https://github.com/mivanit/ZANJ).
 7
 8it will do so by looking in `DEFAULT_HANDLERS`, which will keep it as-is if its already valid, then try to find a `.serialize()` method on the object, and then have a bunch of special cases. You can add handlers by initializing a `JsonSerializer` object and passing a sequence of them to `handlers_pre`
 9
10additionally, `SerializeableDataclass` is a special kind of dataclass where you specify how to serialize each field, and a `.serialize()` method is automatically added to the class. This is done by using the `serializable_dataclass` decorator, inheriting from `SerializeableDataclass`, and `serializable_field` in place of `dataclasses.field` when defining non-standard fields.
11
12This module plays nicely with and is a dependency of the [`ZANJ`](https://github.com/mivanit/ZANJ) library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.
13
14"""
15
16from __future__ import annotations
17
18from muutils.json_serialize.array import arr_metadata, load_array
19from muutils.json_serialize.json_serialize import (
20    BASE_HANDLERS,
21    JsonSerializer,
22    json_serialize,
23)
24from muutils.json_serialize.serializable_dataclass import (
25    SerializableDataclass,
26    serializable_dataclass,
27    serializable_field,
28)
29from muutils.json_serialize.util import try_catch, JSONitem, dc_eq
30
31__all__ = [
32    # submodules
33    "array",
34    "json_serialize",
35    "serializable_dataclass",
36    "serializable_field",
37    "util",
38    # imports
39    "arr_metadata",
40    "load_array",
41    "BASE_HANDLERS",
42    "JSONitem",
43    "JsonSerializer",
44    "json_serialize",
45    "try_catch",
46    "JSONitem",
47    "dc_eq",
48    "serializable_dataclass",
49    "serializable_field",
50    "SerializableDataclass",
51]

def json_serialize( obj: Any, path: tuple[typing.Union[str, int], ...] = ()) -> Union[bool, int, float, str, list, Dict[str, Any], NoneType]:
330def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem:
331    """serialize object to json-serializable object with default config"""
332    return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)

serialize object to json-serializable object with default config

@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
def serializable_dataclass( _cls=None, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except, on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn, **kwargs):
571@dataclass_transform(
572    field_specifiers=(serializable_field, SerializableField),
573)
574def serializable_dataclass(
575    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
576    _cls=None,  # type: ignore
577    *,
578    init: bool = True,
579    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
580    eq: bool = True,
581    order: bool = False,
582    unsafe_hash: bool = False,
583    frozen: bool = False,
584    properties_to_serialize: Optional[list[str]] = None,
585    register_handler: bool = True,
586    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
587    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
588    **kwargs,
589):
590    """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass`
591
592    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
593
594    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs
595
596    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
597
598    Examines PEP 526 `__annotations__` to determine fields.
599
600    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
601
602    ```python
603    @serializable_dataclass(kw_only=True)
604    class Myclass(SerializableDataclass):
605        a: int
606        b: str
607    ```
608    ```python
609    >>> Myclass(a=1, b="q").serialize()
610    {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
611    ```
612
613    # Parameters:
614     - `_cls : _type_`
615       class to decorate. don't pass this arg, just use this as a decorator
616       (defaults to `None`)
617     - `init : bool`
618       (defaults to `True`)
619     - `repr : bool`
620       (defaults to `True`)
621     - `order : bool`
622       (defaults to `False`)
623     - `unsafe_hash : bool`
624       (defaults to `False`)
625     - `frozen : bool`
626       (defaults to `False`)
627     - `properties_to_serialize : Optional[list[str]]`
628       **SerializableDataclass only:** which properties to add to the serialized data dict
629       (defaults to `None`)
630     - `register_handler : bool`
631        **SerializableDataclass only:** if true, register the class with ZANJ for loading
632       (defaults to `True`)
633     - `on_typecheck_error : ErrorMode`
634        **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
635     - `on_typecheck_mismatch : ErrorMode`
636        **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
637
638    # Returns:
639     - `_type_`
640       the decorated class
641
642    # Raises:
643     - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
644     - `NotSerializableFieldException` : if a field is not a `SerializableField`
645     - `FieldSerializationError` : if there is an error serializing a field
646     - `AttributeError` : if a property is not found on the class
647     - `FieldLoadingError` : if there is an error loading a field
648    """
649    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
650    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
651    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
652
653    if properties_to_serialize is None:
654        _properties_to_serialize: list = list()
655    else:
656        _properties_to_serialize = properties_to_serialize
657
658    def wrap(cls: Type[T]) -> Type[T]:
659        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
660        for field_name, field_type in cls.__annotations__.items():
661            field_value = getattr(cls, field_name, None)
662            if not isinstance(field_value, SerializableField):
663                if isinstance(field_value, dataclasses.Field):
664                    # Convert the field to a SerializableField while preserving properties
665                    field_value = SerializableField.from_Field(field_value)
666                else:
667                    # Create a new SerializableField
668                    field_value = serializable_field()
669                setattr(cls, field_name, field_value)
670
671        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
672        if sys.version_info < (3, 10):
673            if "kw_only" in kwargs:
674                if kwargs["kw_only"] == True:  # noqa: E712
675                    raise KWOnlyError("kw_only is not supported in python >=3.9")
676                else:
677                    del kwargs["kw_only"]
678
679        # call `dataclasses.dataclass` to set some stuff up
680        cls = dataclasses.dataclass(  # type: ignore[call-overload]
681            cls,
682            init=init,
683            repr=repr,
684            eq=eq,
685            order=order,
686            unsafe_hash=unsafe_hash,
687            frozen=frozen,
688            **kwargs,
689        )
690
691        # copy these to the class
692        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
693
694        # ======================================================================
695        # define `serialize` func
696        # done locally since it depends on args to the decorator
697        # ======================================================================
698        def serialize(self) -> dict[str, Any]:
699            result: dict[str, Any] = {
700                "__format__": f"{self.__class__.__name__}(SerializableDataclass)"
701            }
702            # for each field in the class
703            for field in dataclasses.fields(self):  # type: ignore[arg-type]
704                # need it to be our special SerializableField
705                if not isinstance(field, SerializableField):
706                    raise NotSerializableFieldException(
707                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
708                        f"but a {type(field)} "
709                        "this state should be inaccessible, please report this bug!"
710                    )
711
712                # try to save it
713                if field.serialize:
714                    try:
715                        # get the val
716                        value = getattr(self, field.name)
717                        # if it is a serializable dataclass, serialize it
718                        if isinstance(value, SerializableDataclass):
719                            value = value.serialize()
720                        # if the value has a serialization function, use that
721                        if hasattr(value, "serialize") and callable(value.serialize):
722                            value = value.serialize()
723                        # if the field has a serialization function, use that
724                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
725                        elif field.serialization_fn:
726                            value = field.serialization_fn(value)
727
728                        # store the value in the result
729                        result[field.name] = value
730                    except Exception as e:
731                        raise FieldSerializationError(
732                            "\n".join(
733                                [
734                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
735                                    f"{field = }",
736                                    f"{value = }",
737                                    f"{self = }",
738                                ]
739                            )
740                        ) from e
741
742            # store each property if we can get it
743            for prop in self._properties_to_serialize:
744                if hasattr(cls, prop):
745                    value = getattr(self, prop)
746                    result[prop] = value
747                else:
748                    raise AttributeError(
749                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
750                        + f"but it is in {self._properties_to_serialize = }"
751                        + f"\n{self = }"
752                    )
753
754            return result
755
756        # ======================================================================
757        # define `load` func
758        # done locally since it depends on args to the decorator
759        # ======================================================================
760        # mypy thinks this isnt a classmethod
761        @classmethod  # type: ignore[misc]
762        def load(cls, data: dict[str, Any] | T) -> Type[T]:
763            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
764            if isinstance(data, cls):
765                return data
766
767            assert isinstance(
768                data, typing.Mapping
769            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
770
771            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
772
773            # initialize dict for keeping what we will pass to the constructor
774            ctor_kwargs: dict[str, Any] = dict()
775
776            # iterate over the fields of the class
777            for field in dataclasses.fields(cls):
778                # check if the field is a SerializableField
779                assert isinstance(
780                    field, SerializableField
781                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
782
783                # check if the field is in the data and if it should be initialized
784                if (field.name in data) and field.init:
785                    # get the value, we will be processing it
786                    value: Any = data[field.name]
787
788                    # get the type hint for the field
789                    field_type_hint: Any = cls_type_hints.get(field.name, None)
790
791                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
792                    if field.deserialize_fn:
793                        # if it has a deserialization function, use that
794                        value = field.deserialize_fn(value)
795                    elif field.loading_fn:
796                        # if it has a loading function, use that
797                        value = field.loading_fn(data)
798                    elif (
799                        field_type_hint is not None
800                        and hasattr(field_type_hint, "load")
801                        and callable(field_type_hint.load)
802                    ):
803                        # if no loading function but has a type hint with a load method, use that
804                        if isinstance(value, dict):
805                            value = field_type_hint.load(value)
806                        else:
807                            raise FieldLoadingError(
808                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
809                            )
810                    else:
811                        # assume no loading needs to happen, keep `value` as-is
812                        pass
813
814                    # store the value in the constructor kwargs
815                    ctor_kwargs[field.name] = value
816
817            # create a new instance of the class with the constructor kwargs
818            output: cls = cls(**ctor_kwargs)
819
820            # validate the types of the fields if needed
821            if on_typecheck_mismatch != ErrorMode.IGNORE:
822                output.validate_fields_types(on_typecheck_error=on_typecheck_error)
823
824            # return the new instance
825            return output
826
827        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
828        # type is `Callable[[T], dict]`
829        cls.serialize = serialize  # type: ignore[attr-defined]
830        # type is `Callable[[dict], T]`
831        cls.load = load  # type: ignore[attr-defined]
832        # type is `Callable[[T, ErrorMode], bool]`
833        cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]
834
835        # type is `Callable[[T, T], bool]`
836        cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
837
838        # Register the class with ZANJ
839        if register_handler:
840            zanj_register_loader_serializable_dataclass(cls)
841
842        return cls
843
844    if _cls is None:
845        return wrap
846    else:
847        return wrap(_cls)

decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass

types will be validated (like pydantic) unless on_typecheck_mismatch is set to ErrorMode.IGNORE

behavior of most kwargs matches that of dataclasses.dataclass, but with some additional kwargs

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

Examines PEP 526 __annotations__ to determine fields.

If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.

@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
    a: int
    b: str
>>> Myclass(a=1, b="q").serialize()
{'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}

Parameters:

  • _cls : _type_ class to decorate. don't pass this arg, just use this as a decorator (defaults to None)
  • init : bool (defaults to True)
  • repr : bool (defaults to True)
  • order : bool (defaults to False)
  • unsafe_hash : bool (defaults to False)
  • frozen : bool (defaults to False)
  • properties_to_serialize : Optional[list[str]] SerializableDataclass only: which properties to add to the serialized data dict (defaults to None)
  • register_handler : bool SerializableDataclass only: if true, register the class with ZANJ for loading (defaults to True)
  • on_typecheck_error : ErrorMode SerializableDataclass only: what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, type validation will still return false
  • on_typecheck_mismatch : ErrorMode SerializableDataclass only: what to do if a type mismatch is found (except, warn, ignore). If ignore, type validation will return True

Returns:

  • _type_ the decorated class

Raises:

  • KWOnlyError : only raised if kw_only is True and python version is <3.9, since dataclasses.dataclass does not support this
  • NotSerializableFieldException : if a field is not a SerializableField
  • FieldSerializationError : if there is an error serializing a field
  • AttributeError : if a property is not found on the class
  • FieldLoadingError : if there is an error loading a field
def serializable_field( *_args, default: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, default_factory: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, metadata: Optional[mappingproxy] = None, kw_only: Union[bool, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, deserialize_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, custom_typecheck_fn: Optional[Callable[[type], bool]] = None, **kwargs: Any) -> Any:
188def serializable_field(
189    *_args,
190    default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
191    default_factory: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
192    init: bool = True,
193    repr: bool = True,
194    hash: Optional[bool] = None,
195    compare: bool = True,
196    metadata: Optional[types.MappingProxyType] = None,
197    kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
198    serialize: bool = True,
199    serialization_fn: Optional[Callable[[Any], Any]] = None,
200    deserialize_fn: Optional[Callable[[Any], Any]] = None,
201    assert_type: bool = True,
202    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
203    **kwargs: Any,
204) -> Any:
205    """Create a new `SerializableField`
206
207    ```
208    default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
209    default_factory: Callable[[], Sfield_T]
210    | dataclasses._MISSING_TYPE = dataclasses.MISSING,
211    init: bool = True,
212    repr: bool = True,
213    hash: Optional[bool] = None,
214    compare: bool = True,
215    metadata: types.MappingProxyType | None = None,
216    kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
217    # ----------------------------------------------------------------------
218    # new in `SerializableField`, not in `dataclasses.Field`
219    serialize: bool = True,
220    serialization_fn: Optional[Callable[[Any], Any]] = None,
221    loading_fn: Optional[Callable[[Any], Any]] = None,
222    deserialize_fn: Optional[Callable[[Any], Any]] = None,
223    assert_type: bool = True,
224    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
225    ```
226
227    # new Parameters:
228    - `serialize`: whether to serialize this field when serializing the class'
229    - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize`
230    - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.
231    - `deserialize_fn`: new alternative to `loading_fn`. takes only the field's value, not the whole class. if both `loading_fn` and `deserialize_fn` are provided, an error will be raised.
232
233    # Gotchas:
234    - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write:
235
236    ```python
237    class MyClass:
238        my_field: int = serializable_field(
239            serialization_fn=lambda x: str(x),
240            loading_fn=lambda x["my_field"]: int(x)
241        )
242    ```
243
244    using `deserialize_fn` instead:
245
246    ```python
247    class MyClass:
248        my_field: int = serializable_field(
249            serialization_fn=lambda x: str(x),
250            deserialize_fn=lambda x: int(x)
251        )
252    ```
253
254    In the above code, `my_field` is an int but will be serialized as a string.
255
256    note that if not using ZANJ, and you have a class inside a container, you MUST provide
257    `serialization_fn` and `loading_fn` to serialize and load the container.
258    ZANJ will automatically do this for you.
259    """
260    assert len(_args) == 0, f"unexpected positional arguments: {_args}"
261    return SerializableField(
262        default=default,
263        default_factory=default_factory,
264        init=init,
265        repr=repr,
266        hash=hash,
267        compare=compare,
268        metadata=metadata,
269        kw_only=kw_only,
270        serialize=serialize,
271        serialization_fn=serialization_fn,
272        deserialize_fn=deserialize_fn,
273        assert_type=assert_type,
274        custom_typecheck_fn=custom_typecheck_fn,
275        **kwargs,
276    )

Create a new SerializableField

default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
# ----------------------------------------------------------------------
# new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,

new Parameters:

  • serialize: whether to serialize this field when serializing the class'
  • serialization_fn: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the SerializerHandlers defined in muutils.json_serialize.json_serialize
  • loading_fn: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.
  • deserialize_fn: new alternative to loading_fn. takes only the field's value, not the whole class. if both loading_fn and deserialize_fn are provided, an error will be raised.

Gotchas:

  • loading_fn takes the dict of the class, not the field. if you wanted a loading_fn that does nothing, you'd write:
class MyClass:
    my_field: int = serializable_field(
        serialization_fn=lambda x: str(x),
        loading_fn=lambda x["my_field"]: int(x)
    )

using deserialize_fn instead:

class MyClass:
    my_field: int = serializable_field(
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: int(x)
    )

In the above code, my_field is an int but will be serialized as a string.

note that if not using ZANJ, and you have a class inside a container, you MUST provide serialization_fn and loading_fn to serialize and load the container. ZANJ will automatically do this for you.

def arr_metadata(arr) -> dict[str, list[int] | str | int]:
43def arr_metadata(arr) -> dict[str, list[int] | str | int]:
44    """get metadata for a numpy array"""
45    return {
46        "shape": list(arr.shape),
47        "dtype": (
48            arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype)
49        ),
50        "n_elements": array_n_elements(arr),
51    }

get metadata for a numpy array

def load_array( arr: Union[bool, int, float, str, list, Dict[str, Any], NoneType], array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None) -> Any:
162def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any:
163    """load a json-serialized array, infer the mode if not specified"""
164    # return arr if its already a numpy array
165    if isinstance(arr, np.ndarray) and array_mode is None:
166        return arr
167
168    # try to infer the array_mode
169    array_mode_inferred: ArrayMode = infer_array_mode(arr)
170    if array_mode is None:
171        array_mode = array_mode_inferred
172    elif array_mode != array_mode_inferred:
173        warnings.warn(
174            f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}"
175        )
176
177    # actually load the array
178    if array_mode == "array_list_meta":
179        assert isinstance(
180            arr, typing.Mapping
181        ), f"invalid list format: {type(arr) = }\n{arr = }"
182
183        data = np.array(arr["data"], dtype=arr["dtype"])
184        if tuple(arr["shape"]) != tuple(data.shape):
185            raise ValueError(f"invalid shape: {arr}")
186        return data
187
188    elif array_mode == "array_hex_meta":
189        assert isinstance(
190            arr, typing.Mapping
191        ), f"invalid list format: {type(arr) = }\n{arr = }"
192
193        data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"])
194        return data.reshape(arr["shape"])
195
196    elif array_mode == "array_b64_meta":
197        assert isinstance(
198            arr, typing.Mapping
199        ), f"invalid list format: {type(arr) = }\n{arr = }"
200
201        data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"])
202        return data.reshape(arr["shape"])
203
204    elif array_mode == "list":
205        assert isinstance(
206            arr, typing.Sequence
207        ), f"invalid list format: {type(arr) = }\n{arr = }"
208
209        return np.array(arr)
210    elif array_mode == "external":
211        # assume ZANJ has taken care of it
212        assert isinstance(arr, typing.Mapping)
213        if "data" not in arr:
214            raise KeyError(
215                f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}"
216            )
217        return arr["data"]
218    elif array_mode == "zero_dim":
219        assert isinstance(arr, typing.Mapping)
220        data = np.array(arr["data"])
221        if tuple(arr["shape"]) != tuple(data.shape):
222            raise ValueError(f"invalid shape: {arr}")
223        return data
224    else:
225        raise ValueError(f"invalid array_mode: {array_mode}")

load a json-serialized array, infer the mode if not specified

BASE_HANDLERS = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'))
JSONitem = typing.Union[bool, int, float, str, list, typing.Dict[str, typing.Any], NoneType]
class JsonSerializer:
234class JsonSerializer:
235    """Json serialization class (holds configs)
236
237    # Parameters:
238    - `array_mode : ArrayMode`
239    how to write arrays
240    (defaults to `"array_list_meta"`)
241    - `error_mode : ErrorMode`
242    what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn")
243    (defaults to `"except"`)
244    - `handlers_pre : MonoTuple[SerializerHandler]`
245    handlers to use before the default handlers
246    (defaults to `tuple()`)
247    - `handlers_default : MonoTuple[SerializerHandler]`
248    default handlers to use
249    (defaults to `DEFAULT_HANDLERS`)
250    - `write_only_format : bool`
251    changes "__format__" keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading)
252    (defaults to `False`)
253
254    # Raises:
255    - `ValueError`: on init, if `args` is not empty
256    - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"`
257
258    """
259
260    def __init__(
261        self,
262        *args,
263        array_mode: ArrayMode = "array_list_meta",
264        error_mode: ErrorMode = ErrorMode.EXCEPT,
265        handlers_pre: MonoTuple[SerializerHandler] = tuple(),
266        handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
267        write_only_format: bool = False,
268    ):
269        if len(args) > 0:
270            raise ValueError(
271                f"JsonSerializer takes no positional arguments!\n{args = }"
272            )
273
274        self.array_mode: ArrayMode = array_mode
275        self.error_mode: ErrorMode = ErrorMode.from_any(error_mode)
276        self.write_only_format: bool = write_only_format
277        # join up the handlers
278        self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
279            handlers_default
280        )
281
282    def json_serialize(
283        self,
284        obj: Any,
285        path: ObjectPath = tuple(),
286    ) -> JSONitem:
287        try:
288            for handler in self.handlers:
289                if handler.check(self, obj, path):
290                    output: JSONitem = handler.serialize_func(self, obj, path)
291                    if self.write_only_format:
292                        if isinstance(output, dict) and "__format__" in output:
293                            new_fmt: JSONitem = output.pop("__format__")
294                            output["__write_format__"] = new_fmt
295                    return output
296
297            raise ValueError(f"no handler found for object with {type(obj) = }")
298
299        except Exception as e:
300            if self.error_mode == "except":
301                obj_str: str = repr(obj)
302                if len(obj_str) > 1000:
303                    obj_str = obj_str[:1000] + "..."
304                raise SerializationException(
305                    f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}"
306                ) from e
307            elif self.error_mode == "warn":
308                warnings.warn(
309                    f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
310                )
311
312            return repr(obj)
313
314    def hashify(
315        self,
316        obj: Any,
317        path: ObjectPath = tuple(),
318        force: bool = True,
319    ) -> Hashableitem:
320        """try to turn any object into something hashable"""
321        data = self.json_serialize(obj, path=path)
322
323        # recursive hashify, turning dicts and lists into tuples
324        return _recursive_hashify(data, force=force)

Json serialization class (holds configs)

Parameters:

  • array_mode : ArrayMode how to write arrays (defaults to "array_list_meta")
  • error_mode : ErrorMode what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") (defaults to "except")
  • handlers_pre : MonoTuple[SerializerHandler] handlers to use before the default handlers (defaults to tuple())
  • handlers_default : MonoTuple[SerializerHandler] default handlers to use (defaults to DEFAULT_HANDLERS)
  • write_only_format : bool changes "__format__" keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) (defaults to False)

Raises:

  • ValueError: on init, if args is not empty
  • SerializationException: on json_serialize(), if any error occurs when trying to serialize an object and error_mode is set to ErrorMode.EXCEPT"
JsonSerializer( *args, array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim'] = 'array_list_meta', error_mode: muutils.errormode.ErrorMode = ErrorMode.Except, handlers_pre: None = (), handlers_default: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings')), write_only_format: bool = False)
260    def __init__(
261        self,
262        *args,
263        array_mode: ArrayMode = "array_list_meta",
264        error_mode: ErrorMode = ErrorMode.EXCEPT,
265        handlers_pre: MonoTuple[SerializerHandler] = tuple(),
266        handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
267        write_only_format: bool = False,
268    ):
269        if len(args) > 0:
270            raise ValueError(
271                f"JsonSerializer takes no positional arguments!\n{args = }"
272            )
273
274        self.array_mode: ArrayMode = array_mode
275        self.error_mode: ErrorMode = ErrorMode.from_any(error_mode)
276        self.write_only_format: bool = write_only_format
277        # join up the handlers
278        self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
279            handlers_default
280        )
array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
write_only_format: bool
handlers: None
def json_serialize( self, obj: Any, path: tuple[typing.Union[str, int], ...] = ()) -> Union[bool, int, float, str, list, Dict[str, Any], NoneType]:
282    def json_serialize(
283        self,
284        obj: Any,
285        path: ObjectPath = tuple(),
286    ) -> JSONitem:
287        try:
288            for handler in self.handlers:
289                if handler.check(self, obj, path):
290                    output: JSONitem = handler.serialize_func(self, obj, path)
291                    if self.write_only_format:
292                        if isinstance(output, dict) and "__format__" in output:
293                            new_fmt: JSONitem = output.pop("__format__")
294                            output["__write_format__"] = new_fmt
295                    return output
296
297            raise ValueError(f"no handler found for object with {type(obj) = }")
298
299        except Exception as e:
300            if self.error_mode == "except":
301                obj_str: str = repr(obj)
302                if len(obj_str) > 1000:
303                    obj_str = obj_str[:1000] + "..."
304                raise SerializationException(
305                    f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}"
306                ) from e
307            elif self.error_mode == "warn":
308                warnings.warn(
309                    f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
310                )
311
312            return repr(obj)
def hashify( self, obj: Any, path: tuple[typing.Union[str, int], ...] = (), force: bool = True) -> Union[bool, int, float, str, tuple]:
314    def hashify(
315        self,
316        obj: Any,
317        path: ObjectPath = tuple(),
318        force: bool = True,
319    ) -> Hashableitem:
320        """try to turn any object into something hashable"""
321        data = self.json_serialize(obj, path=path)
322
323        # recursive hashify, turning dicts and lists into tuples
324        return _recursive_hashify(data, force=force)

try to turn any object into something hashable

def try_catch(func: Callable):
81def try_catch(func: Callable):
82    """wraps the function to catch exceptions, returns serialized error message on exception
83
84    returned func will return normal result on success, or error message on exception
85    """
86
87    @functools.wraps(func)
88    def newfunc(*args, **kwargs):
89        try:
90            return func(*args, **kwargs)
91        except Exception as e:
92            return f"{e.__class__.__name__}: {e}"
93
94    return newfunc

wraps the function to catch exceptions, returns serialized error message on exception

returned func will return normal result on success, or error message on exception

def dc_eq( dc1, dc2, except_when_class_mismatch: bool = False, false_when_class_mismatch: bool = True, except_when_field_mismatch: bool = False) -> bool:
175def dc_eq(
176    dc1,
177    dc2,
178    except_when_class_mismatch: bool = False,
179    false_when_class_mismatch: bool = True,
180    except_when_field_mismatch: bool = False,
181) -> bool:
182    """
183    checks if two dataclasses which (might) hold numpy arrays are equal
184
185    # Parameters:
186
187    - `dc1`: the first dataclass
188    - `dc2`: the second dataclass
189    - `except_when_class_mismatch: bool`
190        if `True`, will throw `TypeError` if the classes are different.
191        if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False`
192        (default: `False`)
193    - `false_when_class_mismatch: bool`
194        only relevant if `except_when_class_mismatch` is `False`.
195        if `True`, will return `False` if the classes are different.
196        if `False`, will attempt to compare the fields.
197    - `except_when_field_mismatch: bool`
198        only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`.
199        if `True`, will throw `TypeError` if the fields are different.
200        (default: `True`)
201
202    # Returns:
203    - `bool`: True if the dataclasses are equal, False otherwise
204
205    # Raises:
206    - `TypeError`: if the dataclasses are of different classes
207    - `AttributeError`: if the dataclasses have different fields
208
209    ```
210              [START]
211
212           ┌───────────┐  ┌─────────┐
213           │dc1 is dc2?├─►│ classes │
214           └──┬────────┘No│ match?  │
215      ────    │           ├─────────┤
216     (True)◄──┘Yes        │No       │Yes
217      ────                ▼         ▼
218          ┌────────────────┐ ┌────────────┐
219          │ except when    │ │ fields keys│
220          │ class mismatch?│ │ match?     │
221          ├───────────┬────┘ ├───────┬────┘
222          │Yes        │No    │No     │Yes
223          ▼           ▼      ▼       ▼
224     ───────────  ┌──────────┐  ┌────────┐
225    { raise     } │ except   │  │ field  │
226    { TypeError } │ when     │  │ values │
227     ───────────  │ field    │  │ match? │
228                  │ mismatch?│  ├────┬───┘
229                  ├───────┬──┘  │    │Yes
230                  │Yes    │No   │No  ▼
231                  ▼       ▼     │   ────
232     ───────────────     ─────  │  (True)
233    { raise         }   (False)◄┘   ────
234    { AttributeError}    ─────
235     ───────────────
236    ```
237
238    """
239    if dc1 is dc2:
240        return True
241
242    if dc1.__class__ is not dc2.__class__:
243        if except_when_class_mismatch:
244            # if the classes don't match, raise an error
245            raise TypeError(
246                f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`"
247            )
248        if except_when_field_mismatch:
249            dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)])
250            dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)])
251            fields_match: bool = set(dc1_fields) == set(dc2_fields)
252            if not fields_match:
253                # if the fields match, keep going
254                raise AttributeError(
255                    f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`"
256                )
257        return False
258
259    return all(
260        array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name))
261        for fld in dataclasses.fields(dc1)
262        if fld.compare
263    )

checks if two dataclasses which (might) hold numpy arrays are equal

Parameters:

  • dc1: the first dataclass
  • dc2: the second dataclass
  • except_when_class_mismatch: bool if True, will throw TypeError if the classes are different. if not, will return false by default or attempt to compare the fields if false_when_class_mismatch is False (default: False)
  • false_when_class_mismatch: bool only relevant if except_when_class_mismatch is False. if True, will return False if the classes are different. if False, will attempt to compare the fields.
  • except_when_field_mismatch: bool only relevant if except_when_class_mismatch is False and false_when_class_mismatch is False. if True, will throw TypeError if the fields are different. (default: True)

Returns:

  • bool: True if the dataclasses are equal, False otherwise

Raises:

  • TypeError: if the dataclasses are of different classes
  • AttributeError: if the dataclasses have different fields
          [START]
             ▼
       ┌───────────┐  ┌─────────┐
       │dc1 is dc2?├─►│ classes │
       └──┬────────┘No│ match?  │
  ────    │           ├─────────┤
 (True)◄──┘Yes        │No       │Yes
  ────                ▼         ▼
      ┌────────────────┐ ┌────────────┐
      │ except when    │ │ fields keys│
      │ class mismatch?│ │ match?     │
      ├───────────┬────┘ ├───────┬────┘
      │Yes        │No    │No     │Yes
      ▼           ▼      ▼       ▼
 ───────────  ┌──────────┐  ┌────────┐
{ raise     } │ except   │  │ field  │
{ TypeError } │ when     │  │ values │
 ───────────  │ field    │  │ match? │
              │ mismatch?│  ├────┬───┘
              ├───────┬──┘  │    │Yes
              │Yes    │No   │No  ▼
              ▼       ▼     │   ────
 ───────────────     ─────  │  (True)
{ raise         }   (False)◄┘   ────
{ AttributeError}    ─────
 ───────────────
@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
class SerializableDataclass(abc.ABC):
316@dataclass_transform(
317    field_specifiers=(serializable_field, SerializableField),
318)
319class SerializableDataclass(abc.ABC):
320    """Base class for serializable dataclasses
321
322    only for linting and type checking, still need to call `serializable_dataclass` decorator
323
324    # Usage:
325
326    ```python
327    @serializable_dataclass
328    class MyClass(SerializableDataclass):
329        a: int
330        b: str
331    ```
332
333    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
334
335        >>> my_obj = MyClass(a=1, b="q")
336        >>> s = json.dumps(my_obj.serialize())
337        >>> s
338        '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
339        >>> read_obj = MyClass.load(json.loads(s))
340        >>> read_obj == my_obj
341        True
342
343    This isn't too impressive on its own, but it gets more useful when you have nested classses,
344    or fields that are not json-serializable by default:
345
346    ```python
347    @serializable_dataclass
348    class NestedClass(SerializableDataclass):
349        x: str
350        y: MyClass
351        act_fun: torch.nn.Module = serializable_field(
352            default=torch.nn.ReLU(),
353            serialization_fn=lambda x: str(x),
354            deserialize_fn=lambda x: getattr(torch.nn, x)(),
355        )
356    ```
357
358    which gives us:
359
360        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
361        >>> s = json.dumps(nc.serialize())
362        >>> s
363        '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
364        >>> read_nc = NestedClass.load(json.loads(s))
365        >>> read_nc == nc
366        True
367    """
368
369    def serialize(self) -> dict[str, Any]:
370        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
371        raise NotImplementedError(
372            f"decorate {self.__class__ = } with `@serializable_dataclass`"
373        )
374
375    @classmethod
376    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
377        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
378        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
379
380    def validate_fields_types(
381        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
382    ) -> bool:
383        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
384        return SerializableDataclass__validate_fields_types(
385            self, on_typecheck_error=on_typecheck_error
386        )
387
388    def validate_field_type(
389        self,
390        field: "SerializableField|str",
391        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
392    ) -> bool:
393        """given a dataclass, check the field matches the type hint"""
394        return SerializableDataclass__validate_field_type(
395            self, field, on_typecheck_error=on_typecheck_error
396        )
397
398    def __eq__(self, other: Any) -> bool:
399        return dc_eq(self, other)
400
401    def __hash__(self) -> int:
402        "hashes the json-serialized representation of the class"
403        return hash(json.dumps(self.serialize()))
404
405    def diff(
406        self, other: "SerializableDataclass", of_serialized: bool = False
407    ) -> dict[str, Any]:
408        """get a rich and recursive diff between two instances of a serializable dataclass
409
410        ```python
411        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
412        {'b': {'self': 2, 'other': 3}}
413        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
414        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
415        ```
416
417        # Parameters:
418         - `other : SerializableDataclass`
419           other instance to compare against
420         - `of_serialized : bool`
421           if true, compare serialized data and not raw values
422           (defaults to `False`)
423
424        # Returns:
425         - `dict[str, Any]`
426
427
428        # Raises:
429         - `ValueError` : if the instances are not of the same type
430         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
431        """
432        # match types
433        if type(self) is not type(other):
434            raise ValueError(
435                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
436            )
437
438        # initialize the diff result
439        diff_result: dict = {}
440
441        # if they are the same, return the empty diff
442        if self == other:
443            return diff_result
444
445        # if we are working with serialized data, serialize the instances
446        if of_serialized:
447            ser_self: dict = self.serialize()
448            ser_other: dict = other.serialize()
449
450        # for each field in the class
451        for field in dataclasses.fields(self):  # type: ignore[arg-type]
452            # skip fields that are not for comparison
453            if not field.compare:
454                continue
455
456            # get values
457            field_name: str = field.name
458            self_value = getattr(self, field_name)
459            other_value = getattr(other, field_name)
460
461            # if the values are both serializable dataclasses, recurse
462            if isinstance(self_value, SerializableDataclass) and isinstance(
463                other_value, SerializableDataclass
464            ):
465                nested_diff: dict = self_value.diff(
466                    other_value, of_serialized=of_serialized
467                )
468                if nested_diff:
469                    diff_result[field_name] = nested_diff
470            # only support serializable dataclasses
471            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
472                other_value
473            ):
474                raise ValueError("Non-serializable dataclass is not supported")
475            else:
476                # get the values of either the serialized or the actual values
477                self_value_s = ser_self[field_name] if of_serialized else self_value
478                other_value_s = ser_other[field_name] if of_serialized else other_value
479                # compare the values
480                if not array_safe_eq(self_value_s, other_value_s):
481                    diff_result[field_name] = {"self": self_value, "other": other_value}
482
483        # return the diff result
484        return diff_result
485
486    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
487        """update the instance from a nested dict, useful for configuration from command line args
488
489        # Parameters:
490            - `nested_dict : dict[str, Any]`
491                nested dict to update the instance with
492        """
493        for field in dataclasses.fields(self):  # type: ignore[arg-type]
494            field_name: str = field.name
495            self_value = getattr(self, field_name)
496
497            if field_name in nested_dict:
498                if isinstance(self_value, SerializableDataclass):
499                    self_value.update_from_nested_dict(nested_dict[field_name])
500                else:
501                    setattr(self, field_name, nested_dict[field_name])
502
503    def __copy__(self) -> "SerializableDataclass":
504        "deep copy by serializing and loading the instance to json"
505        return self.__class__.load(json.loads(json.dumps(self.serialize())))
506
507    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
508        "deep copy by serializing and loading the instance to json"
509        return self.__class__.load(json.loads(json.dumps(self.serialize())))

Base class for serializable dataclasses

only for linting and type checking, still need to call serializable_dataclass decorator

Usage:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]:
369    def serialize(self) -> dict[str, Any]:
370        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
371        raise NotImplementedError(
372            f"decorate {self.__class__ = } with `@serializable_dataclass`"
373        )

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls: Type[~T], data: Union[dict[str, Any], ~T]) -> ~T:
375    @classmethod
376    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
377        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
378        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
380    def validate_fields_types(
381        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
382    ) -> bool:
383        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
384        return SerializableDataclass__validate_fields_types(
385            self, on_typecheck_error=on_typecheck_error
386        )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

def validate_field_type( self, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
388    def validate_field_type(
389        self,
390        field: "SerializableField|str",
391        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
392    ) -> bool:
393        """given a dataclass, check the field matches the type hint"""
394        return SerializableDataclass__validate_field_type(
395            self, field, on_typecheck_error=on_typecheck_error
396        )

given a dataclass, check the field matches the type hint

def diff( self, other: SerializableDataclass, of_serialized: bool = False) -> dict[str, typing.Any]:
405    def diff(
406        self, other: "SerializableDataclass", of_serialized: bool = False
407    ) -> dict[str, Any]:
408        """get a rich and recursive diff between two instances of a serializable dataclass
409
410        ```python
411        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
412        {'b': {'self': 2, 'other': 3}}
413        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
414        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
415        ```
416
417        # Parameters:
418         - `other : SerializableDataclass`
419           other instance to compare against
420         - `of_serialized : bool`
421           if true, compare serialized data and not raw values
422           (defaults to `False`)
423
424        # Returns:
425         - `dict[str, Any]`
426
427
428        # Raises:
429         - `ValueError` : if the instances are not of the same type
430         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
431        """
432        # match types
433        if type(self) is not type(other):
434            raise ValueError(
435                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
436            )
437
438        # initialize the diff result
439        diff_result: dict = {}
440
441        # if they are the same, return the empty diff
442        if self == other:
443            return diff_result
444
445        # if we are working with serialized data, serialize the instances
446        if of_serialized:
447            ser_self: dict = self.serialize()
448            ser_other: dict = other.serialize()
449
450        # for each field in the class
451        for field in dataclasses.fields(self):  # type: ignore[arg-type]
452            # skip fields that are not for comparison
453            if not field.compare:
454                continue
455
456            # get values
457            field_name: str = field.name
458            self_value = getattr(self, field_name)
459            other_value = getattr(other, field_name)
460
461            # if the values are both serializable dataclasses, recurse
462            if isinstance(self_value, SerializableDataclass) and isinstance(
463                other_value, SerializableDataclass
464            ):
465                nested_diff: dict = self_value.diff(
466                    other_value, of_serialized=of_serialized
467                )
468                if nested_diff:
469                    diff_result[field_name] = nested_diff
470            # only support serializable dataclasses
471            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
472                other_value
473            ):
474                raise ValueError("Non-serializable dataclass is not supported")
475            else:
476                # get the values of either the serialized or the actual values
477                self_value_s = ser_self[field_name] if of_serialized else self_value
478                other_value_s = ser_other[field_name] if of_serialized else other_value
479                # compare the values
480                if not array_safe_eq(self_value_s, other_value_s):
481                    diff_result[field_name] = {"self": self_value, "other": other_value}
482
483        # return the diff result
484        return diff_result

get a rich and recursive diff between two instances of a serializable dataclass

>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}

Parameters:

  • other : SerializableDataclass other instance to compare against
  • of_serialized : bool if true, compare serialized data and not raw values (defaults to False)

Returns:

  • dict[str, Any]

Raises:

  • ValueError : if the instances are not of the same type
  • ValueError : if the instances are dataclasses.dataclass but not SerializableDataclass
def update_from_nested_dict(self, nested_dict: dict[str, typing.Any]):
486    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
487        """update the instance from a nested dict, useful for configuration from command line args
488
489        # Parameters:
490            - `nested_dict : dict[str, Any]`
491                nested dict to update the instance with
492        """
493        for field in dataclasses.fields(self):  # type: ignore[arg-type]
494            field_name: str = field.name
495            self_value = getattr(self, field_name)
496
497            if field_name in nested_dict:
498                if isinstance(self_value, SerializableDataclass):
499                    self_value.update_from_nested_dict(nested_dict[field_name])
500                else:
501                    setattr(self, field_name, nested_dict[field_name])

update the instance from a nested dict, useful for configuration from command line args

Parameters:

- `nested_dict : dict[str, Any]`
    nested dict to update the instance with