Coverage for zanj/serializing.py: 88%
65 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-11-06 16:15 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-11-06 16:15 +0000
1from __future__ import annotations
3import json
4import sys
5from dataclasses import dataclass
6from typing import IO, Any, Callable, Iterable, Sequence
7import warnings
9import numpy as np
10from muutils.json_serialize.array import arr_metadata
11from muutils.json_serialize.json_serialize import ( # JsonSerializer,
12 DEFAULT_HANDLERS,
13 ObjectPath,
14 SerializerHandler,
15)
16from muutils.json_serialize.util import (
17 JSONdict,
18 JSONitem,
19 MonoTuple,
20 _FORMAT_KEY,
21 _REF_KEY,
22)
24from zanj.externals import ExternalItem, ExternalItemType, _ZANJ_pre
26KW_ONLY_KWARGS: dict = dict()
27if sys.version_info >= (3, 10):
28 KW_ONLY_KWARGS["kw_only"] = True
30# pylint: disable=unused-argument, protected-access, unexpected-keyword-arg
31# for some reason pylint complains about kwargs to ZANJSerializerHandler
34def jsonl_metadata(data: list[JSONdict]) -> dict:
35 """metadata about a jsonl object"""
36 all_cols: set[str] = set([col for item in data for col in item.keys()])
37 return {
38 "data[0]": data[0],
39 "len(data)": len(data),
40 "columns": {
41 col: {
42 "types": list(
43 set([type(item[col]).__name__ for item in data if col in item])
44 ),
45 "len": len([item[col] for item in data if col in item]),
46 }
47 for col in all_cols
48 if col != _FORMAT_KEY
49 },
50 }
53def store_npy(self: _ZANJ_pre, fp: IO[bytes], data: np.ndarray) -> None:
54 """store numpy array to given file as .npy"""
55 # TODO: Type `<module 'numpy.lib'>` has no attribute `format` --> zanj/serializing.py:54:5
56 # info: rule `unresolved-attribute` is enabled by default
57 np.lib.format.write_array( # ty: ignore[unresolved-attribute]
58 fp=fp,
59 array=np.asanyarray(data),
60 allow_pickle=False,
61 )
64def store_jsonl(self: _ZANJ_pre, fp: IO[bytes], data: Sequence[JSONitem]) -> None:
65 """store sequence to given file as .jsonl"""
67 for item in data:
68 fp.write(json.dumps(item).encode("utf-8"))
69 fp.write("\n".encode("utf-8"))
72EXTERNAL_STORE_FUNCS: dict[
73 ExternalItemType, Callable[[_ZANJ_pre, IO[bytes], Any], None]
74] = {
75 "npy": store_npy,
76 "jsonl": store_jsonl,
77}
80@dataclass(**KW_ONLY_KWARGS)
81class ZANJSerializerHandler(SerializerHandler):
82 """a handler for ZANJ serialization"""
84 # unique identifier for the handler, saved in _FORMAT_KEY field
85 # uid: str
86 # source package of the handler -- note that this might be overridden by ZANJ
87 source_pckg: str
88 # (self_config, object) -> whether to use this handler
89 check: Callable[[_ZANJ_pre, Any, ObjectPath], bool]
90 # (self_config, object, path) -> serialized object
91 serialize_func: Callable[[_ZANJ_pre, Any, ObjectPath], JSONitem]
92 # optional description of how this serializer works
93 # desc: str = "(no description)"
96def zanj_external_serialize(
97 jser: _ZANJ_pre,
98 data: Any,
99 path: ObjectPath,
100 item_type: ExternalItemType,
101 _format: str,
102) -> JSONitem:
103 """stores a numpy array or jsonl externally in a ZANJ object
105 # Parameters:
106 - `jser: ZANJ`
107 - `data: Any`
108 - `path: ObjectPath`
109 - `item_type: ExternalItemType`
111 # Returns:
112 - `JSONitem`
113 json data with reference
115 # Modifies:
116 - modifies `jser._externals`
117 """
118 # get the path, make sure its unique
119 assert isinstance(path, tuple), (
120 f"path must be a tuple, got {type(path) = } {path = }"
121 )
122 joined_path: str = "/".join([str(p) for p in path])
123 archive_path: str = f"{joined_path}.{item_type}"
125 # TODO: somehow need to control whether a failure here causes a fallback to other handlers, or whether the except should propagate
126 # this will probably require changes to the upstream muutils.json_serialize code
127 if archive_path in jser._externals:
128 err_msg = f"external path {archive_path} already exists!"
129 warnings.warn(err_msg)
130 raise ValueError(err_msg)
131 # Check for true path prefix conflicts (not just string prefix)
132 # Only flag when one path is a directory ancestor of another (contains "/" separator)
133 for p in jser._externals.keys():
134 # Remove the file extension to get the joined_path
135 existing_joined_path = p.rsplit(".", 1)[0]
136 # Check if one is a true path prefix with "/" separator
137 if existing_joined_path.startswith(joined_path + "/") or joined_path.startswith(
138 existing_joined_path + "/"
139 ):
140 err_msg = (
141 f"external path {joined_path} is a prefix of another path {p}!\n"
142 + f"{jser._externals.keys() = }\n{joined_path = }\n{path = }\n{p = }\n{existing_joined_path = }\n{archive_path = }\n{_format = }"
143 )
144 warnings.warn(err_msg)
145 raise ValueError(err_msg)
147 # process the data if needed, assemble metadata
148 data_new: Any = data
149 output: dict = {
150 _FORMAT_KEY: _format,
151 _REF_KEY: archive_path,
152 }
153 if item_type == "npy":
154 # check type
155 data_type_str: str = str(type(data))
156 if data_type_str == "<class 'torch.Tensor'>":
157 # detach and convert
158 data_new = data.detach().cpu().numpy()
159 elif data_type_str == "<class 'numpy.ndarray'>":
160 pass
161 else:
162 # if not a numpy array, except
163 raise TypeError(f"expected numpy.ndarray, got {data_type_str}")
164 # get metadata
165 output.update(arr_metadata(data))
166 elif item_type.startswith("jsonl"):
167 # check via mro to avoid importing pandas
168 if any("pandas.core.frame.DataFrame" in str(t) for t in data.__class__.__mro__):
169 output["columns"] = data.columns.tolist()
170 data_new = data.to_dict(orient="records")
171 elif isinstance(data, (list, tuple, Iterable, Sequence)):
172 data_new = [
173 jser.json_serialize(item, tuple(path) + (i,))
174 for i, item in enumerate(data)
175 ]
176 else:
177 raise TypeError(
178 f"expected list or pandas.DataFrame for jsonl, got {type(data)}"
179 )
181 if all([isinstance(item, dict) for item in data_new]):
182 output.update(jsonl_metadata(data_new))
184 # store the item for external serialization
185 jser._externals[archive_path] = ExternalItem(
186 item_type=item_type,
187 data=data_new,
188 path=path,
189 )
191 return output
194DEFAULT_SERIALIZER_HANDLERS_ZANJ: MonoTuple[ZANJSerializerHandler] = tuple(
195 [
196 ZANJSerializerHandler(
197 check=lambda self, obj, path: (
198 isinstance(obj, np.ndarray)
199 and obj.size >= self.external_array_threshold
200 ),
201 serialize_func=lambda self, obj, path: zanj_external_serialize(
202 self, obj, path, item_type="npy", _format="numpy.ndarray:external"
203 ),
204 uid="numpy.ndarray:external",
205 source_pckg="zanj",
206 desc="external numpy array",
207 ),
208 ZANJSerializerHandler(
209 check=lambda self, obj, path: (
210 str(type(obj)) == "<class 'torch.Tensor'>"
211 and int(obj.nelement()) >= self.external_array_threshold
212 ),
213 serialize_func=lambda self, obj, path: zanj_external_serialize(
214 self, obj, path, item_type="npy", _format="torch.Tensor:external"
215 ),
216 uid="torch.Tensor:external",
217 source_pckg="zanj",
218 desc="external torch tensor",
219 ),
220 ZANJSerializerHandler(
221 check=lambda self, obj, path: isinstance(obj, list)
222 and len(obj) >= self.external_list_threshold,
223 serialize_func=lambda self, obj, path: zanj_external_serialize(
224 self, obj, path, item_type="jsonl", _format="list:external"
225 ),
226 uid="list:external",
227 source_pckg="zanj",
228 desc="external list",
229 ),
230 ZANJSerializerHandler(
231 check=lambda self, obj, path: isinstance(obj, tuple)
232 and len(obj) >= self.external_list_threshold,
233 serialize_func=lambda self, obj, path: zanj_external_serialize(
234 self, obj, path, item_type="jsonl", _format="tuple:external"
235 ),
236 uid="tuple:external",
237 source_pckg="zanj",
238 desc="external tuple",
239 ),
240 ZANJSerializerHandler(
241 check=lambda self, obj, path: (
242 any(
243 "pandas.core.frame.DataFrame" in str(t)
244 for t in obj.__class__.__mro__
245 )
246 and len(obj) >= self.external_list_threshold
247 ),
248 serialize_func=lambda self, obj, path: zanj_external_serialize(
249 self, obj, path, item_type="jsonl", _format="pandas.DataFrame:external"
250 ),
251 uid="pandas.DataFrame:external",
252 source_pckg="zanj",
253 desc="external pandas DataFrame",
254 ),
255 # ZANJSerializerHandler(
256 # check=lambda self, obj, path: "<class 'torch.nn.modules.module.Module'>"
257 # in [str(t) for t in obj.__class__.__mro__],
258 # serialize_func=lambda self, obj, path: zanj_serialize_torchmodule(
259 # self, obj, path,
260 # ),
261 # uid="torch.nn.Module",
262 # source_pckg="zanj",
263 # desc="fallback torch serialization",
264 # ),
265 ]
266) + tuple(
267 DEFAULT_HANDLERS # type: ignore[arg-type]
268)
270# the complaint above is:
271# error: Argument 1 to "tuple" has incompatible type "Sequence[SerializerHandler]"; expected "Iterable[ZANJSerializerHandler]" [arg-type]