docs for muutils v0.9.1
View Source on GitHub

muutils.json_serialize

submodule for serializing things to json in a recoverable way

you can throw any object into muutils.json_serialize.json_serialize and it will return a JSONitem, meaning a bool, int, float, str, None, list of JSONitems, or a dict mappting to JSONitem.

The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into json_serialize and it will just work. If you want to do so in a recoverable way, check out ZANJ.

it will do so by looking in DEFAULT_HANDLERS, which will keep it as-is if its already valid, then try to find a .serialize() method on the object, and then have a bunch of special cases. You can add handlers by initializing a JsonSerializer object and passing a sequence of them to handlers_pre

additionally, SerializeableDataclass is a special kind of dataclass where you specify how to serialize each field, and a .serialize() method is automatically added to the class. This is done by using the serializable_dataclass decorator, inheriting from SerializeableDataclass, and serializable_field in place of dataclasses.field when defining non-standard fields.

This module plays nicely with and is a dependency of the ZANJ library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.


 1"""submodule for serializing things to json in a recoverable way
 2
 3you can throw *any* object into `muutils.json_serialize.json_serialize`
 4and it will return a `JSONitem`, meaning a bool, int, float, str, None, list of `JSONitem`s, or a dict mappting to `JSONitem`.
 5
 6The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into `json_serialize` and it will just work. If you want to do so in a recoverable way, check out [`ZANJ`](https://github.com/mivanit/ZANJ).
 7
 8it will do so by looking in `DEFAULT_HANDLERS`, which will keep it as-is if its already valid, then try to find a `.serialize()` method on the object, and then have a bunch of special cases. You can add handlers by initializing a `JsonSerializer` object and passing a sequence of them to `handlers_pre`
 9
10additionally, `SerializeableDataclass` is a special kind of dataclass where you specify how to serialize each field, and a `.serialize()` method is automatically added to the class. This is done by using the `serializable_dataclass` decorator, inheriting from `SerializeableDataclass`, and `serializable_field` in place of `dataclasses.field` when defining non-standard fields.
11
12This module plays nicely with and is a dependency of the [`ZANJ`](https://github.com/mivanit/ZANJ) library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.
13
14"""
15
16from __future__ import annotations
17
18from muutils.json_serialize.array import arr_metadata, load_array
19from muutils.json_serialize.json_serialize import (
20    BASE_HANDLERS,
21    JsonSerializer,
22    json_serialize,
23)
24from muutils.json_serialize.serializable_dataclass import (
25    SerializableDataclass,
26    serializable_dataclass,
27    serializable_field,
28)
29from muutils.json_serialize.util import try_catch, JSONitem, dc_eq
30
31__all__ = [
32    # submodules
33    "array",
34    "json_serialize",
35    "serializable_dataclass",
36    "serializable_field",
37    "util",
38    # imports
39    "arr_metadata",
40    "load_array",
41    "BASE_HANDLERS",
42    "JSONitem",
43    "JsonSerializer",
44    "json_serialize",
45    "try_catch",
46    "JSONitem",
47    "dc_eq",
48    "serializable_dataclass",
49    "serializable_field",
50    "SerializableDataclass",
51]

def json_serialize( obj: Any, path: tuple[typing.Union[str, int], ...] = ()) -> Union[bool, int, float, str, NoneType, Sequence[ForwardRef('JSONitem')], Dict[str, ForwardRef('JSONitem')]]:
402def json_serialize(obj: Any, path: ObjectPath = ()) -> JSONitem:  # pyright: ignore[reportAny]
403    """serialize object to json-serializable object with default config"""
404    return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)

serialize object to json-serializable object with default config

