muutils.json_serialize
submodule for serializing things to json in a recoverable way
you can throw any object into muutils.json_serialize.json_serialize
and it will return a JSONitem
, meaning a bool, int, float, str, None, list of JSONitem
s, or a dict mappting to JSONitem
.
The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into json_serialize
and it will just work. If you want to do so in a recoverable way, check out ZANJ
.
it will do so by looking in DEFAULT_HANDLERS
, which will keep it as-is if its already valid, then try to find a .serialize()
method on the object, and then have a bunch of special cases. You can add handlers by initializing a JsonSerializer
object and passing a sequence of them to handlers_pre
additionally, SerializeableDataclass
is a special kind of dataclass where you specify how to serialize each field, and a .serialize()
method is automatically added to the class. This is done by using the serializable_dataclass
decorator, inheriting from SerializeableDataclass
, and serializable_field
in place of dataclasses.field
when defining non-standard fields.
This module plays nicely with and is a dependency of the ZANJ
library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.
1"""submodule for serializing things to json in a recoverable way 2 3you can throw *any* object into `muutils.json_serialize.json_serialize` 4and it will return a `JSONitem`, meaning a bool, int, float, str, None, list of `JSONitem`s, or a dict mappting to `JSONitem`. 5 6The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into `json_serialize` and it will just work. If you want to do so in a recoverable way, check out [`ZANJ`](https://github.com/mivanit/ZANJ). 7 8it will do so by looking in `DEFAULT_HANDLERS`, which will keep it as-is if its already valid, then try to find a `.serialize()` method on the object, and then have a bunch of special cases. You can add handlers by initializing a `JsonSerializer` object and passing a sequence of them to `handlers_pre` 9 10additionally, `SerializeableDataclass` is a special kind of dataclass where you specify how to serialize each field, and a `.serialize()` method is automatically added to the class. This is done by using the `serializable_dataclass` decorator, inheriting from `SerializeableDataclass`, and `serializable_field` in place of `dataclasses.field` when defining non-standard fields. 11 12This module plays nicely with and is a dependency of the [`ZANJ`](https://github.com/mivanit/ZANJ) library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes. 13 14""" 15 16from __future__ import annotations 17 18from muutils.json_serialize.array import arr_metadata, load_array 19from muutils.json_serialize.json_serialize import ( 20 BASE_HANDLERS, 21 JsonSerializer, 22 json_serialize, 23) 24from muutils.json_serialize.serializable_dataclass import ( 25 SerializableDataclass, 26 serializable_dataclass, 27 serializable_field, 28) 29from muutils.json_serialize.util import try_catch, JSONitem, dc_eq 30 31__all__ = [ 32 # submodules 33 "array", 34 "json_serialize", 35 "serializable_dataclass", 36 "serializable_field", 37 "util", 38 # imports 39 "arr_metadata", 40 "load_array", 41 "BASE_HANDLERS", 42 "JSONitem", 43 "JsonSerializer", 44 "json_serialize", 45 "try_catch", 46 "JSONitem", 47 "dc_eq", 48 "serializable_dataclass", 49 "serializable_field", 50 "SerializableDataclass", 51]
330def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem: 331 """serialize object to json-serializable object with default config""" 332 return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)
serialize object to json-serializable object with default config
571@dataclass_transform( 572 field_specifiers=(serializable_field, SerializableField), 573) 574def serializable_dataclass( 575 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 576 _cls=None, # type: ignore 577 *, 578 init: bool = True, 579 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 580 eq: bool = True, 581 order: bool = False, 582 unsafe_hash: bool = False, 583 frozen: bool = False, 584 properties_to_serialize: Optional[list[str]] = None, 585 register_handler: bool = True, 586 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 587 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 588 **kwargs, 589): 590 """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass` 591 592 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 593 594 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs 595 596 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 597 598 Examines PEP 526 `__annotations__` to determine fields. 599 600 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 601 602 ```python 603 @serializable_dataclass(kw_only=True) 604 class Myclass(SerializableDataclass): 605 a: int 606 b: str 607 ``` 608 ```python 609 >>> Myclass(a=1, b="q").serialize() 610 {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 611 ``` 612 613 # Parameters: 614 - `_cls : _type_` 615 class to decorate. don't pass this arg, just use this as a decorator 616 (defaults to `None`) 617 - `init : bool` 618 (defaults to `True`) 619 - `repr : bool` 620 (defaults to `True`) 621 - `order : bool` 622 (defaults to `False`) 623 - `unsafe_hash : bool` 624 (defaults to `False`) 625 - `frozen : bool` 626 (defaults to `False`) 627 - `properties_to_serialize : Optional[list[str]]` 628 **SerializableDataclass only:** which properties to add to the serialized data dict 629 (defaults to `None`) 630 - `register_handler : bool` 631 **SerializableDataclass only:** if true, register the class with ZANJ for loading 632 (defaults to `True`) 633 - `on_typecheck_error : ErrorMode` 634 **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 635 - `on_typecheck_mismatch : ErrorMode` 636 **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 637 638 # Returns: 639 - `_type_` 640 the decorated class 641 642 # Raises: 643 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 644 - `NotSerializableFieldException` : if a field is not a `SerializableField` 645 - `FieldSerializationError` : if there is an error serializing a field 646 - `AttributeError` : if a property is not found on the class 647 - `FieldLoadingError` : if there is an error loading a field 648 """ 649 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 650 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 651 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 652 653 if properties_to_serialize is None: 654 _properties_to_serialize: list = list() 655 else: 656 _properties_to_serialize = properties_to_serialize 657 658 def wrap(cls: Type[T]) -> Type[T]: 659 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 660 for field_name, field_type in cls.__annotations__.items(): 661 field_value = getattr(cls, field_name, None) 662 if not isinstance(field_value, SerializableField): 663 if isinstance(field_value, dataclasses.Field): 664 # Convert the field to a SerializableField while preserving properties 665 field_value = SerializableField.from_Field(field_value) 666 else: 667 # Create a new SerializableField 668 field_value = serializable_field() 669 setattr(cls, field_name, field_value) 670 671 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 672 if sys.version_info < (3, 10): 673 if "kw_only" in kwargs: 674 if kwargs["kw_only"] == True: # noqa: E712 675 raise KWOnlyError("kw_only is not supported in python >=3.9") 676 else: 677 del kwargs["kw_only"] 678 679 # call `dataclasses.dataclass` to set some stuff up 680 cls = dataclasses.dataclass( # type: ignore[call-overload] 681 cls, 682 init=init, 683 repr=repr, 684 eq=eq, 685 order=order, 686 unsafe_hash=unsafe_hash, 687 frozen=frozen, 688 **kwargs, 689 ) 690 691 # copy these to the class 692 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 693 694 # ====================================================================== 695 # define `serialize` func 696 # done locally since it depends on args to the decorator 697 # ====================================================================== 698 def serialize(self) -> dict[str, Any]: 699 result: dict[str, Any] = { 700 "__format__": f"{self.__class__.__name__}(SerializableDataclass)" 701 } 702 # for each field in the class 703 for field in dataclasses.fields(self): # type: ignore[arg-type] 704 # need it to be our special SerializableField 705 if not isinstance(field, SerializableField): 706 raise NotSerializableFieldException( 707 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 708 f"but a {type(field)} " 709 "this state should be inaccessible, please report this bug!" 710 ) 711 712 # try to save it 713 if field.serialize: 714 try: 715 # get the val 716 value = getattr(self, field.name) 717 # if it is a serializable dataclass, serialize it 718 if isinstance(value, SerializableDataclass): 719 value = value.serialize() 720 # if the value has a serialization function, use that 721 if hasattr(value, "serialize") and callable(value.serialize): 722 value = value.serialize() 723 # if the field has a serialization function, use that 724 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 725 elif field.serialization_fn: 726 value = field.serialization_fn(value) 727 728 # store the value in the result 729 result[field.name] = value 730 except Exception as e: 731 raise FieldSerializationError( 732 "\n".join( 733 [ 734 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 735 f"{field = }", 736 f"{value = }", 737 f"{self = }", 738 ] 739 ) 740 ) from e 741 742 # store each property if we can get it 743 for prop in self._properties_to_serialize: 744 if hasattr(cls, prop): 745 value = getattr(self, prop) 746 result[prop] = value 747 else: 748 raise AttributeError( 749 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 750 + f"but it is in {self._properties_to_serialize = }" 751 + f"\n{self = }" 752 ) 753 754 return result 755 756 # ====================================================================== 757 # define `load` func 758 # done locally since it depends on args to the decorator 759 # ====================================================================== 760 # mypy thinks this isnt a classmethod 761 @classmethod # type: ignore[misc] 762 def load(cls, data: dict[str, Any] | T) -> Type[T]: 763 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 764 if isinstance(data, cls): 765 return data 766 767 assert isinstance( 768 data, typing.Mapping 769 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 770 771 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 772 773 # initialize dict for keeping what we will pass to the constructor 774 ctor_kwargs: dict[str, Any] = dict() 775 776 # iterate over the fields of the class 777 for field in dataclasses.fields(cls): 778 # check if the field is a SerializableField 779 assert isinstance( 780 field, SerializableField 781 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 782 783 # check if the field is in the data and if it should be initialized 784 if (field.name in data) and field.init: 785 # get the value, we will be processing it 786 value: Any = data[field.name] 787 788 # get the type hint for the field 789 field_type_hint: Any = cls_type_hints.get(field.name, None) 790 791 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 792 if field.deserialize_fn: 793 # if it has a deserialization function, use that 794 value = field.deserialize_fn(value) 795 elif field.loading_fn: 796 # if it has a loading function, use that 797 value = field.loading_fn(data) 798 elif ( 799 field_type_hint is not None 800 and hasattr(field_type_hint, "load") 801 and callable(field_type_hint.load) 802 ): 803 # if no loading function but has a type hint with a load method, use that 804 if isinstance(value, dict): 805 value = field_type_hint.load(value) 806 else: 807 raise FieldLoadingError( 808 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 809 ) 810 else: 811 # assume no loading needs to happen, keep `value` as-is 812 pass 813 814 # store the value in the constructor kwargs 815 ctor_kwargs[field.name] = value 816 817 # create a new instance of the class with the constructor kwargs 818 output: cls = cls(**ctor_kwargs) 819 820 # validate the types of the fields if needed 821 if on_typecheck_mismatch != ErrorMode.IGNORE: 822 output.validate_fields_types(on_typecheck_error=on_typecheck_error) 823 824 # return the new instance 825 return output 826 827 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 828 # type is `Callable[[T], dict]` 829 cls.serialize = serialize # type: ignore[attr-defined] 830 # type is `Callable[[dict], T]` 831 cls.load = load # type: ignore[attr-defined] 832 # type is `Callable[[T, ErrorMode], bool]` 833 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 834 835 # type is `Callable[[T, T], bool]` 836 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 837 838 # Register the class with ZANJ 839 if register_handler: 840 zanj_register_loader_serializable_dataclass(cls) 841 842 return cls 843 844 if _cls is None: 845 return wrap 846 else: 847 return wrap(_cls)
decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass
types will be validated (like pydantic) unless on_typecheck_mismatch
is set to ErrorMode.IGNORE
behavior of most kwargs matches that of dataclasses.dataclass
, but with some additional kwargs
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__
to determine fields.
If init is true, an __init__()
method is added to the class. If repr is true, a __repr__()
method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__()
method function is added. If frozen is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str
>>> Myclass(a=1, b="q").serialize()
{'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
Parameters:
_cls : _type_
class to decorate. don't pass this arg, just use this as a decorator (defaults toNone
)init : bool
(defaults toTrue
)repr : bool
(defaults toTrue
)order : bool
(defaults toFalse
)unsafe_hash : bool
(defaults toFalse
)frozen : bool
(defaults toFalse
)properties_to_serialize : Optional[list[str]]
SerializableDataclass only: which properties to add to the serialized data dict (defaults toNone
)register_handler : bool
SerializableDataclass only: if true, register the class with ZANJ for loading (defaults toTrue
)on_typecheck_error : ErrorMode
SerializableDataclass only: what to do if type checking throws an exception (except, warn, ignore). Ifignore
and an exception is thrown, type validation will still return falseon_typecheck_mismatch : ErrorMode
SerializableDataclass only: what to do if a type mismatch is found (except, warn, ignore). Ifignore
, type validation will returnTrue
Returns:
_type_
the decorated class
Raises:
KWOnlyError
: only raised ifkw_only
isTrue
and python version is <3.9, sincedataclasses.dataclass
does not support thisNotSerializableFieldException
: if a field is not aSerializableField
FieldSerializationError
: if there is an error serializing a fieldAttributeError
: if a property is not found on the classFieldLoadingError
: if there is an error loading a field
188def serializable_field( 189 *_args, 190 default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 191 default_factory: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 192 init: bool = True, 193 repr: bool = True, 194 hash: Optional[bool] = None, 195 compare: bool = True, 196 metadata: Optional[types.MappingProxyType] = None, 197 kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 198 serialize: bool = True, 199 serialization_fn: Optional[Callable[[Any], Any]] = None, 200 deserialize_fn: Optional[Callable[[Any], Any]] = None, 201 assert_type: bool = True, 202 custom_typecheck_fn: Optional[Callable[[type], bool]] = None, 203 **kwargs: Any, 204) -> Any: 205 """Create a new `SerializableField` 206 207 ``` 208 default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING, 209 default_factory: Callable[[], Sfield_T] 210 | dataclasses._MISSING_TYPE = dataclasses.MISSING, 211 init: bool = True, 212 repr: bool = True, 213 hash: Optional[bool] = None, 214 compare: bool = True, 215 metadata: types.MappingProxyType | None = None, 216 kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING, 217 # ---------------------------------------------------------------------- 218 # new in `SerializableField`, not in `dataclasses.Field` 219 serialize: bool = True, 220 serialization_fn: Optional[Callable[[Any], Any]] = None, 221 loading_fn: Optional[Callable[[Any], Any]] = None, 222 deserialize_fn: Optional[Callable[[Any], Any]] = None, 223 assert_type: bool = True, 224 custom_typecheck_fn: Optional[Callable[[type], bool]] = None, 225 ``` 226 227 # new Parameters: 228 - `serialize`: whether to serialize this field when serializing the class' 229 - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize` 230 - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is. 231 - `deserialize_fn`: new alternative to `loading_fn`. takes only the field's value, not the whole class. if both `loading_fn` and `deserialize_fn` are provided, an error will be raised. 232 - `assert_type`: whether to assert the type of the field when loading. if `False`, will not check the type of the field. 233 - `custom_typecheck_fn`: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking. 234 235 # Gotchas: 236 - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write: 237 238 ```python 239 class MyClass: 240 my_field: int = serializable_field( 241 serialization_fn=lambda x: str(x), 242 loading_fn=lambda x["my_field"]: int(x) 243 ) 244 ``` 245 246 using `deserialize_fn` instead: 247 248 ```python 249 class MyClass: 250 my_field: int = serializable_field( 251 serialization_fn=lambda x: str(x), 252 deserialize_fn=lambda x: int(x) 253 ) 254 ``` 255 256 In the above code, `my_field` is an int but will be serialized as a string. 257 258 note that if not using ZANJ, and you have a class inside a container, you MUST provide 259 `serialization_fn` and `loading_fn` to serialize and load the container. 260 ZANJ will automatically do this for you. 261 """ 262 assert len(_args) == 0, f"unexpected positional arguments: {_args}" 263 return SerializableField( 264 default=default, 265 default_factory=default_factory, 266 init=init, 267 repr=repr, 268 hash=hash, 269 compare=compare, 270 metadata=metadata, 271 kw_only=kw_only, 272 serialize=serialize, 273 serialization_fn=serialization_fn, 274 deserialize_fn=deserialize_fn, 275 assert_type=assert_type, 276 custom_typecheck_fn=custom_typecheck_fn, 277 **kwargs, 278 )
Create a new SerializableField
default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
# ----------------------------------------------------------------------
# new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
new Parameters:
serialize
: whether to serialize this field when serializing the class'serialization_fn
: function taking the instance of the field and returning a serializable object. If not provided, will iterate through theSerializerHandler
s defined inmuutils.json_serialize.json_serialize
loading_fn
: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.deserialize_fn
: new alternative toloading_fn
. takes only the field's value, not the whole class. if bothloading_fn
anddeserialize_fn
are provided, an error will be raised.assert_type
: whether to assert the type of the field when loading. ifFalse
, will not check the type of the field.custom_typecheck_fn
: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.
Gotchas:
loading_fn
takes the dict of the class, not the field. if you wanted aloading_fn
that does nothing, you'd write:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
loading_fn=lambda x["my_field"]: int(x)
)
using deserialize_fn
instead:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: int(x)
)
In the above code, my_field
is an int but will be serialized as a string.
note that if not using ZANJ, and you have a class inside a container, you MUST provide
serialization_fn
and loading_fn
to serialize and load the container.
ZANJ will automatically do this for you.
49def arr_metadata(arr) -> dict[str, list[int] | str | int]: 50 """get metadata for a numpy array""" 51 return { 52 "shape": list(arr.shape), 53 "dtype": ( 54 arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) 55 ), 56 "n_elements": array_n_elements(arr), 57 }
get metadata for a numpy array
168def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: 169 """load a json-serialized array, infer the mode if not specified""" 170 # return arr if its already a numpy array 171 if isinstance(arr, np.ndarray) and array_mode is None: 172 return arr 173 174 # try to infer the array_mode 175 array_mode_inferred: ArrayMode = infer_array_mode(arr) 176 if array_mode is None: 177 array_mode = array_mode_inferred 178 elif array_mode != array_mode_inferred: 179 warnings.warn( 180 f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}" 181 ) 182 183 # actually load the array 184 if array_mode == "array_list_meta": 185 assert isinstance( 186 arr, typing.Mapping 187 ), f"invalid list format: {type(arr) = }\n{arr = }" 188 189 data = np.array(arr["data"], dtype=arr["dtype"]) 190 if tuple(arr["shape"]) != tuple(data.shape): 191 raise ValueError(f"invalid shape: {arr}") 192 return data 193 194 elif array_mode == "array_hex_meta": 195 assert isinstance( 196 arr, typing.Mapping 197 ), f"invalid list format: {type(arr) = }\n{arr = }" 198 199 data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) 200 return data.reshape(arr["shape"]) 201 202 elif array_mode == "array_b64_meta": 203 assert isinstance( 204 arr, typing.Mapping 205 ), f"invalid list format: {type(arr) = }\n{arr = }" 206 207 data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) 208 return data.reshape(arr["shape"]) 209 210 elif array_mode == "list": 211 assert isinstance( 212 arr, typing.Sequence 213 ), f"invalid list format: {type(arr) = }\n{arr = }" 214 215 return np.array(arr) 216 elif array_mode == "external": 217 # assume ZANJ has taken care of it 218 assert isinstance(arr, typing.Mapping) 219 if "data" not in arr: 220 raise KeyError( 221 f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}" 222 ) 223 return arr["data"] 224 elif array_mode == "zero_dim": 225 assert isinstance(arr, typing.Mapping) 226 data = np.array(arr["data"]) 227 if tuple(arr["shape"]) != tuple(data.shape): 228 raise ValueError(f"invalid shape: {arr}") 229 return data 230 else: 231 raise ValueError(f"invalid array_mode: {array_mode}")
load a json-serialized array, infer the mode if not specified
234class JsonSerializer: 235 """Json serialization class (holds configs) 236 237 # Parameters: 238 - `array_mode : ArrayMode` 239 how to write arrays 240 (defaults to `"array_list_meta"`) 241 - `error_mode : ErrorMode` 242 what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") 243 (defaults to `"except"`) 244 - `handlers_pre : MonoTuple[SerializerHandler]` 245 handlers to use before the default handlers 246 (defaults to `tuple()`) 247 - `handlers_default : MonoTuple[SerializerHandler]` 248 default handlers to use 249 (defaults to `DEFAULT_HANDLERS`) 250 - `write_only_format : bool` 251 changes "__format__" keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) 252 (defaults to `False`) 253 254 # Raises: 255 - `ValueError`: on init, if `args` is not empty 256 - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"` 257 258 """ 259 260 def __init__( 261 self, 262 *args, 263 array_mode: ArrayMode = "array_list_meta", 264 error_mode: ErrorMode = ErrorMode.EXCEPT, 265 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 266 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 267 write_only_format: bool = False, 268 ): 269 if len(args) > 0: 270 raise ValueError( 271 f"JsonSerializer takes no positional arguments!\n{args = }" 272 ) 273 274 self.array_mode: ArrayMode = array_mode 275 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 276 self.write_only_format: bool = write_only_format 277 # join up the handlers 278 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 279 handlers_default 280 ) 281 282 def json_serialize( 283 self, 284 obj: Any, 285 path: ObjectPath = tuple(), 286 ) -> JSONitem: 287 try: 288 for handler in self.handlers: 289 if handler.check(self, obj, path): 290 output: JSONitem = handler.serialize_func(self, obj, path) 291 if self.write_only_format: 292 if isinstance(output, dict) and "__format__" in output: 293 new_fmt: JSONitem = output.pop("__format__") 294 output["__write_format__"] = new_fmt 295 return output 296 297 raise ValueError(f"no handler found for object with {type(obj) = }") 298 299 except Exception as e: 300 if self.error_mode == "except": 301 obj_str: str = repr(obj) 302 if len(obj_str) > 1000: 303 obj_str = obj_str[:1000] + "..." 304 raise SerializationException( 305 f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}" 306 ) from e 307 elif self.error_mode == "warn": 308 warnings.warn( 309 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 310 ) 311 312 return repr(obj) 313 314 def hashify( 315 self, 316 obj: Any, 317 path: ObjectPath = tuple(), 318 force: bool = True, 319 ) -> Hashableitem: 320 """try to turn any object into something hashable""" 321 data = self.json_serialize(obj, path=path) 322 323 # recursive hashify, turning dicts and lists into tuples 324 return _recursive_hashify(data, force=force)
Json serialization class (holds configs)
Parameters:
array_mode : ArrayMode
how to write arrays (defaults to"array_list_meta"
)error_mode : ErrorMode
what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") (defaults to"except"
)handlers_pre : MonoTuple[SerializerHandler]
handlers to use before the default handlers (defaults totuple()
)handlers_default : MonoTuple[SerializerHandler]
default handlers to use (defaults toDEFAULT_HANDLERS
)write_only_format : bool
changes "__format__" keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) (defaults toFalse
)
Raises:
ValueError
: on init, ifargs
is not emptySerializationException
: onjson_serialize()
, if any error occurs when trying to serialize an object anderror_mode
is set toErrorMode.EXCEPT"
260 def __init__( 261 self, 262 *args, 263 array_mode: ArrayMode = "array_list_meta", 264 error_mode: ErrorMode = ErrorMode.EXCEPT, 265 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 266 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 267 write_only_format: bool = False, 268 ): 269 if len(args) > 0: 270 raise ValueError( 271 f"JsonSerializer takes no positional arguments!\n{args = }" 272 ) 273 274 self.array_mode: ArrayMode = array_mode 275 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 276 self.write_only_format: bool = write_only_format 277 # join up the handlers 278 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 279 handlers_default 280 )
282 def json_serialize( 283 self, 284 obj: Any, 285 path: ObjectPath = tuple(), 286 ) -> JSONitem: 287 try: 288 for handler in self.handlers: 289 if handler.check(self, obj, path): 290 output: JSONitem = handler.serialize_func(self, obj, path) 291 if self.write_only_format: 292 if isinstance(output, dict) and "__format__" in output: 293 new_fmt: JSONitem = output.pop("__format__") 294 output["__write_format__"] = new_fmt 295 return output 296 297 raise ValueError(f"no handler found for object with {type(obj) = }") 298 299 except Exception as e: 300 if self.error_mode == "except": 301 obj_str: str = repr(obj) 302 if len(obj_str) > 1000: 303 obj_str = obj_str[:1000] + "..." 304 raise SerializationException( 305 f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}" 306 ) from e 307 elif self.error_mode == "warn": 308 warnings.warn( 309 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 310 ) 311 312 return repr(obj)
314 def hashify( 315 self, 316 obj: Any, 317 path: ObjectPath = tuple(), 318 force: bool = True, 319 ) -> Hashableitem: 320 """try to turn any object into something hashable""" 321 data = self.json_serialize(obj, path=path) 322 323 # recursive hashify, turning dicts and lists into tuples 324 return _recursive_hashify(data, force=force)
try to turn any object into something hashable
81def try_catch(func: Callable): 82 """wraps the function to catch exceptions, returns serialized error message on exception 83 84 returned func will return normal result on success, or error message on exception 85 """ 86 87 @functools.wraps(func) 88 def newfunc(*args, **kwargs): 89 try: 90 return func(*args, **kwargs) 91 except Exception as e: 92 return f"{e.__class__.__name__}: {e}" 93 94 return newfunc
wraps the function to catch exceptions, returns serialized error message on exception
returned func will return normal result on success, or error message on exception
175def dc_eq( 176 dc1, 177 dc2, 178 except_when_class_mismatch: bool = False, 179 false_when_class_mismatch: bool = True, 180 except_when_field_mismatch: bool = False, 181) -> bool: 182 """ 183 checks if two dataclasses which (might) hold numpy arrays are equal 184 185 # Parameters: 186 187 - `dc1`: the first dataclass 188 - `dc2`: the second dataclass 189 - `except_when_class_mismatch: bool` 190 if `True`, will throw `TypeError` if the classes are different. 191 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` 192 (default: `False`) 193 - `false_when_class_mismatch: bool` 194 only relevant if `except_when_class_mismatch` is `False`. 195 if `True`, will return `False` if the classes are different. 196 if `False`, will attempt to compare the fields. 197 - `except_when_field_mismatch: bool` 198 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. 199 if `True`, will throw `TypeError` if the fields are different. 200 (default: `True`) 201 202 # Returns: 203 - `bool`: True if the dataclasses are equal, False otherwise 204 205 # Raises: 206 - `TypeError`: if the dataclasses are of different classes 207 - `AttributeError`: if the dataclasses have different fields 208 209 ``` 210 [START] 211 ▼ 212 ┌───────────┐ ┌─────────┐ 213 │dc1 is dc2?├─►│ classes │ 214 └──┬────────┘No│ match? │ 215 ──── │ ├─────────┤ 216 (True)◄──┘Yes │No │Yes 217 ──── ▼ ▼ 218 ┌────────────────┐ ┌────────────┐ 219 │ except when │ │ fields keys│ 220 │ class mismatch?│ │ match? │ 221 ├───────────┬────┘ ├───────┬────┘ 222 │Yes │No │No │Yes 223 ▼ ▼ ▼ ▼ 224 ─────────── ┌──────────┐ ┌────────┐ 225 { raise } │ except │ │ field │ 226 { TypeError } │ when │ │ values │ 227 ─────────── │ field │ │ match? │ 228 │ mismatch?│ ├────┬───┘ 229 ├───────┬──┘ │ │Yes 230 │Yes │No │No ▼ 231 ▼ ▼ │ ──── 232 ─────────────── ───── │ (True) 233 { raise } (False)◄┘ ──── 234 { AttributeError} ───── 235 ─────────────── 236 ``` 237 238 """ 239 if dc1 is dc2: 240 return True 241 242 if dc1.__class__ is not dc2.__class__: 243 if except_when_class_mismatch: 244 # if the classes don't match, raise an error 245 raise TypeError( 246 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" 247 ) 248 if except_when_field_mismatch: 249 dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)]) 250 dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)]) 251 fields_match: bool = set(dc1_fields) == set(dc2_fields) 252 if not fields_match: 253 # if the fields match, keep going 254 raise AttributeError( 255 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" 256 ) 257 return False 258 259 return all( 260 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) 261 for fld in dataclasses.fields(dc1) 262 if fld.compare 263 )
checks if two dataclasses which (might) hold numpy arrays are equal
Parameters:
dc1
: the first dataclassdc2
: the second dataclassexcept_when_class_mismatch: bool
ifTrue
, will throwTypeError
if the classes are different. if not, will return false by default or attempt to compare the fields iffalse_when_class_mismatch
isFalse
(default:False
)false_when_class_mismatch: bool
only relevant ifexcept_when_class_mismatch
isFalse
. ifTrue
, will returnFalse
if the classes are different. ifFalse
, will attempt to compare the fields.except_when_field_mismatch: bool
only relevant ifexcept_when_class_mismatch
isFalse
andfalse_when_class_mismatch
isFalse
. ifTrue
, will throwTypeError
if the fields are different. (default:True
)
Returns:
bool
: True if the dataclasses are equal, False otherwise
Raises:
TypeError
: if the dataclasses are of different classesAttributeError
: if the dataclasses have different fields
[START]
▼
┌───────────┐ ┌─────────┐
│dc1 is dc2?├─►│ classes │
└──┬────────┘No│ match? │
──── │ ├─────────┤
(True)◄──┘Yes │No │Yes
──── ▼ ▼
┌────────────────┐ ┌────────────┐
│ except when │ │ fields keys│
│ class mismatch?│ │ match? │
├───────────┬────┘ ├───────┬────┘
│Yes │No │No │Yes
▼ ▼ ▼ ▼
─────────── ┌──────────┐ ┌────────┐
{ raise } │ except │ │ field │
{ TypeError } │ when │ │ values │
─────────── │ field │ │ match? │
│ mismatch?│ ├────┬───┘
├───────┬──┘ │ │Yes
│Yes │No │No ▼
▼ ▼ │ ────
─────────────── ───── │ (True)
{ raise } (False)◄┘ ────
{ AttributeError} ─────
───────────────
316@dataclass_transform( 317 field_specifiers=(serializable_field, SerializableField), 318) 319class SerializableDataclass(abc.ABC): 320 """Base class for serializable dataclasses 321 322 only for linting and type checking, still need to call `serializable_dataclass` decorator 323 324 # Usage: 325 326 ```python 327 @serializable_dataclass 328 class MyClass(SerializableDataclass): 329 a: int 330 b: str 331 ``` 332 333 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 334 335 >>> my_obj = MyClass(a=1, b="q") 336 >>> s = json.dumps(my_obj.serialize()) 337 >>> s 338 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 339 >>> read_obj = MyClass.load(json.loads(s)) 340 >>> read_obj == my_obj 341 True 342 343 This isn't too impressive on its own, but it gets more useful when you have nested classses, 344 or fields that are not json-serializable by default: 345 346 ```python 347 @serializable_dataclass 348 class NestedClass(SerializableDataclass): 349 x: str 350 y: MyClass 351 act_fun: torch.nn.Module = serializable_field( 352 default=torch.nn.ReLU(), 353 serialization_fn=lambda x: str(x), 354 deserialize_fn=lambda x: getattr(torch.nn, x)(), 355 ) 356 ``` 357 358 which gives us: 359 360 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 361 >>> s = json.dumps(nc.serialize()) 362 >>> s 363 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 364 >>> read_nc = NestedClass.load(json.loads(s)) 365 >>> read_nc == nc 366 True 367 """ 368 369 def serialize(self) -> dict[str, Any]: 370 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 371 raise NotImplementedError( 372 f"decorate {self.__class__ = } with `@serializable_dataclass`" 373 ) 374 375 @classmethod 376 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 377 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 378 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 379 380 def validate_fields_types( 381 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 382 ) -> bool: 383 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 384 return SerializableDataclass__validate_fields_types( 385 self, on_typecheck_error=on_typecheck_error 386 ) 387 388 def validate_field_type( 389 self, 390 field: "SerializableField|str", 391 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 392 ) -> bool: 393 """given a dataclass, check the field matches the type hint""" 394 return SerializableDataclass__validate_field_type( 395 self, field, on_typecheck_error=on_typecheck_error 396 ) 397 398 def __eq__(self, other: Any) -> bool: 399 return dc_eq(self, other) 400 401 def __hash__(self) -> int: 402 "hashes the json-serialized representation of the class" 403 return hash(json.dumps(self.serialize())) 404 405 def diff( 406 self, other: "SerializableDataclass", of_serialized: bool = False 407 ) -> dict[str, Any]: 408 """get a rich and recursive diff between two instances of a serializable dataclass 409 410 ```python 411 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 412 {'b': {'self': 2, 'other': 3}} 413 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 414 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 415 ``` 416 417 # Parameters: 418 - `other : SerializableDataclass` 419 other instance to compare against 420 - `of_serialized : bool` 421 if true, compare serialized data and not raw values 422 (defaults to `False`) 423 424 # Returns: 425 - `dict[str, Any]` 426 427 428 # Raises: 429 - `ValueError` : if the instances are not of the same type 430 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 431 """ 432 # match types 433 if type(self) is not type(other): 434 raise ValueError( 435 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 436 ) 437 438 # initialize the diff result 439 diff_result: dict = {} 440 441 # if they are the same, return the empty diff 442 if self == other: 443 return diff_result 444 445 # if we are working with serialized data, serialize the instances 446 if of_serialized: 447 ser_self: dict = self.serialize() 448 ser_other: dict = other.serialize() 449 450 # for each field in the class 451 for field in dataclasses.fields(self): # type: ignore[arg-type] 452 # skip fields that are not for comparison 453 if not field.compare: 454 continue 455 456 # get values 457 field_name: str = field.name 458 self_value = getattr(self, field_name) 459 other_value = getattr(other, field_name) 460 461 # if the values are both serializable dataclasses, recurse 462 if isinstance(self_value, SerializableDataclass) and isinstance( 463 other_value, SerializableDataclass 464 ): 465 nested_diff: dict = self_value.diff( 466 other_value, of_serialized=of_serialized 467 ) 468 if nested_diff: 469 diff_result[field_name] = nested_diff 470 # only support serializable dataclasses 471 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 472 other_value 473 ): 474 raise ValueError("Non-serializable dataclass is not supported") 475 else: 476 # get the values of either the serialized or the actual values 477 self_value_s = ser_self[field_name] if of_serialized else self_value 478 other_value_s = ser_other[field_name] if of_serialized else other_value 479 # compare the values 480 if not array_safe_eq(self_value_s, other_value_s): 481 diff_result[field_name] = {"self": self_value, "other": other_value} 482 483 # return the diff result 484 return diff_result 485 486 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 487 """update the instance from a nested dict, useful for configuration from command line args 488 489 # Parameters: 490 - `nested_dict : dict[str, Any]` 491 nested dict to update the instance with 492 """ 493 for field in dataclasses.fields(self): # type: ignore[arg-type] 494 field_name: str = field.name 495 self_value = getattr(self, field_name) 496 497 if field_name in nested_dict: 498 if isinstance(self_value, SerializableDataclass): 499 self_value.update_from_nested_dict(nested_dict[field_name]) 500 else: 501 setattr(self, field_name, nested_dict[field_name]) 502 503 def __copy__(self) -> "SerializableDataclass": 504 "deep copy by serializing and loading the instance to json" 505 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 506 507 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 508 "deep copy by serializing and loading the instance to json" 509 return self.__class__.load(json.loads(json.dumps(self.serialize())))
Base class for serializable dataclasses
only for linting and type checking, still need to call serializable_dataclass
decorator
Usage:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize()
to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
369 def serialize(self) -> dict[str, Any]: 370 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 371 raise NotImplementedError( 372 f"decorate {self.__class__ = } with `@serializable_dataclass`" 373 )
returns the class as a dict, implemented by using @serializable_dataclass
decorator
375 @classmethod 376 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 377 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 378 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
380 def validate_fields_types( 381 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 382 ) -> bool: 383 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 384 return SerializableDataclass__validate_fields_types( 385 self, on_typecheck_error=on_typecheck_error 386 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
388 def validate_field_type( 389 self, 390 field: "SerializableField|str", 391 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 392 ) -> bool: 393 """given a dataclass, check the field matches the type hint""" 394 return SerializableDataclass__validate_field_type( 395 self, field, on_typecheck_error=on_typecheck_error 396 )
given a dataclass, check the field matches the type hint
405 def diff( 406 self, other: "SerializableDataclass", of_serialized: bool = False 407 ) -> dict[str, Any]: 408 """get a rich and recursive diff between two instances of a serializable dataclass 409 410 ```python 411 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 412 {'b': {'self': 2, 'other': 3}} 413 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 414 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 415 ``` 416 417 # Parameters: 418 - `other : SerializableDataclass` 419 other instance to compare against 420 - `of_serialized : bool` 421 if true, compare serialized data and not raw values 422 (defaults to `False`) 423 424 # Returns: 425 - `dict[str, Any]` 426 427 428 # Raises: 429 - `ValueError` : if the instances are not of the same type 430 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 431 """ 432 # match types 433 if type(self) is not type(other): 434 raise ValueError( 435 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 436 ) 437 438 # initialize the diff result 439 diff_result: dict = {} 440 441 # if they are the same, return the empty diff 442 if self == other: 443 return diff_result 444 445 # if we are working with serialized data, serialize the instances 446 if of_serialized: 447 ser_self: dict = self.serialize() 448 ser_other: dict = other.serialize() 449 450 # for each field in the class 451 for field in dataclasses.fields(self): # type: ignore[arg-type] 452 # skip fields that are not for comparison 453 if not field.compare: 454 continue 455 456 # get values 457 field_name: str = field.name 458 self_value = getattr(self, field_name) 459 other_value = getattr(other, field_name) 460 461 # if the values are both serializable dataclasses, recurse 462 if isinstance(self_value, SerializableDataclass) and isinstance( 463 other_value, SerializableDataclass 464 ): 465 nested_diff: dict = self_value.diff( 466 other_value, of_serialized=of_serialized 467 ) 468 if nested_diff: 469 diff_result[field_name] = nested_diff 470 # only support serializable dataclasses 471 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 472 other_value 473 ): 474 raise ValueError("Non-serializable dataclass is not supported") 475 else: 476 # get the values of either the serialized or the actual values 477 self_value_s = ser_self[field_name] if of_serialized else self_value 478 other_value_s = ser_other[field_name] if of_serialized else other_value 479 # compare the values 480 if not array_safe_eq(self_value_s, other_value_s): 481 diff_result[field_name] = {"self": self_value, "other": other_value} 482 483 # return the diff result 484 return diff_result
get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
Parameters:
other : SerializableDataclass
other instance to compare againstof_serialized : bool
if true, compare serialized data and not raw values (defaults toFalse
)
Returns:
dict[str, Any]
Raises:
ValueError
: if the instances are not of the same typeValueError
: if the instances aredataclasses.dataclass
but notSerializableDataclass
486 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 487 """update the instance from a nested dict, useful for configuration from command line args 488 489 # Parameters: 490 - `nested_dict : dict[str, Any]` 491 nested dict to update the instance with 492 """ 493 for field in dataclasses.fields(self): # type: ignore[arg-type] 494 field_name: str = field.name 495 self_value = getattr(self, field_name) 496 497 if field_name in nested_dict: 498 if isinstance(self_value, SerializableDataclass): 499 self_value.update_from_nested_dict(nested_dict[field_name]) 500 else: 501 setattr(self, field_name, nested_dict[field_name])
update the instance from a nested dict, useful for configuration from command line args
Parameters:
- `nested_dict : dict[str, Any]`
nested dict to update the instance with