Coverage for zanj / serializing.py: 100%
74 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 01:52 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 01:52 -0700
1from __future__ import annotations
3import json
4import sys
5from dataclasses import dataclass
6from typing import IO, Any, Callable, Iterable, Sequence
7import warnings
9import numpy as np
10from muutils.json_serialize.array import arr_metadata
11from muutils.json_serialize.json_serialize import ( # JsonSerializer,
12 DEFAULT_HANDLERS,
13 ObjectPath,
14 SerializerHandler,
15)
17from zanj.consts import JSONdict, JSONitem, MonoTuple, _FORMAT_KEY, _REF_KEY
19from zanj.externals import ExternalItem, ExternalItemType, _ZANJ_pre
21KW_ONLY_KWARGS: dict = dict()
22if sys.version_info >= (3, 10):
23 KW_ONLY_KWARGS["kw_only"] = True
25# pylint: disable=unused-argument, protected-access, unexpected-keyword-arg
26# for some reason pylint complains about kwargs to ZANJSerializerHandler
29def jsonl_metadata(data: list[JSONdict]) -> dict[str, Any]:
30 """metadata about a jsonl object"""
31 all_cols: set[str] = set([col for item in data for col in item.keys()])
32 output: dict[str, Any] = {
33 "len(data)": len(data),
34 "columns": {
35 col: {
36 "types": list(
37 set([type(item[col]).__name__ for item in data if col in item])
38 ),
39 "len": len([item[col] for item in data if col in item]),
40 }
41 for col in all_cols
42 if col != _FORMAT_KEY
43 },
44 }
45 if len(data) > 0:
46 output["data[0]"] = data[0]
47 return output
50def store_npy(self: _ZANJ_pre, fp: IO[bytes], data: np.ndarray) -> None:
51 """store numpy array to given file as .npy"""
52 # TODO: Type `<module 'numpy.lib'>` has no attribute `format` --> zanj/serializing.py:54:5
53 # info: rule `unresolved-attribute` is enabled by default
54 np.lib.format.write_array( # ty: ignore[unresolved-attribute]
55 fp=fp,
56 array=np.asanyarray(data),
57 allow_pickle=False,
58 )
61def store_jsonl(self: _ZANJ_pre, fp: IO[bytes], data: Sequence[JSONitem]) -> None:
62 """store sequence to given file as .jsonl"""
64 for item in data:
65 fp.write(json.dumps(item).encode("utf-8"))
66 fp.write("\n".encode("utf-8"))
69EXTERNAL_STORE_FUNCS: dict[
70 ExternalItemType, Callable[[_ZANJ_pre, IO[bytes], Any], None]
71] = {
72 "npy": store_npy,
73 "jsonl": store_jsonl,
74}
77@dataclass(**KW_ONLY_KWARGS)
78class ZANJSerializerHandler(SerializerHandler):
79 """a handler for ZANJ serialization"""
81 # unique identifier for the handler, saved in _FORMAT_KEY field
82 # uid: str
83 # source package of the handler -- note that this might be overridden by ZANJ
84 source_pckg: str
85 # (self_config, object) -> whether to use this handler
86 check: Callable[[_ZANJ_pre, Any, ObjectPath], bool]
87 # (self_config, object, path) -> serialized object
88 serialize_func: Callable[[_ZANJ_pre, Any, ObjectPath], JSONitem]
89 # optional description of how this serializer works
90 # desc: str = "(no description)"
93def zanj_external_serialize(
94 jser: _ZANJ_pre,
95 data: Any,
96 path: ObjectPath,
97 item_type: ExternalItemType,
98 _format: str,
99) -> JSONitem:
100 """stores a numpy array or jsonl externally in a ZANJ object
102 # Parameters:
103 - `jser: ZANJ`
104 - `data: Any`
105 - `path: ObjectPath`
106 - `item_type: ExternalItemType`
108 # Returns:
109 - `JSONitem`
110 json data with reference
112 # Modifies:
113 - modifies `jser._externals`
114 """
115 # get the path, make sure its unique
116 assert isinstance(path, tuple), (
117 f"path must be a tuple, got {type(path) = } {path = }"
118 )
119 joined_path: str = "/".join([str(p) for p in path])
120 archive_path: str = f"{joined_path}.{item_type}"
122 # TODO: somehow need to control whether a failure here causes a fallback to other handlers, or whether the except should propagate
123 # this will probably require changes to the upstream muutils.json_serialize code
124 if archive_path in jser._externals:
125 err_msg = f"external path {archive_path} already exists!"
126 warnings.warn(err_msg)
127 raise ValueError(err_msg)
128 # Check for true path prefix conflicts (not just string prefix)
129 # Only flag when one path is a directory ancestor of another (contains "/" separator)
130 for p in jser._externals.keys():
131 # Remove the file extension to get the joined_path
132 existing_joined_path = p.rsplit(".", 1)[0]
133 # Check if one is a true path prefix with "/" separator
134 if existing_joined_path.startswith(joined_path + "/") or joined_path.startswith(
135 existing_joined_path + "/"
136 ):
137 err_msg = (
138 f"external path {joined_path} is a prefix of another path {p}!\n"
139 + f"{jser._externals.keys() = }\n{joined_path = }\n{path = }\n{p = }\n{existing_joined_path = }\n{archive_path = }\n{_format = }"
140 )
141 warnings.warn(err_msg)
142 raise ValueError(err_msg)
144 # process the data if needed, assemble metadata
145 data_new: Any = data
146 output: dict = {
147 _FORMAT_KEY: _format,
148 _REF_KEY: archive_path,
149 }
150 if item_type == "npy":
151 # check type
152 data_type_str: str = str(type(data))
153 if data_type_str == "<class 'torch.Tensor'>":
154 # detach and convert
155 data_new = data.detach().cpu().numpy()
156 elif data_type_str == "<class 'numpy.ndarray'>":
157 pass
158 else:
159 # if not a numpy array, except
160 raise TypeError(f"expected numpy.ndarray, got {data_type_str}")
161 # get metadata
162 output.update(arr_metadata(data))
163 elif item_type.startswith("jsonl"):
164 # check via module and class name to avoid importing pandas (works with pandas 3.0+)
165 dataframe_columns = None
166 if (
167 "pandas" in data.__class__.__module__
168 and data.__class__.__name__ == "DataFrame"
169 ):
170 dataframe_columns = data.columns.tolist()
171 data_new = data.to_dict(orient="records")
172 elif (
173 "polars" in data.__class__.__module__
174 and data.__class__.__name__ == "DataFrame"
175 ):
176 dataframe_columns = data.columns
177 data_new = data.to_dicts()
178 elif isinstance(data, (list, tuple, Iterable, Sequence)):
179 data_new = [
180 jser.json_serialize(item, tuple(path) + (i,))
181 for i, item in enumerate(data)
182 ]
183 else:
184 raise TypeError(
185 f"expected list or pandas.DataFrame for jsonl, got {type(data)}"
186 )
188 if all([isinstance(item, dict) for item in data_new]):
189 output.update(jsonl_metadata(data_new))
191 # set DataFrame columns after jsonl_metadata to avoid being overwritten
192 if dataframe_columns is not None:
193 output["columns"] = dataframe_columns
195 # store the item for external serialization
196 jser._externals[archive_path] = ExternalItem(
197 item_type=item_type,
198 data=data_new,
199 path=path,
200 )
202 return output
205DEFAULT_SERIALIZER_HANDLERS_ZANJ: MonoTuple[ZANJSerializerHandler] = tuple(
206 [
207 ZANJSerializerHandler(
208 check=lambda self, obj, path: (
209 isinstance(obj, np.ndarray)
210 and obj.size >= self.external_array_threshold
211 ),
212 serialize_func=lambda self, obj, path: zanj_external_serialize(
213 self, obj, path, item_type="npy", _format="numpy.ndarray:external"
214 ),
215 uid="numpy.ndarray:external",
216 source_pckg="zanj",
217 desc="external numpy array",
218 ),
219 ZANJSerializerHandler(
220 check=lambda self, obj, path: (
221 str(type(obj)) == "<class 'torch.Tensor'>"
222 and int(obj.nelement()) >= self.external_array_threshold
223 ),
224 serialize_func=lambda self, obj, path: zanj_external_serialize(
225 self, obj, path, item_type="npy", _format="torch.Tensor:external"
226 ),
227 uid="torch.Tensor:external",
228 source_pckg="zanj",
229 desc="external torch tensor",
230 ),
231 ZANJSerializerHandler(
232 check=lambda self, obj, path: (
233 isinstance(obj, list) and len(obj) >= self.external_list_threshold
234 ),
235 serialize_func=lambda self, obj, path: zanj_external_serialize(
236 self, obj, path, item_type="jsonl", _format="list:external"
237 ),
238 uid="list:external",
239 source_pckg="zanj",
240 desc="external list",
241 ),
242 ZANJSerializerHandler(
243 check=lambda self, obj, path: (
244 isinstance(obj, tuple) and len(obj) >= self.external_list_threshold
245 ),
246 serialize_func=lambda self, obj, path: zanj_external_serialize(
247 self, obj, path, item_type="jsonl", _format="tuple:external"
248 ),
249 uid="tuple:external",
250 source_pckg="zanj",
251 desc="external tuple",
252 ),
253 ZANJSerializerHandler(
254 check=lambda self, obj, path: (
255 "pandas" in obj.__class__.__module__
256 and obj.__class__.__name__ == "DataFrame"
257 ),
258 serialize_func=lambda self, obj, path: zanj_external_serialize(
259 self, obj, path, item_type="jsonl", _format="pandas.DataFrame:external"
260 ),
261 uid="pandas.DataFrame:external",
262 source_pckg="zanj",
263 desc="external pandas DataFrame",
264 ),
265 ZANJSerializerHandler(
266 check=lambda self, obj, path: (
267 "polars" in obj.__class__.__module__
268 and obj.__class__.__name__ == "DataFrame"
269 ),
270 serialize_func=lambda self, obj, path: zanj_external_serialize(
271 self, obj, path, item_type="jsonl", _format="polars.DataFrame:external"
272 ),
273 uid="polars.DataFrame:external",
274 source_pckg="zanj",
275 desc="external polars DataFrame",
276 ),
277 # ZANJSerializerHandler(
278 # check=lambda self, obj, path: "<class 'torch.nn.modules.module.Module'>"
279 # in [str(t) for t in obj.__class__.__mro__],
280 # serialize_func=lambda self, obj, path: zanj_serialize_torchmodule(
281 # self, obj, path,
282 # ),
283 # uid="torch.nn.Module",
284 # source_pckg="zanj",
285 # desc="fallback torch serialization",
286 # ),
287 ]
288) + tuple(
289 DEFAULT_HANDLERS # type: ignore[arg-type]
290)
292# the complaint above is:
293# error: Argument 1 to "tuple" has incompatible type "Sequence[SerializerHandler]"; expected "Iterable[ZANJSerializerHandler]" [arg-type]