@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
def serializable_dataclass( _cls=None, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except, on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn, methods_no_override: list[str] | None = None, **kwargs: Any) -> Any:
586@dataclass_transform(
587    field_specifiers=(serializable_field, SerializableField),
588)
589def serializable_dataclass(
590    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
591    _cls=None,  # type: ignore
592    *,
593    init: bool = True,
594    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
595    eq: bool = True,
596    order: bool = False,
597    unsafe_hash: bool = False,
598    frozen: bool = False,
599    properties_to_serialize: Optional[list[str]] = None,
600    register_handler: bool = True,
601    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
602    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
603    methods_no_override: list[str] | None = None,
604    **kwargs: Any,
605) -> Any:
606    """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!**
607
608    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
609
610    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass`
611
612    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
613
614    Examines PEP 526 `__annotations__` to determine fields.
615
616    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
617
618    ```python
619    @serializable_dataclass(kw_only=True)
620    class Myclass(SerializableDataclass):
621        a: int
622        b: str
623    ```
624    ```python
625    >>> Myclass(a=1, b="q").serialize()
626    {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
627    ```
628
629    # Parameters:
630
631    - `_cls : _type_`
632       class to decorate. don't pass this arg, just use this as a decorator
633       (defaults to `None`)
634    - `init : bool`
635       whether to add an `__init__` method
636       *(passed to dataclasses.dataclass)*
637       (defaults to `True`)
638    - `repr : bool`
639       whether to add a `__repr__` method
640       *(passed to dataclasses.dataclass)*
641       (defaults to `True`)
642    - `order : bool`
643       whether to add rich comparison methods
644       *(passed to dataclasses.dataclass)*
645       (defaults to `False`)
646    - `unsafe_hash : bool`
647       whether to add a `__hash__` method
648       *(passed to dataclasses.dataclass)*
649       (defaults to `False`)
650    - `frozen : bool`
651       whether to make the class frozen
652       *(passed to dataclasses.dataclass)*
653       (defaults to `False`)
654    - `properties_to_serialize : Optional[list[str]]`
655       which properties to add to the serialized data dict
656       **SerializableDataclass only**
657       (defaults to `None`)
658    - `register_handler : bool`
659        if true, register the class with ZANJ for loading
660        **SerializableDataclass only**
661        (defaults to `True`)
662    - `on_typecheck_error : ErrorMode`
663        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
664        **SerializableDataclass only**
665    - `on_typecheck_mismatch : ErrorMode`
666        what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
667        **SerializableDataclass only**
668    - `methods_no_override : list[str]|None`
669        list of methods that should not be overridden by the decorator
670        by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function,
671        but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence
672        **SerializableDataclass only**
673        (defaults to `None`)
674    - `**kwargs`
675        *(passed to dataclasses.dataclass)*
676
677    # Returns:
678
679    - `_type_`
680       the decorated class
681
682    # Raises:
683
684    - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
685    - `NotSerializableFieldException` : if a field is not a `SerializableField`
686    - `FieldSerializationError` : if there is an error serializing a field
687    - `AttributeError` : if a property is not found on the class
688    - `FieldLoadingError` : if there is an error loading a field
689    """
690    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
691    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
692    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
693
694    if properties_to_serialize is None:
695        _properties_to_serialize: list = list()
696    else:
697        _properties_to_serialize = properties_to_serialize
698
699    def wrap(cls: Type[T_SerializeableDataclass]) -> Type[T_SerializeableDataclass]:
700        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
701        for field_name, field_type in cls.__annotations__.items():
702            field_value = getattr(cls, field_name, None)
703            if not isinstance(field_value, SerializableField):
704                if isinstance(field_value, dataclasses.Field):
705                    # Convert the field to a SerializableField while preserving properties
706                    field_value = SerializableField.from_Field(field_value)
707                else:
708                    # Create a new SerializableField
709                    field_value = serializable_field()
710                setattr(cls, field_name, field_value)
711
712        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
713        if sys.version_info < (3, 10):
714            if "kw_only" in kwargs:
715                if kwargs["kw_only"] == True:  # noqa: E712
716                    raise KWOnlyError(
717                        "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored"
718                    )
719                else:
720                    del kwargs["kw_only"]
721
722        # call `dataclasses.dataclass` to set some stuff up
723        cls = dataclasses.dataclass(  # type: ignore[call-overload]
724            cls,
725            init=init,
726            repr=repr,
727            eq=eq,
728            order=order,
729            unsafe_hash=unsafe_hash,
730            frozen=frozen,
731            **kwargs,
732        )
733
734        # copy these to the class
735        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
736
737        # ======================================================================
738        # define `serialize` func
739        # done locally since it depends on args to the decorator
740        # ======================================================================
741        def serialize(self: Any) -> dict[str, Any]:
742            result: dict[str, Any] = {
743                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
744            }
745            # for each field in the class
746            for field in dataclasses.fields(self):  # type: ignore[arg-type]
747                # need it to be our special SerializableField
748                if not isinstance(field, SerializableField):
749                    raise NotSerializableFieldException(
750                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
751                        f"but a {type(field)} "
752                        "this state should be inaccessible, please report this bug!"
753                    )
754
755                # try to save it
756                if field.serialize:
757                    value: Any = None  # init before try in case getattr raises
758                    try:
759                        # get the val
760                        value = getattr(self, field.name)
761                        # if it is a serializable dataclass, serialize it
762                        if isinstance(value, SerializableDataclass):
763                            value = value.serialize()
764                        # if the value has a serialization function, use that
765                        if hasattr(value, "serialize") and callable(value.serialize):  # pyright: ignore[reportAttributeAccessIssue]
766                            value = value.serialize()  # pyright: ignore[reportAttributeAccessIssue]
767                        # if the field has a serialization function, use that
768                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
769                        elif field.serialization_fn:
770                            value = field.serialization_fn(value)
771
772                        # store the value in the result
773                        result[field.name] = value
774                    except Exception as e:
775                        raise FieldSerializationError(
776                            "\n".join(
777                                [
778                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
779                                    f"{field = }",
780                                    f"{value or '<unavailable>' = }",
781                                    f"{self = }",
782                                ]
783                            )
784                        ) from e
785
786            # store each property if we can get it
787            for prop in self._properties_to_serialize:
788                if hasattr(cls, prop):
789                    value = getattr(self, prop)
790                    result[prop] = value
791                else:
792                    raise AttributeError(
793                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
794                        + f"but it is in {self._properties_to_serialize = }"
795                        + f"\n{self = }"
796                    )
797
798            return result
799
800        # ======================================================================
801        # define `load` func
802        # done locally since it depends on args to the decorator
803        # ======================================================================
804        # mypy thinks this isnt a classmethod
805        @classmethod  # type: ignore[misc]
806        def load(
807            cls: type[T_SerializeableDataclass],
808            data: dict[str, Any] | T_SerializeableDataclass,
809        ) -> T_SerializeableDataclass:
810            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
811            if isinstance(data, cls):
812                return data
813
814            assert isinstance(data, typing.Mapping), (
815                f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
816            )
817
818            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
819
820            # initialize dict for keeping what we will pass to the constructor
821            ctor_kwargs: dict[str, Any] = dict()
822
823            # iterate over the fields of the class
824            # mypy doesn't recognize @dataclass_transform for dataclasses.fields()
825            # https://github.com/python/mypy/issues/16241
826            for field in dataclasses.fields(cls):  # type: ignore[arg-type]
827                # check if the field is a SerializableField
828                assert isinstance(field, SerializableField), (
829                    f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
830                )
831
832                # check if the field is in the data and if it should be initialized
833                if (field.name in data) and field.init:
834                    # get the value, we will be processing it
835                    value: Any = data[field.name]
836
837                    # get the type hint for the field
838                    field_type_hint: Any = cls_type_hints.get(field.name, None)
839
840                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
841                    if field.deserialize_fn:
842                        # if it has a deserialization function, use that
843                        value = field.deserialize_fn(value)
844                    elif field.loading_fn:
845                        # if it has a loading function, use that
846                        value = field.loading_fn(data)
847                    elif (
848                        field_type_hint is not None
849                        and hasattr(field_type_hint, "load")
850                        and callable(field_type_hint.load)
851                    ):
852                        # if no loading function but has a type hint with a load method, use that
853                        if isinstance(value, dict):
854                            value = field_type_hint.load(value)
855                        else:
856                            raise FieldLoadingError(
857                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
858                            )
859                    else:
860                        # assume no loading needs to happen, keep `value` as-is
861                        pass
862
863                    # store the value in the constructor kwargs
864                    ctor_kwargs[field.name] = value
865
866            # create a new instance of the class with the constructor kwargs
867            output: T_SerializeableDataclass = cls(**ctor_kwargs)
868
869            # validate the types of the fields if needed
870            if on_typecheck_mismatch != ErrorMode.IGNORE:
871                fields_valid: dict[str, bool] = (
872                    SerializableDataclass__validate_fields_types__dict(
873                        output,
874                        on_typecheck_error=on_typecheck_error,
875                    )
876                )
877
878                # if there are any fields that are not valid, raise an error
879                if not all(fields_valid.values()):
880                    msg: str = (
881                        f"Type mismatch in fields of {cls.__name__}:\n"
882                        + "\n".join(
883                            [
884                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
885                                for k, v in fields_valid.items()
886                                if not v
887                            ]
888                        )
889                    )
890
891                    on_typecheck_mismatch.process(
892                        msg, except_cls=FieldTypeMismatchError
893                    )
894
895            # return the new instance
896            return output
897
898        _methods_no_override: set[str]
899        if methods_no_override is None:
900            _methods_no_override = set()
901        else:
902            _methods_no_override = set(methods_no_override)
903
904        if _methods_no_override - {
905            "__eq__",
906            "serialize",
907            "load",
908            "validate_fields_types",
909        }:
910            warnings.warn(
911                f"Unknown methods in `methods_no_override`: {_methods_no_override = }"
912            )
913
914        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
915        if "serialize" not in _methods_no_override:
916            # type is `Callable[[T], dict]`
917            cls.serialize = serialize  # type: ignore[attr-defined, method-assign]
918        if "load" not in _methods_no_override:
919            # type is `Callable[[dict], T]`
920            cls.load = load  # type: ignore[attr-defined, method-assign, assignment]
921
922        if "validate_field_type" not in _methods_no_override:
923            # type is `Callable[[T, ErrorMode], bool]`
924            cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined, method-assign]
925
926        if "__eq__" not in _methods_no_override:
927            # type is `Callable[[T, T], bool]`
928            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
929
930        # Register the class with ZANJ
931        if register_handler:
932            zanj_register_loader_serializable_dataclass(cls)
933
934        return cls
935
936    if _cls is None:
937        return wrap
938    else:
939        return wrap(_cls)

decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass!!

types will be validated (like pydantic) unless on_typecheck_mismatch is set to ErrorMode.IGNORE

behavior of most kwargs matches that of dataclasses.dataclass, but with some additional kwargs. any kwargs not listed here are passed to dataclasses.dataclass

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

Examines PEP 526 __annotations__ to determine fields.

If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.

@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
    a: int
    b: str
>>> Myclass(a=1, b="q").serialize()
{_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}

Parameters:

  • _cls : _type_ class to decorate. don't pass this arg, just use this as a decorator (defaults to None)
  • init : bool whether to add an __init__ method (passed to dataclasses.dataclass) (defaults to True)
  • repr : bool whether to add a __repr__ method (passed to dataclasses.dataclass) (defaults to True)
  • order : bool whether to add rich comparison methods (passed to dataclasses.dataclass) (defaults to False)
  • unsafe_hash : bool whether to add a __hash__ method (passed to dataclasses.dataclass) (defaults to False)
  • frozen : bool whether to make the class frozen (passed to dataclasses.dataclass) (defaults to False)
  • properties_to_serialize : Optional[list[str]] which properties to add to the serialized data dict SerializableDataclass only (defaults to None)
  • register_handler : bool if true, register the class with ZANJ for loading SerializableDataclass only (defaults to True)
  • on_typecheck_error : ErrorMode what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, type validation will still return false SerializableDataclass only
  • on_typecheck_mismatch : ErrorMode what to do if a type mismatch is found (except, warn, ignore). If ignore, type validation will return True SerializableDataclass only
  • methods_no_override : list[str]|None list of methods that should not be overridden by the decorator by default, __eq__, serialize, load, and validate_fields_types are overridden by this function, but you can disable this if you'd rather write your own. dataclasses.dataclass might still overwrite these, and those options take precedence SerializableDataclass only (defaults to None)
  • **kwargs (passed to dataclasses.dataclass)

Returns:

  • _type_ the decorated class

Raises:

  • KWOnlyError : only raised if kw_only is True and python version is <3.9, since dataclasses.dataclass does not support this
  • NotSerializableFieldException : if a field is not a SerializableField
  • FieldSerializationError : if there is an error serializing a field
  • AttributeError : if a property is not found on the class
  • FieldLoadingError : if there is an error loading a field
