Coverage for zanj / zanj.py: 99%
83 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-21 22:18 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-21 22:18 -0700
1"""
2an HDF5/exdir file alternative, which uses json for attributes, allows serialization of arbitrary data
4for large arrays, the output is a .tar.gz file with most data in a json file, but with sufficiently large arrays stored in binary .npy files
7"ZANJ" is an acronym that the AI tool [Elicit](https://elicit.org) came up with for me. not to be confused with:
9- https://en.wikipedia.org/wiki/Zanj
10- https://www.plutojournals.com/zanj/
12"""
14from __future__ import annotations
16import json
17import os
18import time
19import zipfile
20from dataclasses import dataclass
21from pathlib import Path
22from typing import Any, Union
24import numpy as np
25from muutils.errormode import ErrorMode
26from muutils.json_serialize.array import ArrayMode, arr_metadata
27from muutils.json_serialize.json_serialize import (
28 JsonSerializer,
29 SerializerHandler,
30 json_serialize,
31)
32from muutils.sysinfo import SysInfo
34from zanj.consts import JSONitem, MonoTuple
36from zanj.externals import ZANJ_MAIN, ZANJ_META, ExternalItem
37import zanj.externals
38from zanj.loading import LOADER_MAP, LoadedZANJ, load_item_recursive
39from zanj.serializing import (
40 DEFAULT_SERIALIZER_HANDLERS_ZANJ,
41 EXTERNAL_STORE_FUNCS,
42 KW_ONLY_KWARGS,
43)
45# pylint: disable=protected-access, unused-import, dangerous-default-value, line-too-long
47ZANJitem = Union[
48 JSONitem,
49 np.ndarray,
50 "pd.DataFrame", # type: ignore # noqa: F821
51]
54@dataclass(**KW_ONLY_KWARGS)
55class _ZANJ_GLOBAL_DEFAULTS_CLASS:
56 error_mode: ErrorMode = ErrorMode.EXCEPT
57 internal_array_mode: ArrayMode = "array_list_meta"
58 external_array_threshold: int = 256
59 external_list_threshold: int = 256
60 compress: bool | int = True
61 custom_settings: dict[str, Any] | None = None
64ZANJ_GLOBAL_DEFAULTS: _ZANJ_GLOBAL_DEFAULTS_CLASS = _ZANJ_GLOBAL_DEFAULTS_CLASS()
67class ZANJ(JsonSerializer):
68 """Zip up: Arrays in Numpy, JSON for everything else
70 given an arbitrary object, throw into a zip file, with arrays stored in .npy files, and everything else stored in a json file
72 (basically npz file with json)
74 - numpy (or pytorch) arrays are stored in paths according to their name and structure in the object
75 - everything else about the object is stored in a json file `zanj.json` in the root of the archive, via `muutils.json_serialize.JsonSerializer`
76 - metadata about ZANJ configuration, and optionally packages and versions, is stored in a `__zanj_meta__.json` file in the root of the archive
78 create a ZANJ-class via `z_cls = ZANJ().create(obj)`, and save/read instances of the object via `z_cls.save(obj, path)`, `z_cls.load(path)`. be sure to pass an **instance** of the object, to make sure that the attributes of the class can be correctly recognized
80 """
82 def __init__(
83 self,
84 error_mode: ErrorMode = ZANJ_GLOBAL_DEFAULTS.error_mode,
85 internal_array_mode: ArrayMode = ZANJ_GLOBAL_DEFAULTS.internal_array_mode,
86 external_array_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_array_threshold,
87 external_list_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_list_threshold,
88 compress: bool | int = ZANJ_GLOBAL_DEFAULTS.compress,
89 custom_settings: dict[str, Any] | None = ZANJ_GLOBAL_DEFAULTS.custom_settings,
90 handlers_pre: MonoTuple[SerializerHandler] = tuple(),
91 handlers_default: MonoTuple[
92 SerializerHandler
93 ] = DEFAULT_SERIALIZER_HANDLERS_ZANJ,
94 ) -> None:
95 super().__init__(
96 array_mode=internal_array_mode,
97 error_mode=error_mode,
98 handlers_pre=handlers_pre,
99 handlers_default=handlers_default,
100 )
102 self.external_array_threshold: int = external_array_threshold
103 self.external_list_threshold: int = external_list_threshold
104 self.custom_settings: dict = (
105 custom_settings if custom_settings is not None else dict()
106 )
108 # process compression to int if bool given
109 self.compress = compress
110 if isinstance(compress, bool):
111 if compress:
112 self.compress = zipfile.ZIP_DEFLATED
113 else:
114 self.compress = zipfile.ZIP_STORED
116 # create the externals, leave it empty
117 self._externals: dict[str, ExternalItem] = dict()
119 def externals_info(self) -> dict[str, dict[str, str | int | list[int]]]:
120 """return information about the current externals"""
121 output: dict[str, dict] = dict()
123 key: str
124 item: ExternalItem
125 for key, item in self._externals.items():
126 data = item.data
127 output[key] = {
128 "item_type": item.item_type,
129 "path": item.path,
130 "type(data)": str(type(data)),
131 "len(data)": len(data),
132 }
134 if item.item_type == "ndarray":
135 output[key].update(arr_metadata(data))
136 elif item.item_type.startswith("jsonl") and len(data) > 0:
137 output[key]["data[0]"] = data[0]
139 return {
140 key: val
141 for key, val in sorted(output.items(), key=lambda x: len(x[1]["path"]))
142 }
144 def meta(self) -> JSONitem:
145 """return the metadata of the ZANJ archive"""
147 serialization_handlers = {h.uid: h.serialize() for h in self.handlers}
148 load_handlers = {h.uid: h.serialize() for h in LOADER_MAP.values()}
150 return dict(
151 # configuration of this ZANJ instance
152 zanj_cfg=dict(
153 error_mode=str(self.error_mode),
154 array_mode=str(self.array_mode),
155 external_array_threshold=self.external_array_threshold,
156 external_list_threshold=self.external_list_threshold,
157 compress=self.compress,
158 serialization_handlers=serialization_handlers,
159 load_handlers=load_handlers,
160 ),
161 # system info (python, pip packages, torch & cuda, platform info, git info)
162 sysinfo=json_serialize(SysInfo.get_all(include=("python", "pytorch"))),
163 externals_info=self.externals_info(),
164 timestamp=time.time(),
165 )
167 def save(self, obj: Any, file_path: str | Path) -> str:
168 """save the object to a ZANJ archive. returns the path to the archive"""
170 # adjust extension
171 file_path = str(file_path)
172 if not file_path.endswith(".zanj"):
173 file_path += ".zanj"
175 # make directory
176 dir_path: str = os.path.dirname(file_path)
177 if dir_path != "":
178 if not os.path.exists(dir_path):
179 os.makedirs(dir_path, exist_ok=False)
181 # clear the externals!
182 self._externals = dict()
184 # serialize the object -- this will populate self._externals
185 # TODO: calling self.json_serialize again here might be slow
186 json_data: JSONitem = self.json_serialize(self.json_serialize(obj))
188 # open the zip file
189 zipf: zipfile.ZipFile = zipfile.ZipFile(
190 file=file_path, mode="w", compression=self.compress
191 )
193 # store base json data and metadata
194 zipf.writestr(
195 ZANJ_META,
196 json.dumps(
197 self.json_serialize(self.meta()),
198 indent="\t",
199 ),
200 )
201 zipf.writestr(
202 ZANJ_MAIN,
203 json.dumps(
204 json_data,
205 indent="\t",
206 ),
207 )
209 # store externals
210 for key, (ext_type, ext_data, ext_path) in self._externals.items():
211 # why force zip64? numpy.savez does it
212 with zipf.open(key, "w", force_zip64=True) as fp:
213 EXTERNAL_STORE_FUNCS[ext_type](self, fp, ext_data)
215 zipf.close()
217 # clear the externals, again
218 self._externals = dict()
220 return file_path
222 def read(
223 self,
224 file_path: Union[str, Path],
225 ) -> Any:
226 """load the object from a ZANJ archive
227 # TODO: load only some part of the zanj file by passing an ObjectPath
228 """
229 file_path = Path(file_path)
230 if not file_path.exists():
231 raise FileNotFoundError(f"file not found: {file_path}")
232 if not file_path.is_file():
233 raise FileNotFoundError(f"not a file: {file_path}")
235 loaded_zanj: LoadedZANJ = LoadedZANJ(
236 path=file_path,
237 zanj=self,
238 )
240 loaded_zanj.populate_externals()
242 return load_item_recursive(
243 loaded_zanj._json_data,
244 path=tuple(),
245 zanj=self,
246 error_mode=self.error_mode,
247 # lh_map=loader_handlers,
248 )
251zanj.externals._ZANJ_pre = ZANJ # type: ignore