Coverage for muutils\json_serialize\json_serialize.py: 28%
64 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-09 01:48 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-09 01:48 -0600
1"""provides the basic framework for json serialization of objects
3notably:
5- `SerializerHandler` defines how to serialize a specific type of object
6- `JsonSerializer` handles configuration for which handlers to use
7- `json_serialize` provides the default configuration if you don't care -- call it on any object!
9"""
11from __future__ import annotations
13import inspect
14import warnings
15from dataclasses import dataclass, is_dataclass
16from pathlib import Path
17from typing import Any, Callable, Iterable, Mapping, Set, Union
19from muutils.errormode import ErrorMode
21try:
22 from muutils.json_serialize.array import ArrayMode, serialize_array
23except ImportError as e:
24 ArrayMode = str # type: ignore[misc]
25 serialize_array = lambda *args, **kwargs: None # noqa: E731
26 warnings.warn(
27 f"muutils.json_serialize.array could not be imported probably because missing numpy, array serialization will not work: \n{e}",
28 ImportWarning,
29 )
31from muutils.json_serialize.util import (
32 Hashableitem,
33 JSONitem,
34 MonoTuple,
35 SerializationException,
36 _recursive_hashify,
37 isinstance_namedtuple,
38 safe_getsource,
39 string_as_lines,
40 try_catch,
41)
43# pylint: disable=protected-access
45SERIALIZER_SPECIAL_KEYS: MonoTuple[str] = (
46 "__name__",
47 "__doc__",
48 "__module__",
49 "__class__",
50 "__dict__",
51 "__annotations__",
52)
54SERIALIZER_SPECIAL_FUNCS: dict[str, Callable] = {
55 "str": str,
56 "dir": dir,
57 "type": try_catch(lambda x: str(type(x).__name__)),
58 "repr": try_catch(lambda x: repr(x)),
59 "code": try_catch(lambda x: inspect.getsource(x)),
60 "sourcefile": try_catch(lambda x: inspect.getsourcefile(x)),
61}
63SERIALIZE_DIRECT_AS_STR: Set[str] = {
64 "<class 'torch.device'>",
65 "<class 'torch.dtype'>",
66}
68ObjectPath = MonoTuple[Union[str, int]]
71@dataclass
72class SerializerHandler:
73 """a handler for a specific type of object
75 # Parameters:
76 - `check : Callable[[JsonSerializer, Any], bool]` takes a JsonSerializer and an object, returns whether to use this handler
77 - `serialize : Callable[[JsonSerializer, Any, ObjectPath], JSONitem]` takes a JsonSerializer, an object, and the current path, returns the serialized object
78 - `desc : str` description of the handler (optional)
79 """
81 # (self_config, object) -> whether to use this handler
82 check: Callable[["JsonSerializer", Any, ObjectPath], bool]
83 # (self_config, object, path) -> serialized object
84 serialize_func: Callable[["JsonSerializer", Any, ObjectPath], JSONitem]
85 # unique identifier for the handler
86 uid: str
87 # description of this serializer
88 desc: str
90 def serialize(self) -> dict:
91 """serialize the handler info"""
92 return {
93 # get the code and doc of the check function
94 "check": {
95 "code": safe_getsource(self.check),
96 "doc": string_as_lines(self.check.__doc__),
97 },
98 # get the code and doc of the load function
99 "serialize_func": {
100 "code": safe_getsource(self.serialize_func),
101 "doc": string_as_lines(self.serialize_func.__doc__),
102 },
103 # get the uid, source_pckg, priority, and desc
104 "uid": str(self.uid),
105 "source_pckg": getattr(self.serialize_func, "source_pckg", None),
106 "__module__": getattr(self.serialize_func, "__module__", None),
107 "desc": str(self.desc),
108 }
111BASE_HANDLERS: MonoTuple[SerializerHandler] = (
112 SerializerHandler(
113 check=lambda self, obj, path: isinstance(
114 obj, (bool, int, float, str, type(None))
115 ),
116 serialize_func=lambda self, obj, path: obj,
117 uid="base types",
118 desc="base types (bool, int, float, str, None)",
119 ),
120 SerializerHandler(
121 check=lambda self, obj, path: isinstance(obj, Mapping),
122 serialize_func=lambda self, obj, path: {
123 str(k): self.json_serialize(v, tuple(path) + (k,)) for k, v in obj.items()
124 },
125 uid="dictionaries",
126 desc="dictionaries",
127 ),
128 SerializerHandler(
129 check=lambda self, obj, path: isinstance(obj, (list, tuple)),
130 serialize_func=lambda self, obj, path: [
131 self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj)
132 ],
133 uid="(list, tuple) -> list",
134 desc="lists and tuples as lists",
135 ),
136)
139def _serialize_override_serialize_func(
140 self: "JsonSerializer", obj: Any, path: ObjectPath
141) -> JSONitem:
142 # obj_cls: type = type(obj)
143 # if hasattr(obj_cls, "_register_self") and callable(obj_cls._register_self):
144 # obj_cls._register_self()
146 # get the serialized object
147 return obj.serialize()
150DEFAULT_HANDLERS: MonoTuple[SerializerHandler] = tuple(BASE_HANDLERS) + (
151 SerializerHandler(
152 # TODO: allow for custom serialization handler name
153 check=lambda self, obj, path: hasattr(obj, "serialize")
154 and callable(obj.serialize),
155 serialize_func=_serialize_override_serialize_func,
156 uid=".serialize override",
157 desc="objects with .serialize method",
158 ),
159 SerializerHandler(
160 check=lambda self, obj, path: isinstance_namedtuple(obj),
161 serialize_func=lambda self, obj, path: self.json_serialize(dict(obj._asdict())),
162 uid="namedtuple -> dict",
163 desc="namedtuples as dicts",
164 ),
165 SerializerHandler(
166 check=lambda self, obj, path: is_dataclass(obj),
167 serialize_func=lambda self, obj, path: {
168 k: self.json_serialize(getattr(obj, k), tuple(path) + (k,))
169 for k in obj.__dataclass_fields__
170 },
171 uid="dataclass -> dict",
172 desc="dataclasses as dicts",
173 ),
174 SerializerHandler(
175 check=lambda self, obj, path: isinstance(obj, Path),
176 serialize_func=lambda self, obj, path: obj.as_posix(),
177 uid="path -> str",
178 desc="Path objects as posix strings",
179 ),
180 SerializerHandler(
181 check=lambda self, obj, path: str(type(obj)) in SERIALIZE_DIRECT_AS_STR,
182 serialize_func=lambda self, obj, path: str(obj),
183 uid="obj -> str(obj)",
184 desc="directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings",
185 ),
186 SerializerHandler(
187 check=lambda self, obj, path: str(type(obj)) == "<class 'numpy.ndarray'>",
188 serialize_func=lambda self, obj, path: serialize_array(self, obj, path=path),
189 uid="numpy.ndarray",
190 desc="numpy arrays",
191 ),
192 SerializerHandler(
193 check=lambda self, obj, path: str(type(obj)) == "<class 'torch.Tensor'>",
194 serialize_func=lambda self, obj, path: serialize_array(
195 self, obj.detach().cpu(), path=path
196 ),
197 uid="torch.Tensor",
198 desc="pytorch tensors",
199 ),
200 SerializerHandler(
201 check=lambda self, obj, path: str(type(obj))
202 == "<class 'pandas.core.frame.DataFrame'>",
203 serialize_func=lambda self, obj, path: dict(
204 __format__="pandas.DataFrame",
205 columns=obj.columns.tolist(),
206 data=obj.to_dict(orient="records"),
207 path=path,
208 ),
209 uid="pandas.DataFrame",
210 desc="pandas DataFrames",
211 ),
212 SerializerHandler(
213 check=lambda self, obj, path: isinstance(obj, (set, list, tuple))
214 or isinstance(obj, Iterable),
215 serialize_func=lambda self, obj, path: [
216 self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj)
217 ],
218 uid="(set, list, tuple, Iterable) -> list",
219 desc="sets, lists, tuples, and Iterables as lists",
220 ),
221 SerializerHandler(
222 check=lambda self, obj, path: True,
223 serialize_func=lambda self, obj, path: {
224 **{k: str(getattr(obj, k, None)) for k in SERIALIZER_SPECIAL_KEYS},
225 **{k: f(obj) for k, f in SERIALIZER_SPECIAL_FUNCS.items()},
226 },
227 uid="fallback",
228 desc="fallback handler -- serialize object attributes and special functions as strings",
229 ),
230)
233class JsonSerializer:
234 """Json serialization class (holds configs)
236 # Parameters:
237 - `array_mode : ArrayMode`
238 how to write arrays
239 (defaults to `"array_list_meta"`)
240 - `error_mode : ErrorMode`
241 what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn")
242 (defaults to `"except"`)
243 - `handlers_pre : MonoTuple[SerializerHandler]`
244 handlers to use before the default handlers
245 (defaults to `tuple()`)
246 - `handlers_default : MonoTuple[SerializerHandler]`
247 default handlers to use
248 (defaults to `DEFAULT_HANDLERS`)
249 - `write_only_format : bool`
250 changes "__format__" keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading)
251 (defaults to `False`)
253 # Raises:
254 - `ValueError`: on init, if `args` is not empty
255 - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"`
257 """
259 def __init__(
260 self,
261 *args,
262 array_mode: ArrayMode = "array_list_meta",
263 error_mode: ErrorMode = ErrorMode.EXCEPT,
264 handlers_pre: MonoTuple[SerializerHandler] = tuple(),
265 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
266 write_only_format: bool = False,
267 ):
268 if len(args) > 0:
269 raise ValueError(
270 f"JsonSerializer takes no positional arguments!\n{args = }"
271 )
273 self.array_mode: ArrayMode = array_mode
274 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode)
275 self.write_only_format: bool = write_only_format
276 # join up the handlers
277 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
278 handlers_default
279 )
281 def json_serialize(
282 self,
283 obj: Any,
284 path: ObjectPath = tuple(),
285 ) -> JSONitem:
286 try:
287 for handler in self.handlers:
288 if handler.check(self, obj, path):
289 output: JSONitem = handler.serialize_func(self, obj, path)
290 if self.write_only_format:
291 if isinstance(output, dict) and "__format__" in output:
292 new_fmt: JSONitem = output.pop("__format__")
293 output["__write_format__"] = new_fmt
294 return output
296 raise ValueError(f"no handler found for object with {type(obj) = }")
298 except Exception as e:
299 if self.error_mode == "except":
300 obj_str: str = repr(obj)
301 if len(obj_str) > 1000:
302 obj_str = obj_str[:1000] + "..."
303 raise SerializationException(
304 f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}"
305 ) from e
306 elif self.error_mode == "warn":
307 warnings.warn(
308 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
309 )
311 return repr(obj)
313 def hashify(
314 self,
315 obj: Any,
316 path: ObjectPath = tuple(),
317 force: bool = True,
318 ) -> Hashableitem:
319 """try to turn any object into something hashable"""
320 data = self.json_serialize(obj, path=path)
322 # recursive hashify, turning dicts and lists into tuples
323 return _recursive_hashify(data, force=force)
326GLOBAL_JSON_SERIALIZER: JsonSerializer = JsonSerializer()
329def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem:
330 """serialize object to json-serializable object with default config"""
331 return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)