def serializable_field( *_args: Any, default: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, default_factory: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, doc: str | None = None, metadata: Optional[mappingproxy] = None, kw_only: Union[bool, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, deserialize_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, custom_typecheck_fn: Optional[Callable[[type], bool]] = None, **kwargs: Any) -> Any:
202def serializable_field(  # general implementation
203    *_args: Any,
204    default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
205    default_factory: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
206    init: bool = True,
207    repr: bool = True,
208    hash: Optional[bool] = None,
209    compare: bool = True,
210    doc: str | None = None,
211    metadata: Optional[types.MappingProxyType] = None,
212    kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
213    serialize: bool = True,
214    serialization_fn: Optional[Callable[[Any], Any]] = None,
215    deserialize_fn: Optional[Callable[[Any], Any]] = None,
216    assert_type: bool = True,
217    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
218    **kwargs: Any,
219) -> Any:
220    """Create a new `SerializableField`
221
222    ```
223    default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
224    default_factory: Callable[[], Sfield_T]
225    | dataclasses._MISSING_TYPE = dataclasses.MISSING,
226    init: bool = True,
227    repr: bool = True,
228    hash: Optional[bool] = None,
229    compare: bool = True,
230    doc: str | None = None, # new in python 3.14. can alternately pass `description` to match pydantic, but this is discouraged
231    metadata: types.MappingProxyType | None = None,
232    kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
233    # ----------------------------------------------------------------------
234    # new in `SerializableField`, not in `dataclasses.Field`
235    serialize: bool = True,
236    serialization_fn: Optional[Callable[[Any], Any]] = None,
237    loading_fn: Optional[Callable[[Any], Any]] = None,
238    deserialize_fn: Optional[Callable[[Any], Any]] = None,
239    assert_type: bool = True,
240    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
241    ```
242
243    # new Parameters:
244    - `serialize`: whether to serialize this field when serializing the class'
245    - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize`
246    - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.
247    - `deserialize_fn`: new alternative to `loading_fn`. takes only the field's value, not the whole class. if both `loading_fn` and `deserialize_fn` are provided, an error will be raised.
248    - `assert_type`: whether to assert the type of the field when loading. if `False`, will not check the type of the field.
249    - `custom_typecheck_fn`: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.
250
251    # Gotchas:
252    - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write:
253
254    ```python
255    class MyClass:
256        my_field: int = serializable_field(
257            serialization_fn=lambda x: str(x),
258            loading_fn=lambda x["my_field"]: int(x)
259        )
260    ```
261
262    using `deserialize_fn` instead:
263
264    ```python
265    class MyClass:
266        my_field: int = serializable_field(
267            serialization_fn=lambda x: str(x),
268            deserialize_fn=lambda x: int(x)
269        )
270    ```
271
272    In the above code, `my_field` is an int but will be serialized as a string.
273
274    note that if not using ZANJ, and you have a class inside a container, you MUST provide
275    `serialization_fn` and `loading_fn` to serialize and load the container.
276    ZANJ will automatically do this for you.
277
278    # TODO: `custom_value_check_fn`: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test
279    """
280    assert len(_args) == 0, f"unexpected positional arguments: {_args}"
281
282    if "description" in kwargs:
283        import warnings
284
285        warnings.warn(
286            "`description` is deprecated, use `doc` instead",
287            DeprecationWarning,
288        )
289        if doc is not None:
290            err_msg: str = f"cannot pass both `doc` and `description`: {doc=}, {kwargs['description']=}"
291            raise ValueError(err_msg)
292        doc = kwargs.pop("description")
293
294    return SerializableField(
295        default=default,
296        default_factory=default_factory,
297        init=init,
298        repr=repr,
299        hash=hash,
300        compare=compare,
301        doc=doc,
302        metadata=metadata,
303        kw_only=kw_only,
304        serialize=serialize,
305        serialization_fn=serialization_fn,
306        deserialize_fn=deserialize_fn,
307        assert_type=assert_type,
308        custom_typecheck_fn=custom_typecheck_fn,
309        **kwargs,
310    )

Create a new SerializableField

default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
doc: str | None = None, # new in python 3.14. can alternately pass `description` to match pydantic, but this is discouraged
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
# ----------------------------------------------------------------------
# new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,

new Parameters:

  • serialize: whether to serialize this field when serializing the class'
  • serialization_fn: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the SerializerHandlers defined in muutils.json_serialize.json_serialize
  • loading_fn: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.
  • deserialize_fn: new alternative to loading_fn. takes only the field's value, not the whole class. if both loading_fn and deserialize_fn are provided, an error will be raised.
  • assert_type: whether to assert the type of the field when loading. if False, will not check the type of the field.
  • custom_typecheck_fn: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.

Gotchas:

  • loading_fn takes the dict of the class, not the field. if you wanted a loading_fn that does nothing, you'd write:
class MyClass:
    my_field: int = serializable_field(
        serialization_fn=lambda x: str(x),
        loading_fn=lambda x["my_field"]: int(x)
    )

using deserialize_fn instead:

class MyClass:
    my_field: int = serializable_field(
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: int(x)
    )

In the above code, my_field is an int but will be serialized as a string.

note that if not using ZANJ, and you have a class inside a container, you MUST provide serialization_fn and loading_fn to serialize and load the container. ZANJ will automatically do this for you.

TODO: custom_value_check_fn: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test

def arr_metadata(arr: Any) -> muutils.json_serialize.array.ArrayMetadata:
103def arr_metadata(arr: Any) -> ArrayMetadata:  # pyright: ignore[reportAny]
104    """get metadata for a numpy array"""
105    return {
106        "shape": list(arr.shape),  # pyright: ignore[reportAny]
107        "dtype": (
108            arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype)  # pyright: ignore[reportAny]
109        ),
110        "n_elements": array_n_elements(arr),
111    }

get metadata for a numpy array

def load_array( arr: Union[muutils.json_serialize.array.SerializedArrayWithMeta, numpy.ndarray, List[Union[int, float, bool]], List[Union[List[Union[int, float, bool]], List[ForwardRef('NumericList')]]]], array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None) -> numpy.ndarray:
289def load_array(
290    arr: Union[SerializedArrayWithMeta, np.ndarray, NumericList],
291    array_mode: Optional[ArrayMode] = None,
292) -> np.ndarray:
293    """load a json-serialized array, infer the mode if not specified"""
294    # return arr if its already a numpy array
295    if isinstance(arr, np.ndarray):
296        assert array_mode is None, (
297            "array_mode should not be specified when loading a numpy array, since that is a no-op"
298        )
299        return arr
300
301    # try to infer the array_mode
302    array_mode_inferred: ArrayMode = infer_array_mode(arr)
303    if array_mode is None:
304        array_mode = array_mode_inferred
305    elif array_mode != array_mode_inferred:
306        warnings.warn(
307            f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}"
308        )
309
310    # actually load the array
311    if array_mode == "array_list_meta":
312        assert isinstance(arr, typing.Mapping), (
313            f"invalid list format: {type(arr) = }\n{arr = }"
314        )
315        data = np.array(arr["data"], dtype=arr["dtype"])  # type: ignore
316        if tuple(arr["shape"]) != tuple(data.shape):  # type: ignore
317            raise ValueError(f"invalid shape: {arr}")
318        return data
319
320    elif array_mode == "array_hex_meta":
321        assert isinstance(arr, typing.Mapping), (
322            f"invalid list format: {type(arr) = }\n{arr = }"
323        )
324        data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"])  # type: ignore
325        return data.reshape(arr["shape"])  # type: ignore
326
327    elif array_mode == "array_b64_meta":
328        assert isinstance(arr, typing.Mapping), (
329            f"invalid list format: {type(arr) = }\n{arr = }"
330        )
331        data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"])  # type: ignore
332        return data.reshape(arr["shape"])  # type: ignore
333
334    elif array_mode == "list":
335        assert isinstance(arr, typing.Sequence), (
336            f"invalid list format: {type(arr) = }\n{arr = }"
337        )
338        return np.array(arr)  # type: ignore
339    elif array_mode == "external":
340        assert isinstance(arr, typing.Mapping)
341        if "data" not in arr:
342            raise KeyError(  # pyright: ignore[reportUnreachable]
343                f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}"
344            )
345        # we can ignore here since we assume ZANJ has taken care of it
346        return arr["data"]  # type: ignore[return-value] # pyright: ignore[reportReturnType]
347    elif array_mode == "zero_dim":
348        assert isinstance(arr, typing.Mapping)
349        data = np.array(arr["data"])  # ty: ignore[invalid-argument-type]
350        if tuple(arr["shape"]) != tuple(data.shape):  # type: ignore
351            raise ValueError(f"invalid shape: {arr}")
352        return data
353    else:
354        raise ValueError(f"invalid array_mode: {array_mode}")  # pyright: ignore[reportUnreachable]

