muutils.json_serialize
submodule for serializing things to json in a recoverable way
you can throw any object into muutils.json_serialize.json_serialize
and it will return a JSONitem, meaning a bool, int, float, str, None, list of JSONitems, or a dict mappting to JSONitem.
The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into json_serialize and it will just work. If you want to do so in a recoverable way, check out ZANJ.
it will do so by looking in DEFAULT_HANDLERS, which will keep it as-is if its already valid, then try to find a .serialize() method on the object, and then have a bunch of special cases. You can add handlers by initializing a JsonSerializer object and passing a sequence of them to handlers_pre
additionally, SerializeableDataclass is a special kind of dataclass where you specify how to serialize each field, and a .serialize() method is automatically added to the class. This is done by using the serializable_dataclass decorator, inheriting from SerializeableDataclass, and serializable_field in place of dataclasses.field when defining non-standard fields.
This module plays nicely with and is a dependency of the ZANJ library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.
1"""submodule for serializing things to json in a recoverable way 2 3you can throw *any* object into `muutils.json_serialize.json_serialize` 4and it will return a `JSONitem`, meaning a bool, int, float, str, None, list of `JSONitem`s, or a dict mappting to `JSONitem`. 5 6The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into `json_serialize` and it will just work. If you want to do so in a recoverable way, check out [`ZANJ`](https://github.com/mivanit/ZANJ). 7 8it will do so by looking in `DEFAULT_HANDLERS`, which will keep it as-is if its already valid, then try to find a `.serialize()` method on the object, and then have a bunch of special cases. You can add handlers by initializing a `JsonSerializer` object and passing a sequence of them to `handlers_pre` 9 10additionally, `SerializeableDataclass` is a special kind of dataclass where you specify how to serialize each field, and a `.serialize()` method is automatically added to the class. This is done by using the `serializable_dataclass` decorator, inheriting from `SerializeableDataclass`, and `serializable_field` in place of `dataclasses.field` when defining non-standard fields. 11 12This module plays nicely with and is a dependency of the [`ZANJ`](https://github.com/mivanit/ZANJ) library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes. 13 14""" 15 16from __future__ import annotations 17 18from muutils.json_serialize.array import arr_metadata, load_array 19from muutils.json_serialize.json_serialize import ( 20 BASE_HANDLERS, 21 JsonSerializer, 22 json_serialize, 23) 24from muutils.json_serialize.serializable_dataclass import ( 25 SerializableDataclass, 26 serializable_dataclass, 27 serializable_field, 28) 29from muutils.json_serialize.util import try_catch, JSONitem, dc_eq 30 31__all__ = [ 32 # submodules 33 "array", 34 "json_serialize", 35 "serializable_dataclass", 36 "serializable_field", 37 "util", 38 # imports 39 "arr_metadata", 40 "load_array", 41 "BASE_HANDLERS", 42 "JSONitem", 43 "JsonSerializer", 44 "json_serialize", 45 "try_catch", 46 "JSONitem", 47 "dc_eq", 48 "serializable_dataclass", 49 "serializable_field", 50 "SerializableDataclass", 51]
402def json_serialize(obj: Any, path: ObjectPath = ()) -> JSONitem: # pyright: ignore[reportAny] 403 """serialize object to json-serializable object with default config""" 404 return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)
serialize object to json-serializable object with default config
586@dataclass_transform( 587 field_specifiers=(serializable_field, SerializableField), 588) 589def serializable_dataclass( 590 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 591 _cls=None, # type: ignore 592 *, 593 init: bool = True, 594 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 595 eq: bool = True, 596 order: bool = False, 597 unsafe_hash: bool = False, 598 frozen: bool = False, 599 properties_to_serialize: Optional[list[str]] = None, 600 register_handler: bool = True, 601 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 602 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 603 methods_no_override: list[str] | None = None, 604 **kwargs: Any, 605) -> Any: 606 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!** 607 608 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 609 610 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass` 611 612 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 613 614 Examines PEP 526 `__annotations__` to determine fields. 615 616 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 617 618 ```python 619 @serializable_dataclass(kw_only=True) 620 class Myclass(SerializableDataclass): 621 a: int 622 b: str 623 ``` 624 ```python 625 >>> Myclass(a=1, b="q").serialize() 626 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 627 ``` 628 629 # Parameters: 630 631 - `_cls : _type_` 632 class to decorate. don't pass this arg, just use this as a decorator 633 (defaults to `None`) 634 - `init : bool` 635 whether to add an `__init__` method 636 *(passed to dataclasses.dataclass)* 637 (defaults to `True`) 638 - `repr : bool` 639 whether to add a `__repr__` method 640 *(passed to dataclasses.dataclass)* 641 (defaults to `True`) 642 - `order : bool` 643 whether to add rich comparison methods 644 *(passed to dataclasses.dataclass)* 645 (defaults to `False`) 646 - `unsafe_hash : bool` 647 whether to add a `__hash__` method 648 *(passed to dataclasses.dataclass)* 649 (defaults to `False`) 650 - `frozen : bool` 651 whether to make the class frozen 652 *(passed to dataclasses.dataclass)* 653 (defaults to `False`) 654 - `properties_to_serialize : Optional[list[str]]` 655 which properties to add to the serialized data dict 656 **SerializableDataclass only** 657 (defaults to `None`) 658 - `register_handler : bool` 659 if true, register the class with ZANJ for loading 660 **SerializableDataclass only** 661 (defaults to `True`) 662 - `on_typecheck_error : ErrorMode` 663 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 664 **SerializableDataclass only** 665 - `on_typecheck_mismatch : ErrorMode` 666 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 667 **SerializableDataclass only** 668 - `methods_no_override : list[str]|None` 669 list of methods that should not be overridden by the decorator 670 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function, 671 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence 672 **SerializableDataclass only** 673 (defaults to `None`) 674 - `**kwargs` 675 *(passed to dataclasses.dataclass)* 676 677 # Returns: 678 679 - `_type_` 680 the decorated class 681 682 # Raises: 683 684 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 685 - `NotSerializableFieldException` : if a field is not a `SerializableField` 686 - `FieldSerializationError` : if there is an error serializing a field 687 - `AttributeError` : if a property is not found on the class 688 - `FieldLoadingError` : if there is an error loading a field 689 """ 690 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 691 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 692 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 693 694 if properties_to_serialize is None: 695 _properties_to_serialize: list = list() 696 else: 697 _properties_to_serialize = properties_to_serialize 698 699 def wrap(cls: Type[T_SerializeableDataclass]) -> Type[T_SerializeableDataclass]: 700 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 701 for field_name, field_type in cls.__annotations__.items(): 702 field_value = getattr(cls, field_name, None) 703 if not isinstance(field_value, SerializableField): 704 if isinstance(field_value, dataclasses.Field): 705 # Convert the field to a SerializableField while preserving properties 706 field_value = SerializableField.from_Field(field_value) 707 else: 708 # Create a new SerializableField 709 field_value = serializable_field() 710 setattr(cls, field_name, field_value) 711 712 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 713 if sys.version_info < (3, 10): 714 if "kw_only" in kwargs: 715 if kwargs["kw_only"] == True: # noqa: E712 716 raise KWOnlyError( 717 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored" 718 ) 719 else: 720 del kwargs["kw_only"] 721 722 # call `dataclasses.dataclass` to set some stuff up 723 cls = dataclasses.dataclass( # type: ignore[call-overload] 724 cls, 725 init=init, 726 repr=repr, 727 eq=eq, 728 order=order, 729 unsafe_hash=unsafe_hash, 730 frozen=frozen, 731 **kwargs, 732 ) 733 734 # copy these to the class 735 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 736 737 # ====================================================================== 738 # define `serialize` func 739 # done locally since it depends on args to the decorator 740 # ====================================================================== 741 def serialize(self: Any) -> dict[str, Any]: 742 result: dict[str, Any] = { 743 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 744 } 745 # for each field in the class 746 for field in dataclasses.fields(self): # type: ignore[arg-type] 747 # need it to be our special SerializableField 748 if not isinstance(field, SerializableField): 749 raise NotSerializableFieldException( 750 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 751 f"but a {type(field)} " 752 "this state should be inaccessible, please report this bug!" 753 ) 754 755 # try to save it 756 if field.serialize: 757 value: Any = None # init before try in case getattr raises 758 try: 759 # get the val 760 value = getattr(self, field.name) 761 # if it is a serializable dataclass, serialize it 762 if isinstance(value, SerializableDataclass): 763 value = value.serialize() 764 # if the value has a serialization function, use that 765 if hasattr(value, "serialize") and callable(value.serialize): # pyright: ignore[reportAttributeAccessIssue] 766 value = value.serialize() # pyright: ignore[reportAttributeAccessIssue] 767 # if the field has a serialization function, use that 768 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 769 elif field.serialization_fn: 770 value = field.serialization_fn(value) 771 772 # store the value in the result 773 result[field.name] = value 774 except Exception as e: 775 raise FieldSerializationError( 776 "\n".join( 777 [ 778 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 779 f"{field = }", 780 f"{value or '<unavailable>' = }", 781 f"{self = }", 782 ] 783 ) 784 ) from e 785 786 # store each property if we can get it 787 for prop in self._properties_to_serialize: 788 if hasattr(cls, prop): 789 value = getattr(self, prop) 790 result[prop] = value 791 else: 792 raise AttributeError( 793 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 794 + f"but it is in {self._properties_to_serialize = }" 795 + f"\n{self = }" 796 ) 797 798 return result 799 800 # ====================================================================== 801 # define `load` func 802 # done locally since it depends on args to the decorator 803 # ====================================================================== 804 # mypy thinks this isnt a classmethod 805 @classmethod # type: ignore[misc] 806 def load( 807 cls: type[T_SerializeableDataclass], 808 data: dict[str, Any] | T_SerializeableDataclass, 809 ) -> T_SerializeableDataclass: 810 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 811 if isinstance(data, cls): 812 return data 813 814 assert isinstance(data, typing.Mapping), ( 815 f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 816 ) 817 818 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 819 820 # initialize dict for keeping what we will pass to the constructor 821 ctor_kwargs: dict[str, Any] = dict() 822 823 # iterate over the fields of the class 824 # mypy doesn't recognize @dataclass_transform for dataclasses.fields() 825 # https://github.com/python/mypy/issues/16241 826 for field in dataclasses.fields(cls): # type: ignore[arg-type] 827 # check if the field is a SerializableField 828 assert isinstance(field, SerializableField), ( 829 f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 830 ) 831 832 # check if the field is in the data and if it should be initialized 833 if (field.name in data) and field.init: 834 # get the value, we will be processing it 835 value: Any = data[field.name] 836 837 # get the type hint for the field 838 field_type_hint: Any = cls_type_hints.get(field.name, None) 839 840 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 841 if field.deserialize_fn: 842 # if it has a deserialization function, use that 843 value = field.deserialize_fn(value) 844 elif field.loading_fn: 845 # if it has a loading function, use that 846 value = field.loading_fn(data) 847 elif ( 848 field_type_hint is not None 849 and hasattr(field_type_hint, "load") 850 and callable(field_type_hint.load) 851 ): 852 # if no loading function but has a type hint with a load method, use that 853 if isinstance(value, dict): 854 value = field_type_hint.load(value) 855 else: 856 raise FieldLoadingError( 857 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 858 ) 859 else: 860 # assume no loading needs to happen, keep `value` as-is 861 pass 862 863 # store the value in the constructor kwargs 864 ctor_kwargs[field.name] = value 865 866 # create a new instance of the class with the constructor kwargs 867 output: T_SerializeableDataclass = cls(**ctor_kwargs) 868 869 # validate the types of the fields if needed 870 if on_typecheck_mismatch != ErrorMode.IGNORE: 871 fields_valid: dict[str, bool] = ( 872 SerializableDataclass__validate_fields_types__dict( 873 output, 874 on_typecheck_error=on_typecheck_error, 875 ) 876 ) 877 878 # if there are any fields that are not valid, raise an error 879 if not all(fields_valid.values()): 880 msg: str = ( 881 f"Type mismatch in fields of {cls.__name__}:\n" 882 + "\n".join( 883 [ 884 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 885 for k, v in fields_valid.items() 886 if not v 887 ] 888 ) 889 ) 890 891 on_typecheck_mismatch.process( 892 msg, except_cls=FieldTypeMismatchError 893 ) 894 895 # return the new instance 896 return output 897 898 _methods_no_override: set[str] 899 if methods_no_override is None: 900 _methods_no_override = set() 901 else: 902 _methods_no_override = set(methods_no_override) 903 904 if _methods_no_override - { 905 "__eq__", 906 "serialize", 907 "load", 908 "validate_fields_types", 909 }: 910 warnings.warn( 911 f"Unknown methods in `methods_no_override`: {_methods_no_override = }" 912 ) 913 914 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 915 if "serialize" not in _methods_no_override: 916 # type is `Callable[[T], dict]` 917 cls.serialize = serialize # type: ignore[attr-defined, method-assign] 918 if "load" not in _methods_no_override: 919 # type is `Callable[[dict], T]` 920 cls.load = load # type: ignore[attr-defined, method-assign, assignment] 921 922 if "validate_field_type" not in _methods_no_override: 923 # type is `Callable[[T, ErrorMode], bool]` 924 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined, method-assign] 925 926 if "__eq__" not in _methods_no_override: 927 # type is `Callable[[T, T], bool]` 928 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 929 930 # Register the class with ZANJ 931 if register_handler: 932 zanj_register_loader_serializable_dataclass(cls) 933 934 return cls 935 936 if _cls is None: 937 return wrap 938 else: 939 return wrap(_cls)
decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass!!
types will be validated (like pydantic) unless on_typecheck_mismatch is set to ErrorMode.IGNORE
behavior of most kwargs matches that of dataclasses.dataclass, but with some additional kwargs. any kwargs not listed here are passed to dataclasses.dataclass
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__ to determine fields.
If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str
>>> Myclass(a=1, b="q").serialize()
{_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
Parameters:
_cls : _type_class to decorate. don't pass this arg, just use this as a decorator (defaults toNone)init : boolwhether to add an__init__method (passed to dataclasses.dataclass) (defaults toTrue)repr : boolwhether to add a__repr__method (passed to dataclasses.dataclass) (defaults toTrue)order : boolwhether to add rich comparison methods (passed to dataclasses.dataclass) (defaults toFalse)unsafe_hash : boolwhether to add a__hash__method (passed to dataclasses.dataclass) (defaults toFalse)frozen : boolwhether to make the class frozen (passed to dataclasses.dataclass) (defaults toFalse)properties_to_serialize : Optional[list[str]]which properties to add to the serialized data dict SerializableDataclass only (defaults toNone)register_handler : boolif true, register the class with ZANJ for loading SerializableDataclass only (defaults toTrue)on_typecheck_error : ErrorModewhat to do if type checking throws an exception (except, warn, ignore). Ifignoreand an exception is thrown, type validation will still return false SerializableDataclass onlyon_typecheck_mismatch : ErrorModewhat to do if a type mismatch is found (except, warn, ignore). Ifignore, type validation will returnTrueSerializableDataclass onlymethods_no_override : list[str]|Nonelist of methods that should not be overridden by the decorator by default,__eq__,serialize,load, andvalidate_fields_typesare overridden by this function, but you can disable this if you'd rather write your own.dataclasses.dataclassmight still overwrite these, and those options take precedence SerializableDataclass only (defaults toNone)**kwargs(passed to dataclasses.dataclass)
Returns:
_type_the decorated class
Raises:
KWOnlyError: only raised ifkw_onlyisTrueand python version is <3.9, sincedataclasses.dataclassdoes not support thisNotSerializableFieldException: if a field is not aSerializableFieldFieldSerializationError: if there is an error serializing a fieldAttributeError: if a property is not found on the classFieldLoadingError: if there is an error loading a field
202def serializable_field( # general implementation 203 *_args: Any, 204 default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 205 default_factory: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 206 init: bool = True, 207 repr: bool = True, 208 hash: Optional[bool] = None, 209 compare: bool = True, 210 doc: str | None = None, 211 metadata: Optional[types.MappingProxyType] = None, 212 kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 213 serialize: bool = True, 214 serialization_fn: Optional[Callable[[Any], Any]] = None, 215 deserialize_fn: Optional[Callable[[Any], Any]] = None, 216 assert_type: bool = True, 217 custom_typecheck_fn: Optional[Callable[[type], bool]] = None, 218 **kwargs: Any, 219) -> Any: 220 """Create a new `SerializableField` 221 222 ``` 223 default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING, 224 default_factory: Callable[[], Sfield_T] 225 | dataclasses._MISSING_TYPE = dataclasses.MISSING, 226 init: bool = True, 227 repr: bool = True, 228 hash: Optional[bool] = None, 229 compare: bool = True, 230 doc: str | None = None, # new in python 3.14. can alternately pass `description` to match pydantic, but this is discouraged 231 metadata: types.MappingProxyType | None = None, 232 kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING, 233 # ---------------------------------------------------------------------- 234 # new in `SerializableField`, not in `dataclasses.Field` 235 serialize: bool = True, 236 serialization_fn: Optional[Callable[[Any], Any]] = None, 237 loading_fn: Optional[Callable[[Any], Any]] = None, 238 deserialize_fn: Optional[Callable[[Any], Any]] = None, 239 assert_type: bool = True, 240 custom_typecheck_fn: Optional[Callable[[type], bool]] = None, 241 ``` 242 243 # new Parameters: 244 - `serialize`: whether to serialize this field when serializing the class' 245 - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize` 246 - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is. 247 - `deserialize_fn`: new alternative to `loading_fn`. takes only the field's value, not the whole class. if both `loading_fn` and `deserialize_fn` are provided, an error will be raised. 248 - `assert_type`: whether to assert the type of the field when loading. if `False`, will not check the type of the field. 249 - `custom_typecheck_fn`: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking. 250 251 # Gotchas: 252 - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write: 253 254 ```python 255 class MyClass: 256 my_field: int = serializable_field( 257 serialization_fn=lambda x: str(x), 258 loading_fn=lambda x["my_field"]: int(x) 259 ) 260 ``` 261 262 using `deserialize_fn` instead: 263 264 ```python 265 class MyClass: 266 my_field: int = serializable_field( 267 serialization_fn=lambda x: str(x), 268 deserialize_fn=lambda x: int(x) 269 ) 270 ``` 271 272 In the above code, `my_field` is an int but will be serialized as a string. 273 274 note that if not using ZANJ, and you have a class inside a container, you MUST provide 275 `serialization_fn` and `loading_fn` to serialize and load the container. 276 ZANJ will automatically do this for you. 277 278 # TODO: `custom_value_check_fn`: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test 279 """ 280 assert len(_args) == 0, f"unexpected positional arguments: {_args}" 281 282 if "description" in kwargs: 283 import warnings 284 285 warnings.warn( 286 "`description` is deprecated, use `doc` instead", 287 DeprecationWarning, 288 ) 289 if doc is not None: 290 err_msg: str = f"cannot pass both `doc` and `description`: {doc=}, {kwargs['description']=}" 291 raise ValueError(err_msg) 292 doc = kwargs.pop("description") 293 294 return SerializableField( 295 default=default, 296 default_factory=default_factory, 297 init=init, 298 repr=repr, 299 hash=hash, 300 compare=compare, 301 doc=doc, 302 metadata=metadata, 303 kw_only=kw_only, 304 serialize=serialize, 305 serialization_fn=serialization_fn, 306 deserialize_fn=deserialize_fn, 307 assert_type=assert_type, 308 custom_typecheck_fn=custom_typecheck_fn, 309 **kwargs, 310 )
Create a new SerializableField
default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
doc: str | None = None, # new in python 3.14. can alternately pass `description` to match pydantic, but this is discouraged
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
# ----------------------------------------------------------------------
# new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
new Parameters:
serialize: whether to serialize this field when serializing the class'serialization_fn: function taking the instance of the field and returning a serializable object. If not provided, will iterate through theSerializerHandlers defined inmuutils.json_serialize.json_serializeloading_fn: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.deserialize_fn: new alternative toloading_fn. takes only the field's value, not the whole class. if bothloading_fnanddeserialize_fnare provided, an error will be raised.assert_type: whether to assert the type of the field when loading. ifFalse, will not check the type of the field.custom_typecheck_fn: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.
Gotchas:
loading_fntakes the dict of the class, not the field. if you wanted aloading_fnthat does nothing, you'd write:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
loading_fn=lambda x["my_field"]: int(x)
)
using deserialize_fn instead:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: int(x)
)
In the above code, my_field is an int but will be serialized as a string.
note that if not using ZANJ, and you have a class inside a container, you MUST provide
serialization_fn and loading_fn to serialize and load the container.
ZANJ will automatically do this for you.
TODO: custom_value_check_fn: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test
103def arr_metadata(arr: Any) -> ArrayMetadata: # pyright: ignore[reportAny] 104 """get metadata for a numpy array""" 105 return { 106 "shape": list(arr.shape), # pyright: ignore[reportAny] 107 "dtype": ( 108 arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) # pyright: ignore[reportAny] 109 ), 110 "n_elements": array_n_elements(arr), 111 }
get metadata for a numpy array
289def load_array( 290 arr: Union[SerializedArrayWithMeta, np.ndarray, NumericList], 291 array_mode: Optional[ArrayMode] = None, 292) -> np.ndarray: 293 """load a json-serialized array, infer the mode if not specified""" 294 # return arr if its already a numpy array 295 if isinstance(arr, np.ndarray): 296 assert array_mode is None, ( 297 "array_mode should not be specified when loading a numpy array, since that is a no-op" 298 ) 299 return arr 300 301 # try to infer the array_mode 302 array_mode_inferred: ArrayMode = infer_array_mode(arr) 303 if array_mode is None: 304 array_mode = array_mode_inferred 305 elif array_mode != array_mode_inferred: 306 warnings.warn( 307 f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}" 308 ) 309 310 # actually load the array 311 if array_mode == "array_list_meta": 312 assert isinstance(arr, typing.Mapping), ( 313 f"invalid list format: {type(arr) = }\n{arr = }" 314 ) 315 data = np.array(arr["data"], dtype=arr["dtype"]) # type: ignore 316 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore 317 raise ValueError(f"invalid shape: {arr}") 318 return data 319 320 elif array_mode == "array_hex_meta": 321 assert isinstance(arr, typing.Mapping), ( 322 f"invalid list format: {type(arr) = }\n{arr = }" 323 ) 324 data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) # type: ignore 325 return data.reshape(arr["shape"]) # type: ignore 326 327 elif array_mode == "array_b64_meta": 328 assert isinstance(arr, typing.Mapping), ( 329 f"invalid list format: {type(arr) = }\n{arr = }" 330 ) 331 data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) # type: ignore 332 return data.reshape(arr["shape"]) # type: ignore 333 334 elif array_mode == "list": 335 assert isinstance(arr, typing.Sequence), ( 336 f"invalid list format: {type(arr) = }\n{arr = }" 337 ) 338 return np.array(arr) # type: ignore 339 elif array_mode == "external": 340 assert isinstance(arr, typing.Mapping) 341 if "data" not in arr: 342 raise KeyError( # pyright: ignore[reportUnreachable] 343 f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}" 344 ) 345 # we can ignore here since we assume ZANJ has taken care of it 346 return arr["data"] # type: ignore[return-value] # pyright: ignore[reportReturnType] 347 elif array_mode == "zero_dim": 348 assert isinstance(arr, typing.Mapping) 349 data = np.array(arr["data"]) # ty: ignore[invalid-argument-type] 350 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore 351 raise ValueError(f"invalid shape: {arr}") 352 return data 353 else: 354 raise ValueError(f"invalid array_mode: {array_mode}") # pyright: ignore[reportUnreachable]
load a json-serialized array, infer the mode if not specified
279class JsonSerializer: 280 """Json serialization class (holds configs) 281 282 # Parameters: 283 - `array_mode : ArrayMode` 284 how to write arrays 285 (defaults to `"array_list_meta"`) 286 - `error_mode : ErrorMode` 287 what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") 288 (defaults to `"except"`) 289 - `handlers_pre : MonoTuple[SerializerHandler]` 290 handlers to use before the default handlers 291 (defaults to `tuple()`) 292 - `handlers_default : MonoTuple[SerializerHandler]` 293 default handlers to use 294 (defaults to `DEFAULT_HANDLERS`) 295 - `write_only_format : bool` 296 changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) 297 (defaults to `False`) 298 299 # Raises: 300 - `ValueError`: on init, if `args` is not empty 301 - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"` 302 303 """ 304 305 def __init__( 306 self, 307 *args: None, 308 array_mode: "ArrayMode" = "array_list_meta", 309 error_mode: ErrorMode = ErrorMode.EXCEPT, 310 handlers_pre: MonoTuple[SerializerHandler] = (), 311 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 312 write_only_format: bool = False, 313 ): 314 if len(args) > 0: 315 raise ValueError( 316 f"JsonSerializer takes no positional arguments!\n{args = }" 317 ) 318 319 self.array_mode: "ArrayMode" = array_mode 320 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 321 self.write_only_format: bool = write_only_format 322 # join up the handlers 323 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 324 handlers_default 325 ) 326 327 @overload 328 def json_serialize( 329 self, obj: Mapping[str, Any], path: ObjectPath = () 330 ) -> JSONdict: ... 331 @overload 332 def json_serialize(self, obj: list, path: ObjectPath = ()) -> list: ... 333 # @overload # pyright: ignore[reportOverlappingOverload] 334 # def json_serialize(self, obj: set, path: ObjectPath = ()) -> _SerializedSet: ... 335 # @overload 336 # def json_serialize( 337 # self, obj: frozenset, path: ObjectPath = () 338 # ) -> _SerializedFrozenset: ... 339 @overload 340 def json_serialize(self, obj: Any, path: ObjectPath = ()) -> JSONitem: ... 341 def json_serialize( 342 self, 343 obj: Any, # pyright: ignore[reportAny] 344 path: ObjectPath = (), 345 ) -> JSONitem: 346 handler = None 347 try: 348 for handler in self.handlers: 349 if handler.check(self, obj, path): 350 output: JSONitem = handler.serialize_func(self, obj, path) 351 if self.write_only_format: 352 if isinstance(output, dict) and _FORMAT_KEY in output: 353 # TYPING: JSONitem has no idea that _FORMAT_KEY is str 354 new_fmt: str = output.pop(_FORMAT_KEY) # type: ignore # pyright: ignore[reportAssignmentType] 355 output["__write_format__"] = new_fmt # type: ignore 356 return output 357 358 raise ValueError(f"no handler found for object with {type(obj) = }") # pyright: ignore[reportAny] 359 360 except Exception as e: 361 if self.error_mode == ErrorMode.EXCEPT: 362 obj_str: str = repr(obj) # pyright: ignore[reportAny] 363 if len(obj_str) > 1000: 364 obj_str = obj_str[:1000] + "..." 365 handler_uid = handler.uid if handler else "no handler matched" 366 raise SerializationException( 367 f"error serializing at {path = } with last handler: '{handler_uid}'\nfrom: {e}\nobj: {obj_str}" 368 ) from e 369 elif self.error_mode == ErrorMode.WARN: 370 warnings.warn( 371 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 372 ) 373 374 return repr(obj) # pyright: ignore[reportAny] 375 376 def hashify( 377 self, 378 obj: Any, # pyright: ignore[reportAny] 379 path: ObjectPath = (), 380 force: bool = True, 381 ) -> Hashableitem: 382 """try to turn any object into something hashable""" 383 data = self.json_serialize(obj, path=path) 384 385 # recursive hashify, turning dicts and lists into tuples 386 return _recursive_hashify(data, force=force)
Json serialization class (holds configs)
Parameters:
array_mode : ArrayModehow to write arrays (defaults to"array_list_meta")error_mode : ErrorModewhat to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") (defaults to"except")handlers_pre : MonoTuple[SerializerHandler]handlers to use before the default handlers (defaults totuple())handlers_default : MonoTuple[SerializerHandler]default handlers to use (defaults toDEFAULT_HANDLERS)write_only_format : boolchanges _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) (defaults toFalse)
Raises:
ValueError: on init, ifargsis not emptySerializationException: onjson_serialize(), if any error occurs when trying to serialize an object anderror_modeis set toErrorMode.EXCEPT"
305 def __init__( 306 self, 307 *args: None, 308 array_mode: "ArrayMode" = "array_list_meta", 309 error_mode: ErrorMode = ErrorMode.EXCEPT, 310 handlers_pre: MonoTuple[SerializerHandler] = (), 311 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 312 write_only_format: bool = False, 313 ): 314 if len(args) > 0: 315 raise ValueError( 316 f"JsonSerializer takes no positional arguments!\n{args = }" 317 ) 318 319 self.array_mode: "ArrayMode" = array_mode 320 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 321 self.write_only_format: bool = write_only_format 322 # join up the handlers 323 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 324 handlers_default 325 )
341 def json_serialize( 342 self, 343 obj: Any, # pyright: ignore[reportAny] 344 path: ObjectPath = (), 345 ) -> JSONitem: 346 handler = None 347 try: 348 for handler in self.handlers: 349 if handler.check(self, obj, path): 350 output: JSONitem = handler.serialize_func(self, obj, path) 351 if self.write_only_format: 352 if isinstance(output, dict) and _FORMAT_KEY in output: 353 # TYPING: JSONitem has no idea that _FORMAT_KEY is str 354 new_fmt: str = output.pop(_FORMAT_KEY) # type: ignore # pyright: ignore[reportAssignmentType] 355 output["__write_format__"] = new_fmt # type: ignore 356 return output 357 358 raise ValueError(f"no handler found for object with {type(obj) = }") # pyright: ignore[reportAny] 359 360 except Exception as e: 361 if self.error_mode == ErrorMode.EXCEPT: 362 obj_str: str = repr(obj) # pyright: ignore[reportAny] 363 if len(obj_str) > 1000: 364 obj_str = obj_str[:1000] + "..." 365 handler_uid = handler.uid if handler else "no handler matched" 366 raise SerializationException( 367 f"error serializing at {path = } with last handler: '{handler_uid}'\nfrom: {e}\nobj: {obj_str}" 368 ) from e 369 elif self.error_mode == ErrorMode.WARN: 370 warnings.warn( 371 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 372 ) 373 374 return repr(obj) # pyright: ignore[reportAny]
376 def hashify( 377 self, 378 obj: Any, # pyright: ignore[reportAny] 379 path: ObjectPath = (), 380 force: bool = True, 381 ) -> Hashableitem: 382 """try to turn any object into something hashable""" 383 data = self.json_serialize(obj, path=path) 384 385 # recursive hashify, turning dicts and lists into tuples 386 return _recursive_hashify(data, force=force)
try to turn any object into something hashable
116def try_catch( 117 func: Callable[..., T_FuncTryCatchReturn], 118) -> Callable[..., Union[T_FuncTryCatchReturn, str]]: 119 """wraps the function to catch exceptions, returns serialized error message on exception 120 121 returned func will return normal result on success, or error message on exception 122 """ 123 124 @functools.wraps(func) 125 def newfunc(*args: Any, **kwargs: Any) -> Union[T_FuncTryCatchReturn, str]: # pyright: ignore[reportAny] 126 try: 127 return func(*args, **kwargs) 128 except Exception as e: 129 return f"{e.__class__.__name__}: {e}" 130 131 return newfunc
wraps the function to catch exceptions, returns serialized error message on exception
returned func will return normal result on success, or error message on exception
214def dc_eq( 215 dc1: Any, # pyright: ignore[reportAny] 216 dc2: Any, # pyright: ignore[reportAny] 217 except_when_class_mismatch: bool = False, 218 false_when_class_mismatch: bool = True, 219 except_when_field_mismatch: bool = False, 220) -> bool: 221 """ 222 checks if two dataclasses which (might) hold numpy arrays are equal 223 224 # Parameters: 225 226 - `dc1`: the first dataclass 227 - `dc2`: the second dataclass 228 - `except_when_class_mismatch: bool` 229 if `True`, will throw `TypeError` if the classes are different. 230 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` 231 (default: `False`) 232 - `false_when_class_mismatch: bool` 233 only relevant if `except_when_class_mismatch` is `False`. 234 if `True`, will return `False` if the classes are different. 235 if `False`, will attempt to compare the fields. 236 - `except_when_field_mismatch: bool` 237 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. 238 if `True`, will throw `AttributeError` if the fields are different. 239 (default: `False`) 240 241 # Returns: 242 - `bool`: True if the dataclasses are equal, False otherwise 243 244 # Raises: 245 - `TypeError`: if the dataclasses are of different classes 246 - `AttributeError`: if the dataclasses have different fields 247 248 ``` 249 [START] 250 ▼ 251 ┌─────────────┐ 252 │ dc1 is dc2? │───Yes───► (True) 253 └──────┬──────┘ 254 │No 255 ▼ 256 ┌───────────────┐ 257 │ classes match?│───Yes───► [compare field values] ───► (True/False) 258 └──────┬────────┘ 259 │No 260 ▼ 261 ┌────────────────────────────┐ 262 │ except_when_class_mismatch?│───Yes───► { raise TypeError } 263 └─────────────┬──────────────┘ 264 │No 265 ▼ 266 ┌────────────────────────────┐ 267 │ false_when_class_mismatch? │───Yes───► (False) 268 └─────────────┬──────────────┘ 269 │No 270 ▼ 271 ┌────────────────────────────┐ 272 │ except_when_field_mismatch?│───No────► [compare field values] 273 └─────────────┬──────────────┘ 274 │Yes 275 ▼ 276 ┌───────────────┐ 277 │ fields match? │───Yes───► [compare field values] 278 └──────┬────────┘ 279 │No 280 ▼ 281 { raise AttributeError } 282 ``` 283 284 """ 285 if dc1 is dc2: 286 return True 287 288 if dc1.__class__ is not dc2.__class__: # pyright: ignore[reportAny] 289 if except_when_class_mismatch: 290 # if the classes don't match, raise an error 291 raise TypeError( 292 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" # pyright: ignore[reportAny] 293 ) 294 if false_when_class_mismatch: 295 # return False immediately without attempting field comparison 296 return False 297 # classes don't match but we'll try to compare fields anyway 298 if except_when_field_mismatch: 299 dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)]) # pyright: ignore[reportAny] 300 dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)]) # pyright: ignore[reportAny] 301 fields_match: bool = set(dc1_fields) == set(dc2_fields) 302 if not fields_match: 303 # if the fields don't match, raise an error 304 raise AttributeError( 305 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" 306 ) 307 308 return all( 309 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) # pyright: ignore[reportAny] 310 for fld in dataclasses.fields(dc1) # pyright: ignore[reportAny] 311 if fld.compare 312 )
checks if two dataclasses which (might) hold numpy arrays are equal
Parameters:
dc1: the first dataclassdc2: the second dataclassexcept_when_class_mismatch: boolifTrue, will throwTypeErrorif the classes are different. if not, will return false by default or attempt to compare the fields iffalse_when_class_mismatchisFalse(default:False)false_when_class_mismatch: boolonly relevant ifexcept_when_class_mismatchisFalse. ifTrue, will returnFalseif the classes are different. ifFalse, will attempt to compare the fields.except_when_field_mismatch: boolonly relevant ifexcept_when_class_mismatchisFalseandfalse_when_class_mismatchisFalse. ifTrue, will throwAttributeErrorif the fields are different. (default:False)
Returns:
bool: True if the dataclasses are equal, False otherwise
Raises:
TypeError: if the dataclasses are of different classesAttributeError: if the dataclasses have different fields
[START]
▼
┌─────────────┐
│ dc1 is dc2? │───Yes───► (True)
└──────┬──────┘
│No
▼
┌───────────────┐
│ classes match?│───Yes───► [compare field values] ───► (True/False)
└──────┬────────┘
│No
▼
┌────────────────────────────┐
│ except_when_class_mismatch?│───Yes───► { raise TypeError }
└─────────────┬──────────────┘
│No
▼
┌────────────────────────────┐
│ false_when_class_mismatch? │───Yes───► (False)
└─────────────┬──────────────┘
│No
▼
┌────────────────────────────┐
│ except_when_field_mismatch?│───No────► [compare field values]
└─────────────┬──────────────┘
│Yes
▼
┌───────────────┐
│ fields match? │───Yes───► [compare field values]
└──────┬────────┘
│No
▼
{ raise AttributeError }
310@dataclass_transform( 311 field_specifiers=(serializable_field, SerializableField), 312) 313class SerializableDataclass(abc.ABC): 314 """Base class for serializable dataclasses 315 316 only for linting and type checking, still need to call `serializable_dataclass` decorator 317 318 # Usage: 319 320 ```python 321 @serializable_dataclass 322 class MyClass(SerializableDataclass): 323 a: int 324 b: str 325 ``` 326 327 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 328 329 >>> my_obj = MyClass(a=1, b="q") 330 >>> s = json.dumps(my_obj.serialize()) 331 >>> s 332 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 333 >>> read_obj = MyClass.load(json.loads(s)) 334 >>> read_obj == my_obj 335 True 336 337 This isn't too impressive on its own, but it gets more useful when you have nested classses, 338 or fields that are not json-serializable by default: 339 340 ```python 341 @serializable_dataclass 342 class NestedClass(SerializableDataclass): 343 x: str 344 y: MyClass 345 act_fun: torch.nn.Module = serializable_field( 346 default=torch.nn.ReLU(), 347 serialization_fn=lambda x: str(x), 348 deserialize_fn=lambda x: getattr(torch.nn, x)(), 349 ) 350 ``` 351 352 which gives us: 353 354 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 355 >>> s = json.dumps(nc.serialize()) 356 >>> s 357 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 358 >>> read_nc = NestedClass.load(json.loads(s)) 359 >>> read_nc == nc 360 True 361 """ 362 363 def serialize(self) -> dict[str, Any]: 364 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 365 raise NotImplementedError( 366 f"decorate {self.__class__ = } with `@serializable_dataclass`" 367 ) 368 369 @overload 370 @classmethod 371 def load(cls, data: dict[str, Any]) -> Self: ... 372 373 @overload 374 @classmethod 375 def load(cls, data: Self) -> Self: ... 376 377 @classmethod 378 def load(cls, data: dict[str, Any] | Self) -> Self: 379 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 380 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 381 382 def validate_fields_types( 383 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 384 ) -> bool: 385 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 386 return SerializableDataclass__validate_fields_types( 387 self, on_typecheck_error=on_typecheck_error 388 ) 389 390 def validate_field_type( 391 self, 392 field: "SerializableField|str", 393 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 394 ) -> bool: 395 """given a dataclass, check the field matches the type hint""" 396 return SerializableDataclass__validate_field_type( 397 self, field, on_typecheck_error=on_typecheck_error 398 ) 399 400 def __eq__(self, other: Any) -> bool: 401 return dc_eq(self, other) 402 403 def __hash__(self) -> int: 404 "hashes the json-serialized representation of the class" 405 return hash(json.dumps(self.serialize())) 406 407 def diff( 408 self, other: "SerializableDataclass", of_serialized: bool = False 409 ) -> dict[str, Any]: 410 """get a rich and recursive diff between two instances of a serializable dataclass 411 412 ```python 413 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 414 {'b': {'self': 2, 'other': 3}} 415 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 416 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 417 ``` 418 419 # Parameters: 420 - `other : SerializableDataclass` 421 other instance to compare against 422 - `of_serialized : bool` 423 if true, compare serialized data and not raw values 424 (defaults to `False`) 425 426 # Returns: 427 - `dict[str, Any]` 428 429 430 # Raises: 431 - `ValueError` : if the instances are not of the same type 432 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 433 """ 434 # match types 435 if type(self) is not type(other): 436 raise ValueError( 437 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 438 ) 439 440 # initialize the diff result 441 diff_result: dict = {} 442 443 # if they are the same, return the empty diff 444 try: 445 if self == other: 446 return diff_result 447 except Exception: 448 pass 449 450 # if we are working with serialized data, serialize the instances 451 if of_serialized: 452 ser_self: JSONdict = self.serialize() 453 ser_other: JSONdict = other.serialize() 454 455 # for each field in the class 456 for field in dataclasses.fields(self): # type: ignore[arg-type] # pyright: ignore[reportArgumentType] 457 # skip fields that are not for comparison 458 if not field.compare: 459 continue 460 461 # get values 462 field_name: str = field.name 463 self_value = getattr(self, field_name) 464 other_value = getattr(other, field_name) 465 466 # if the values are both serializable dataclasses, recurse 467 if isinstance(self_value, SerializableDataclass) and isinstance( 468 other_value, SerializableDataclass 469 ): 470 nested_diff: dict = self_value.diff( 471 other_value, of_serialized=of_serialized 472 ) 473 if nested_diff: 474 diff_result[field_name] = nested_diff 475 # only support serializable dataclasses 476 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 477 other_value 478 ): 479 raise ValueError("Non-serializable dataclass is not supported") 480 else: 481 # get the values of either the serialized or the actual values 482 if of_serialized: 483 self_value_s = ser_self[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] 484 other_value_s = ser_other[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] 485 else: 486 self_value_s = self_value 487 other_value_s = other_value 488 # compare the values 489 if not array_safe_eq(self_value_s, other_value_s): 490 diff_result[field_name] = {"self": self_value, "other": other_value} 491 492 # return the diff result 493 return diff_result 494 495 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 496 """update the instance from a nested dict, useful for configuration from command line args 497 498 # Parameters: 499 - `nested_dict : dict[str, Any]` 500 nested dict to update the instance with 501 """ 502 for field in dataclasses.fields(self): # type: ignore[arg-type] 503 field_name: str = field.name 504 self_value = getattr(self, field_name) 505 506 if field_name in nested_dict: 507 if isinstance(self_value, SerializableDataclass): 508 self_value.update_from_nested_dict(nested_dict[field_name]) 509 else: 510 setattr(self, field_name, nested_dict[field_name]) 511 512 def __copy__(self) -> "SerializableDataclass": 513 "deep copy by serializing and loading the instance to json" 514 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 515 516 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 517 "deep copy by serializing and loading the instance to json" 518 return self.__class__.load(json.loads(json.dumps(self.serialize())))
Base class for serializable dataclasses
only for linting and type checking, still need to call serializable_dataclass decorator
Usage:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
363 def serialize(self) -> dict[str, Any]: 364 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 365 raise NotImplementedError( 366 f"decorate {self.__class__ = } with `@serializable_dataclass`" 367 )
returns the class as a dict, implemented by using @serializable_dataclass decorator
377 @classmethod 378 def load(cls, data: dict[str, Any] | Self) -> Self: 379 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 380 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator
382 def validate_fields_types( 383 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 384 ) -> bool: 385 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 386 return SerializableDataclass__validate_fields_types( 387 self, on_typecheck_error=on_typecheck_error 388 )
validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field
390 def validate_field_type( 391 self, 392 field: "SerializableField|str", 393 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 394 ) -> bool: 395 """given a dataclass, check the field matches the type hint""" 396 return SerializableDataclass__validate_field_type( 397 self, field, on_typecheck_error=on_typecheck_error 398 )
given a dataclass, check the field matches the type hint
407 def diff( 408 self, other: "SerializableDataclass", of_serialized: bool = False 409 ) -> dict[str, Any]: 410 """get a rich and recursive diff between two instances of a serializable dataclass 411 412 ```python 413 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 414 {'b': {'self': 2, 'other': 3}} 415 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 416 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 417 ``` 418 419 # Parameters: 420 - `other : SerializableDataclass` 421 other instance to compare against 422 - `of_serialized : bool` 423 if true, compare serialized data and not raw values 424 (defaults to `False`) 425 426 # Returns: 427 - `dict[str, Any]` 428 429 430 # Raises: 431 - `ValueError` : if the instances are not of the same type 432 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 433 """ 434 # match types 435 if type(self) is not type(other): 436 raise ValueError( 437 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 438 ) 439 440 # initialize the diff result 441 diff_result: dict = {} 442 443 # if they are the same, return the empty diff 444 try: 445 if self == other: 446 return diff_result 447 except Exception: 448 pass 449 450 # if we are working with serialized data, serialize the instances 451 if of_serialized: 452 ser_self: JSONdict = self.serialize() 453 ser_other: JSONdict = other.serialize() 454 455 # for each field in the class 456 for field in dataclasses.fields(self): # type: ignore[arg-type] # pyright: ignore[reportArgumentType] 457 # skip fields that are not for comparison 458 if not field.compare: 459 continue 460 461 # get values 462 field_name: str = field.name 463 self_value = getattr(self, field_name) 464 other_value = getattr(other, field_name) 465 466 # if the values are both serializable dataclasses, recurse 467 if isinstance(self_value, SerializableDataclass) and isinstance( 468 other_value, SerializableDataclass 469 ): 470 nested_diff: dict = self_value.diff( 471 other_value, of_serialized=of_serialized 472 ) 473 if nested_diff: 474 diff_result[field_name] = nested_diff 475 # only support serializable dataclasses 476 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 477 other_value 478 ): 479 raise ValueError("Non-serializable dataclass is not supported") 480 else: 481 # get the values of either the serialized or the actual values 482 if of_serialized: 483 self_value_s = ser_self[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] 484 other_value_s = ser_other[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] 485 else: 486 self_value_s = self_value 487 other_value_s = other_value 488 # compare the values 489 if not array_safe_eq(self_value_s, other_value_s): 490 diff_result[field_name] = {"self": self_value, "other": other_value} 491 492 # return the diff result 493 return diff_result
get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
Parameters:
other : SerializableDataclassother instance to compare againstof_serialized : boolif true, compare serialized data and not raw values (defaults toFalse)
Returns:
dict[str, Any]
Raises:
ValueError: if the instances are not of the same typeValueError: if the instances aredataclasses.dataclassbut notSerializableDataclass
495 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 496 """update the instance from a nested dict, useful for configuration from command line args 497 498 # Parameters: 499 - `nested_dict : dict[str, Any]` 500 nested dict to update the instance with 501 """ 502 for field in dataclasses.fields(self): # type: ignore[arg-type] 503 field_name: str = field.name 504 self_value = getattr(self, field_name) 505 506 if field_name in nested_dict: 507 if isinstance(self_value, SerializableDataclass): 508 self_value.update_from_nested_dict(nested_dict[field_name]) 509 else: 510 setattr(self, field_name, nested_dict[field_name])
update the instance from a nested dict, useful for configuration from command line args
Parameters:
- `nested_dict : dict[str, Any]`
nested dict to update the instance with