load a json-serialized array, infer the mode if not specified

BASE_HANDLERS = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'))
JSONitem = typing.Union[bool, int, float, str, NoneType, typing.Sequence[ForwardRef('JSONitem')], typing.Dict[str, ForwardRef('JSONitem')]]
class JsonSerializer:
279class JsonSerializer:
280    """Json serialization class (holds configs)
281
282    # Parameters:
283    - `array_mode : ArrayMode`
284    how to write arrays
285    (defaults to `"array_list_meta"`)
286    - `error_mode : ErrorMode`
287    what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn")
288    (defaults to `"except"`)
289    - `handlers_pre : MonoTuple[SerializerHandler]`
290    handlers to use before the default handlers
291    (defaults to `tuple()`)
292    - `handlers_default : MonoTuple[SerializerHandler]`
293    default handlers to use
294    (defaults to `DEFAULT_HANDLERS`)
295    - `write_only_format : bool`
296    changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading)
297    (defaults to `False`)
298
299    # Raises:
300    - `ValueError`: on init, if `args` is not empty
301    - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"`
302
303    """
304
305    def __init__(
306        self,
307        *args: None,
308        array_mode: "ArrayMode" = "array_list_meta",
309        error_mode: ErrorMode = ErrorMode.EXCEPT,
310        handlers_pre: MonoTuple[SerializerHandler] = (),
311        handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
312        write_only_format: bool = False,
313    ):
314        if len(args) > 0:
315            raise ValueError(
316                f"JsonSerializer takes no positional arguments!\n{args = }"
317            )
318
319        self.array_mode: "ArrayMode" = array_mode
320        self.error_mode: ErrorMode = ErrorMode.from_any(error_mode)
321        self.write_only_format: bool = write_only_format
322        # join up the handlers
323        self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
324            handlers_default
325        )
326
327    @overload
328    def json_serialize(
329        self, obj: Mapping[str, Any], path: ObjectPath = ()
330    ) -> JSONdict: ...
331    @overload
332    def json_serialize(self, obj: list, path: ObjectPath = ()) -> list: ...
333    # @overload  # pyright: ignore[reportOverlappingOverload]
334    # def json_serialize(self, obj: set, path: ObjectPath = ()) -> _SerializedSet: ...
335    # @overload
336    # def json_serialize(
337    #     self, obj: frozenset, path: ObjectPath = ()
338    # ) -> _SerializedFrozenset: ...
339    @overload
340    def json_serialize(self, obj: Any, path: ObjectPath = ()) -> JSONitem: ...
341    def json_serialize(
342        self,
343        obj: Any,  # pyright: ignore[reportAny]
344        path: ObjectPath = (),
345    ) -> JSONitem:
346        handler = None
347        try:
348            for handler in self.handlers:
349                if handler.check(self, obj, path):
350                    output: JSONitem = handler.serialize_func(self, obj, path)
351                    if self.write_only_format:
352                        if isinstance(output, dict) and _FORMAT_KEY in output:
353                            # TYPING: JSONitem has no idea that _FORMAT_KEY is str
354                            new_fmt: str = output.pop(_FORMAT_KEY)  # type: ignore  # pyright: ignore[reportAssignmentType]
355                            output["__write_format__"] = new_fmt  # type: ignore
356                    return output
357
358            raise ValueError(f"no handler found for object with {type(obj) = }")  # pyright: ignore[reportAny]
359
360        except Exception as e:
361            if self.error_mode == ErrorMode.EXCEPT:
362                obj_str: str = repr(obj)  # pyright: ignore[reportAny]
363                if len(obj_str) > 1000:
364                    obj_str = obj_str[:1000] + "..."
365                handler_uid = handler.uid if handler else "no handler matched"
366                raise SerializationException(
367                    f"error serializing at {path = } with last handler: '{handler_uid}'\nfrom: {e}\nobj: {obj_str}"
368                ) from e
369            elif self.error_mode == ErrorMode.WARN:
370                warnings.warn(
371                    f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
372                )
373
374            return repr(obj)  # pyright: ignore[reportAny]
375
376    def hashify(
377        self,
378        obj: Any,  # pyright: ignore[reportAny]
379        path: ObjectPath = (),
380        force: bool = True,
381    ) -> Hashableitem:
382        """try to turn any object into something hashable"""
383        data = self.json_serialize(obj, path=path)
384
385        # recursive hashify, turning dicts and lists into tuples
386        return _recursive_hashify(data, force=force)

Json serialization class (holds configs)

Parameters:

  • array_mode : ArrayMode how to write arrays (defaults to "array_list_meta")
  • error_mode : ErrorMode what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") (defaults to "except")
  • handlers_pre : MonoTuple[SerializerHandler] handlers to use before the default handlers (defaults to tuple())
  • handlers_default : MonoTuple[SerializerHandler] default handlers to use (defaults to DEFAULT_HANDLERS)
  • write_only_format : bool changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) (defaults to False)

Raises:

  • ValueError: on init, if args is not empty
  • SerializationException: on json_serialize(), if any error occurs when trying to serialize an object and error_mode is set to ErrorMode.EXCEPT"
JsonSerializer( *args: None, array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim'] = 'array_list_meta', error_mode: muutils.errormode.ErrorMode = ErrorMode.Except, handlers_pre: None = (), handlers_default: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid="set -> dict[_FORMAT_KEY: 'set', data: list(...)]", desc='sets as dicts with format key'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='Iterable -> list', desc='Iterables (not lists/tuples/strings) as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings')), write_only_format: bool = False)
305    def __init__(
306        self,
307        *args: None,
308        array_mode: "ArrayMode" = "array_list_meta",
309        error_mode: ErrorMode = ErrorMode.EXCEPT,
310        handlers_pre: MonoTuple[SerializerHandler] = (),
311        handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
312        write_only_format: bool = False,
313    ):
314        if len(args) > 0:
315            raise ValueError(
316                f"JsonSerializer takes no positional arguments!\n{args = }"
317            )
318
319        self.array_mode: "ArrayMode" = array_mode
320        self.error_mode: ErrorMode = ErrorMode.from_any(error_mode)
321        self.write_only_format: bool = write_only_format
322        # join up the handlers
323        self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
324            handlers_default
325        )
array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
write_only_format: bool
handlers: None
def json_serialize( self, obj: Any, path: tuple[typing.Union[str, int], ...] = ()) -> Union[bool, int, float, str, NoneType, Sequence[ForwardRef('JSONitem')], Dict[str, ForwardRef('JSONitem')]]:
341    def json_serialize(
342        self,
343        obj: Any,  # pyright: ignore[reportAny]
344        path: ObjectPath = (),
345    ) -> JSONitem:
346        handler = None
347        try:
348            for handler in self.handlers:
349                if handler.check(self, obj, path):
350                    output: JSONitem = handler.serialize_func(self, obj, path)
351                    if self.write_only_format:
352                        if isinstance(output, dict) and _FORMAT_KEY in output:
353                            # TYPING: JSONitem has no idea that _FORMAT_KEY is str
354                            new_fmt: str = output.pop(_FORMAT_KEY)  # type: ignore  # pyright: ignore[reportAssignmentType]
355                            output["__write_format__"] = new_fmt  # type: ignore
356                    return output
357
358            raise ValueError(f"no handler found for object with {type(obj) = }")  # pyright: ignore[reportAny]
359
360        except Exception as e:
361            if self.error_mode == ErrorMode.EXCEPT:
362                obj_str: str = repr(obj)  # pyright: ignore[reportAny]
363                if len(obj_str) > 1000:
364                    obj_str = obj_str[:1000] + "..."
365                handler_uid = handler.uid if handler else "no handler matched"
366                raise SerializationException(
367                    f"error serializing at {path = } with last handler: '{handler_uid}'\nfrom: {e}\nobj: {obj_str}"
368                ) from e
369            elif self.error_mode == ErrorMode.WARN:
370                warnings.warn(
371                    f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
372                )
373
374            return repr(obj)  # pyright: ignore[reportAny]
def hashify( self, obj: Any, path: tuple[typing.Union[str, int], ...] = (), force: bool = True) -> Union[bool, int, float, str, NoneType, Tuple[ForwardRef('Hashableitem'), ...]]:
376    def hashify(
377        self,
378        obj: Any,  # pyright: ignore[reportAny]
379        path: ObjectPath = (),
380        force: bool = True,
381    ) -> Hashableitem:
382        """try to turn any object into something hashable"""
383        data = self.json_serialize(obj, path=path)
384
385        # recursive hashify, turning dicts and lists into tuples
386        return _recursive_hashify(data, force=force)

try to turn any object into something hashable

def try_catch( func: Callable[..., ~T_FuncTryCatchReturn]) -> Callable[..., Union[~T_FuncTryCatchReturn, str]]:
116def try_catch(
117    func: Callable[..., T_FuncTryCatchReturn],
118) -> Callable[..., Union[T_FuncTryCatchReturn, str]]:
119    """wraps the function to catch exceptions, returns serialized error message on exception
120
121    returned func will return normal result on success, or error message on exception
122    """
123
124    @functools.wraps(func)
125    def newfunc(*args: Any, **kwargs: Any) -> Union[T_FuncTryCatchReturn, str]:  # pyright: ignore[reportAny]
126        try:
127            return func(*args, **kwargs)
128        except Exception as e:
129            return f"{e.__class__.__name__}: {e}"
130
131    return newfunc

wraps the function to catch exceptions, returns serialized error message on exception

returned func will return normal result on success, or error message on exception

def dc_eq( dc1: Any, dc2: Any, except_when_class_mismatch: bool = False, false_when_class_mismatch: bool = True, except_when_field_mismatch: bool = False) -> bool:
214def dc_eq(
215    dc1: Any,  # pyright: ignore[reportAny]
216    dc2: Any,  # pyright: ignore[reportAny]
217    except_when_class_mismatch: bool = False,
218    false_when_class_mismatch: bool = True,
219    except_when_field_mismatch: bool = False,
220) -> bool:
221    """
222    checks if two dataclasses which (might) hold numpy arrays are equal
223
224    # Parameters:
225
226    - `dc1`: the first dataclass
227    - `dc2`: the second dataclass
228    - `except_when_class_mismatch: bool`
229        if `True`, will throw `TypeError` if the classes are different.
230        if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False`
231        (default: `False`)
232    - `false_when_class_mismatch: bool`
233        only relevant if `except_when_class_mismatch` is `False`.
234        if `True`, will return `False` if the classes are different.
235        if `False`, will attempt to compare the fields.
236    - `except_when_field_mismatch: bool`
237        only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`.
238        if `True`, will throw `AttributeError` if the fields are different.
239        (default: `False`)
240
241    # Returns:
242    - `bool`: True if the dataclasses are equal, False otherwise
243
244    # Raises:
245    - `TypeError`: if the dataclasses are of different classes
246    - `AttributeError`: if the dataclasses have different fields
247
248    ```
249    [START]
250
251    ┌─────────────┐
252    │ dc1 is dc2? │───Yes───► (True)
253    └──────┬──────┘
254           │No
255
256    ┌───────────────┐
257    │ classes match?│───Yes───► [compare field values] ───► (True/False)
258    └──────┬────────┘
259           │No
260
261    ┌────────────────────────────┐
262    │ except_when_class_mismatch?│───Yes───► { raise TypeError }
263    └─────────────┬──────────────┘
264                  │No
265
266    ┌────────────────────────────┐
267    │ false_when_class_mismatch? │───Yes───► (False)
268    └─────────────┬──────────────┘
269                  │No
270
271    ┌────────────────────────────┐
272    │ except_when_field_mismatch?│───No────► [compare field values]
273    └─────────────┬──────────────┘
274                  │Yes
275
276    ┌───────────────┐
277    │ fields match? │───Yes───► [compare field values]
278    └──────┬────────┘
279           │No
280
281    { raise AttributeError }
282    ```
283
284    """
285    if dc1 is dc2:
286        return True
287
288    if dc1.__class__ is not dc2.__class__:  # pyright: ignore[reportAny]
289        if except_when_class_mismatch:
290            # if the classes don't match, raise an error
291            raise TypeError(
292                f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`"  # pyright: ignore[reportAny]
293            )
294        if false_when_class_mismatch:
295            # return False immediately without attempting field comparison
296            return False
297        # classes don't match but we'll try to compare fields anyway
298        if except_when_field_mismatch:
299            dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)])  # pyright: ignore[reportAny]
300            dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)])  # pyright: ignore[reportAny]
301            fields_match: bool = set(dc1_fields) == set(dc2_fields)
302            if not fields_match:
303                # if the fields don't match, raise an error
304                raise AttributeError(
305                    f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`"
306                )
307
308    return all(
309        array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name))  # pyright: ignore[reportAny]
310        for fld in dataclasses.fields(dc1)  # pyright: ignore[reportAny]
311        if fld.compare
312    )

checks if two dataclasses which (might) hold numpy arrays are equal

Parameters:

  • dc1: the first dataclass
  • dc2: the second dataclass
  • except_when_class_mismatch: bool if True, will throw TypeError if the classes are different. if not, will return false by default or attempt to compare the fields if false_when_class_mismatch is False (default: False)
  • false_when_class_mismatch: bool only relevant if except_when_class_mismatch is False. if True, will return False if the classes are different. if False, will attempt to compare the fields.
  • except_when_field_mismatch: bool only relevant if except_when_class_mismatch is False and false_when_class_mismatch is False. if True, will throw AttributeError if the fields are different. (default: False)

Returns:

  • bool: True if the dataclasses are equal, False otherwise

Raises:

  • TypeError: if the dataclasses are of different classes
  • AttributeError: if the dataclasses have different fields
[START]
   ▼
┌─────────────┐
│ dc1 is dc2? │───Yes───► (True)
└──────┬──────┘
       │No
       ▼
┌───────────────┐
│ classes match?│───Yes───► [compare field values] ───► (True/False)
└──────┬────────┘
       │No
       ▼
┌────────────────────────────┐
│ except_when_class_mismatch?│───Yes───► { raise TypeError }
└─────────────┬──────────────┘
              │No
              ▼
┌────────────────────────────┐
│ false_when_class_mismatch? │───Yes───► (False)
└─────────────┬──────────────┘
              │No
              ▼
┌────────────────────────────┐
│ except_when_field_mismatch?│───No────► [compare field values]
└─────────────┬──────────────┘
              │Yes
              ▼
┌───────────────┐
│ fields match? │───Yes───► [compare field values]
└──────┬────────┘
       │No
       ▼
{ raise AttributeError }
@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
class SerializableDataclass(abc.ABC):
310@dataclass_transform(
311    field_specifiers=(serializable_field, SerializableField),
312)
313class SerializableDataclass(abc.ABC):
314    """Base class for serializable dataclasses
315
316    only for linting and type checking, still need to call `serializable_dataclass` decorator
317
318    # Usage:
319
320    ```python
321    @serializable_dataclass
322    class MyClass(SerializableDataclass):
323        a: int
324        b: str
325    ```
326
327    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
328
329        >>> my_obj = MyClass(a=1, b="q")
330        >>> s = json.dumps(my_obj.serialize())
331        >>> s
332        '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
333        >>> read_obj = MyClass.load(json.loads(s))
334        >>> read_obj == my_obj
335        True
336
337    This isn't too impressive on its own, but it gets more useful when you have nested classses,
338    or fields that are not json-serializable by default:
339
340    ```python
341    @serializable_dataclass
342    class NestedClass(SerializableDataclass):
343        x: str
344        y: MyClass
345        act_fun: torch.nn.Module = serializable_field(
346            default=torch.nn.ReLU(),
347            serialization_fn=lambda x: str(x),
348            deserialize_fn=lambda x: getattr(torch.nn, x)(),
349        )
350    ```
351
352    which gives us:
353
354        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
355        >>> s = json.dumps(nc.serialize())
356        >>> s
357        '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
358        >>> read_nc = NestedClass.load(json.loads(s))
359        >>> read_nc == nc
360        True
361    """
362
363    def serialize(self) -> dict[str, Any]:
364        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
365        raise NotImplementedError(
366            f"decorate {self.__class__ = } with `@serializable_dataclass`"
367        )
368
369    @overload
370    @classmethod
371    def load(cls, data: dict[str, Any]) -> Self: ...
372
373    @overload
374    @classmethod
375    def load(cls, data: Self) -> Self: ...
376
377    @classmethod
378    def load(cls, data: dict[str, Any] | Self) -> Self:
379        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
380        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
381
382    def validate_fields_types(
383        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
384    ) -> bool:
385        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
386        return SerializableDataclass__validate_fields_types(
387            self, on_typecheck_error=on_typecheck_error
388        )
389
390    def validate_field_type(
391        self,
392        field: "SerializableField|str",
393        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
394    ) -> bool:
395        """given a dataclass, check the field matches the type hint"""
396        return SerializableDataclass__validate_field_type(
397            self, field, on_typecheck_error=on_typecheck_error
398        )
399
400    def __eq__(self, other: Any) -> bool:
401        return dc_eq(self, other)
402
403    def __hash__(self) -> int:
404        "hashes the json-serialized representation of the class"
405        return hash(json.dumps(self.serialize()))
406
407    def diff(
408        self, other: "SerializableDataclass", of_serialized: bool = False
409    ) -> dict[str, Any]:
410        """get a rich and recursive diff between two instances of a serializable dataclass
411
412        ```python
413        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
414        {'b': {'self': 2, 'other': 3}}
415        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
416        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
417        ```
418
419        # Parameters:
420         - `other : SerializableDataclass`
421           other instance to compare against
422         - `of_serialized : bool`
423           if true, compare serialized data and not raw values
424           (defaults to `False`)
425
426        # Returns:
427         - `dict[str, Any]`
428
429
430        # Raises:
431         - `ValueError` : if the instances are not of the same type
432         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
433        """
434        # match types
435        if type(self) is not type(other):
436            raise ValueError(
437                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
438            )
439
440        # initialize the diff result
441        diff_result: dict = {}
442
443        # if they are the same, return the empty diff
444        try:
445            if self == other:
446                return diff_result
447        except Exception:
448            pass
449
450        # if we are working with serialized data, serialize the instances
451        if of_serialized:
452            ser_self: JSONdict = self.serialize()
453            ser_other: JSONdict = other.serialize()
454
455        # for each field in the class
456        for field in dataclasses.fields(self):  # type: ignore[arg-type]  # pyright: ignore[reportArgumentType]
457            # skip fields that are not for comparison
458            if not field.compare:
459                continue
460
461            # get values
462            field_name: str = field.name
463            self_value = getattr(self, field_name)
464            other_value = getattr(other, field_name)
465
466            # if the values are both serializable dataclasses, recurse
467            if isinstance(self_value, SerializableDataclass) and isinstance(
468                other_value, SerializableDataclass
469            ):
470                nested_diff: dict = self_value.diff(
471                    other_value, of_serialized=of_serialized
472                )
473                if nested_diff:
474                    diff_result[field_name] = nested_diff
475            # only support serializable dataclasses
476            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
477                other_value
478            ):
479                raise ValueError("Non-serializable dataclass is not supported")
480            else:
481                # get the values of either the serialized or the actual values
482                if of_serialized:
483                    self_value_s = ser_self[field_name]  # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType]
484                    other_value_s = ser_other[field_name]  # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType]
485                else:
486                    self_value_s = self_value
487                    other_value_s = other_value
488                # compare the values
489                if not array_safe_eq(self_value_s, other_value_s):
490                    diff_result[field_name] = {"self": self_value, "other": other_value}
491
492        # return the diff result
493        return diff_result
494
495    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
496        """update the instance from a nested dict, useful for configuration from command line args
497
498        # Parameters:
499            - `nested_dict : dict[str, Any]`
500                nested dict to update the instance with
501        """
502        for field in dataclasses.fields(self):  # type: ignore[arg-type]
503            field_name: str = field.name
504            self_value = getattr(self, field_name)
505
506            if field_name in nested_dict:
507                if isinstance(self_value, SerializableDataclass):
508                    self_value.update_from_nested_dict(nested_dict[field_name])
509                else:
510                    setattr(self, field_name, nested_dict[field_name])
511
512    def __copy__(self) -> "SerializableDataclass":
513        "deep copy by serializing and loading the instance to json"
514        return self.__class__.load(json.loads(json.dumps(self.serialize())))
515
516    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
517        "deep copy by serializing and loading the instance to json"
518        return self.__class__.load(json.loads(json.dumps(self.serialize())))

Base class for serializable dataclasses

only for linting and type checking, still need to call serializable_dataclass decorator

Usage:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]:
363    def serialize(self) -> dict[str, Any]:
364        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
365        raise NotImplementedError(
366            f"decorate {self.__class__ = } with `@serializable_dataclass`"
367        )

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], Self]) -> Self:
377    @classmethod
378    def load(cls, data: dict[str, Any] | Self) -> Self:
379        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
380        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
382    def validate_fields_types(
383        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
384    ) -> bool:
385        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
386        return SerializableDataclass__validate_fields_types(
387            self, on_typecheck_error=on_typecheck_error
388        )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

def validate_field_type( self, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
390    def validate_field_type(
391        self,
392        field: "SerializableField|str",
393        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
394    ) -> bool:
395        """given a dataclass, check the field matches the type hint"""
396        return SerializableDataclass__validate_field_type(
397            self, field, on_typecheck_error=on_typecheck_error
398        )

given a dataclass, check the field matches the type hint

def diff( self, other: SerializableDataclass, of_serialized: bool = False) -> dict[str, typing.Any]:
407    def diff(
408        self, other: "SerializableDataclass", of_serialized: bool = False
409    ) -> dict[str, Any]:
410        """get a rich and recursive diff between two instances of a serializable dataclass
411
412        ```python
413        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
414        {'b': {'self': 2, 'other': 3}}
415        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
416        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
417        ```
418
419        # Parameters:
420         - `other : SerializableDataclass`
421           other instance to compare against
422         - `of_serialized : bool`
423           if true, compare serialized data and not raw values
424           (defaults to `False`)
425
426        # Returns:
427         - `dict[str, Any]`
428
429
430        # Raises:
431         - `ValueError` : if the instances are not of the same type
432         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
433        """
434        # match types
435        if type(self) is not type(other):
436            raise ValueError(
437                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
438            )
439
440        # initialize the diff result
441        diff_result: dict = {}
442
443        # if they are the same, return the empty diff
444        try:
445            if self == other:
446                return diff_result
447        except Exception:
448            pass
449
450        # if we are working with serialized data, serialize the instances
451        if of_serialized:
452            ser_self: JSONdict = self.serialize()
453            ser_other: JSONdict = other.serialize()
454
455        # for each field in the class
456        for field in dataclasses.fields(self):  # type: ignore[arg-type]  # pyright: ignore[reportArgumentType]
457            # skip fields that are not for comparison
458            if not field.compare:
459                continue
460
461            # get values
462            field_name: str = field.name
463            self_value = getattr(self, field_name)
464            other_value = getattr(other, field_name)
465
466            # if the values are both serializable dataclasses, recurse
467            if isinstance(self_value, SerializableDataclass) and isinstance(
468                other_value, SerializableDataclass
469            ):
470                nested_diff: dict = self_value.diff(
471                    other_value, of_serialized=of_serialized
472                )
473                if nested_diff:
474                    diff_result[field_name] = nested_diff
475            # only support serializable dataclasses
476            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
477                other_value
478            ):
479                raise ValueError("Non-serializable dataclass is not supported")
480            else:
481                # get the values of either the serialized or the actual values
482                if of_serialized:
483                    self_value_s = ser_self[field_name]  # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType]
484                    other_value_s = ser_other[field_name]  # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType]
485                else:
486                    self_value_s = self_value
487                    other_value_s = other_value
488                # compare the values
489                if not array_safe_eq(self_value_s, other_value_s):
490                    diff_result[field_name] = {"self": self_value, "other": other_value}
491
492        # return the diff result
493        return diff_result

get a rich and recursive diff between two instances of a serializable dataclass

>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}

Parameters:

  • other : SerializableDataclass other instance to compare against
  • of_serialized : bool if true, compare serialized data and not raw values (defaults to False)

Returns:

  • dict[str, Any]

Raises:

  • ValueError : if the instances are not of the same type
  • ValueError : if the instances are dataclasses.dataclass but not SerializableDataclass
def update_from_nested_dict(self, nested_dict: dict[str, typing.Any]):
495    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
496        """update the instance from a nested dict, useful for configuration from command line args
497
498        # Parameters:
499            - `nested_dict : dict[str, Any]`
500                nested dict to update the instance with
501        """
502        for field in dataclasses.fields(self):  # type: ignore[arg-type]
503            field_name: str = field.name
504            self_value = getattr(self, field_name)
505
506            if field_name in nested_dict:
507                if isinstance(self_value, SerializableDataclass):
508                    self_value.update_from_nested_dict(nested_dict[field_name])
509                else:
510                    setattr(self, field_name, nested_dict[field_name])

update the instance from a nested dict, useful for configuration from command line args

Parameters:

- `nested_dict : dict[str, Any]`
    nested dict to update the instance with