Coverage for zanj / zanj.py: 99%

83 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-21 22:18 -0700

1""" 

2an HDF5/exdir file alternative, which uses json for attributes, allows serialization of arbitrary data 

3 

4for large arrays, the output is a .tar.gz file with most data in a json file, but with sufficiently large arrays stored in binary .npy files 

5 

6 

7"ZANJ" is an acronym that the AI tool [Elicit](https://elicit.org) came up with for me. not to be confused with: 

8 

9- https://en.wikipedia.org/wiki/Zanj 

10- https://www.plutojournals.com/zanj/ 

11 

12""" 

13 

14from __future__ import annotations 

15 

16import json 

17import os 

18import time 

19import zipfile 

20from dataclasses import dataclass 

21from pathlib import Path 

22from typing import Any, Union 

23 

24import numpy as np 

25from muutils.errormode import ErrorMode 

26from muutils.json_serialize.array import ArrayMode, arr_metadata 

27from muutils.json_serialize.json_serialize import ( 

28 JsonSerializer, 

29 SerializerHandler, 

30 json_serialize, 

31) 

32from muutils.sysinfo import SysInfo 

33 

34from zanj.consts import JSONitem, MonoTuple 

35 

36from zanj.externals import ZANJ_MAIN, ZANJ_META, ExternalItem 

37import zanj.externals 

38from zanj.loading import LOADER_MAP, LoadedZANJ, load_item_recursive 

39from zanj.serializing import ( 

40 DEFAULT_SERIALIZER_HANDLERS_ZANJ, 

41 EXTERNAL_STORE_FUNCS, 

42 KW_ONLY_KWARGS, 

43) 

44 

45# pylint: disable=protected-access, unused-import, dangerous-default-value, line-too-long 

46 

47ZANJitem = Union[ 

48 JSONitem, 

49 np.ndarray, 

50 "pd.DataFrame", # type: ignore # noqa: F821 

51] 

52 

53 

54@dataclass(**KW_ONLY_KWARGS) 

55class _ZANJ_GLOBAL_DEFAULTS_CLASS: 

56 error_mode: ErrorMode = ErrorMode.EXCEPT 

57 internal_array_mode: ArrayMode = "array_list_meta" 

58 external_array_threshold: int = 256 

59 external_list_threshold: int = 256 

60 compress: bool | int = True 

61 custom_settings: dict[str, Any] | None = None 

62 

63 

64ZANJ_GLOBAL_DEFAULTS: _ZANJ_GLOBAL_DEFAULTS_CLASS = _ZANJ_GLOBAL_DEFAULTS_CLASS() 

65 

66 

67class ZANJ(JsonSerializer): 

68 """Zip up: Arrays in Numpy, JSON for everything else 

69 

70 given an arbitrary object, throw into a zip file, with arrays stored in .npy files, and everything else stored in a json file 

71 

72 (basically npz file with json) 

73 

74 - numpy (or pytorch) arrays are stored in paths according to their name and structure in the object 

75 - everything else about the object is stored in a json file `zanj.json` in the root of the archive, via `muutils.json_serialize.JsonSerializer` 

76 - metadata about ZANJ configuration, and optionally packages and versions, is stored in a `__zanj_meta__.json` file in the root of the archive 

77 

78 create a ZANJ-class via `z_cls = ZANJ().create(obj)`, and save/read instances of the object via `z_cls.save(obj, path)`, `z_cls.load(path)`. be sure to pass an **instance** of the object, to make sure that the attributes of the class can be correctly recognized 

79 

80 """ 

81 

82 def __init__( 

83 self, 

84 error_mode: ErrorMode = ZANJ_GLOBAL_DEFAULTS.error_mode, 

85 internal_array_mode: ArrayMode = ZANJ_GLOBAL_DEFAULTS.internal_array_mode, 

86 external_array_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_array_threshold, 

87 external_list_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_list_threshold, 

88 compress: bool | int = ZANJ_GLOBAL_DEFAULTS.compress, 

89 custom_settings: dict[str, Any] | None = ZANJ_GLOBAL_DEFAULTS.custom_settings, 

90 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 

91 handlers_default: MonoTuple[ 

92 SerializerHandler 

93 ] = DEFAULT_SERIALIZER_HANDLERS_ZANJ, 

94 ) -> None: 

95 super().__init__( 

96 array_mode=internal_array_mode, 

97 error_mode=error_mode, 

98 handlers_pre=handlers_pre, 

99 handlers_default=handlers_default, 

100 ) 

101 

102 self.external_array_threshold: int = external_array_threshold 

103 self.external_list_threshold: int = external_list_threshold 

104 self.custom_settings: dict = ( 

105 custom_settings if custom_settings is not None else dict() 

106 ) 

107 

108 # process compression to int if bool given 

109 self.compress = compress 

110 if isinstance(compress, bool): 

111 if compress: 

112 self.compress = zipfile.ZIP_DEFLATED 

113 else: 

114 self.compress = zipfile.ZIP_STORED 

115 

116 # create the externals, leave it empty 

117 self._externals: dict[str, ExternalItem] = dict() 

118 

119 def externals_info(self) -> dict[str, dict[str, str | int | list[int]]]: 

120 """return information about the current externals""" 

121 output: dict[str, dict] = dict() 

122 

123 key: str 

124 item: ExternalItem 

125 for key, item in self._externals.items(): 

126 data = item.data 

127 output[key] = { 

128 "item_type": item.item_type, 

129 "path": item.path, 

130 "type(data)": str(type(data)), 

131 "len(data)": len(data), 

132 } 

133 

134 if item.item_type == "ndarray": 

135 output[key].update(arr_metadata(data)) 

136 elif item.item_type.startswith("jsonl") and len(data) > 0: 

137 output[key]["data[0]"] = data[0] 

138 

139 return { 

140 key: val 

141 for key, val in sorted(output.items(), key=lambda x: len(x[1]["path"])) 

142 } 

143 

144 def meta(self) -> JSONitem: 

145 """return the metadata of the ZANJ archive""" 

146 

147 serialization_handlers = {h.uid: h.serialize() for h in self.handlers} 

148 load_handlers = {h.uid: h.serialize() for h in LOADER_MAP.values()} 

149 

150 return dict( 

151 # configuration of this ZANJ instance 

152 zanj_cfg=dict( 

153 error_mode=str(self.error_mode), 

154 array_mode=str(self.array_mode), 

155 external_array_threshold=self.external_array_threshold, 

156 external_list_threshold=self.external_list_threshold, 

157 compress=self.compress, 

158 serialization_handlers=serialization_handlers, 

159 load_handlers=load_handlers, 

160 ), 

161 # system info (python, pip packages, torch & cuda, platform info, git info) 

162 sysinfo=json_serialize(SysInfo.get_all(include=("python", "pytorch"))), 

163 externals_info=self.externals_info(), 

164 timestamp=time.time(), 

165 ) 

166 

167 def save(self, obj: Any, file_path: str | Path) -> str: 

168 """save the object to a ZANJ archive. returns the path to the archive""" 

169 

170 # adjust extension 

171 file_path = str(file_path) 

172 if not file_path.endswith(".zanj"): 

173 file_path += ".zanj" 

174 

175 # make directory 

176 dir_path: str = os.path.dirname(file_path) 

177 if dir_path != "": 

178 if not os.path.exists(dir_path): 

179 os.makedirs(dir_path, exist_ok=False) 

180 

181 # clear the externals! 

182 self._externals = dict() 

183 

184 # serialize the object -- this will populate self._externals 

185 # TODO: calling self.json_serialize again here might be slow 

186 json_data: JSONitem = self.json_serialize(self.json_serialize(obj)) 

187 

188 # open the zip file 

189 zipf: zipfile.ZipFile = zipfile.ZipFile( 

190 file=file_path, mode="w", compression=self.compress 

191 ) 

192 

193 # store base json data and metadata 

194 zipf.writestr( 

195 ZANJ_META, 

196 json.dumps( 

197 self.json_serialize(self.meta()), 

198 indent="\t", 

199 ), 

200 ) 

201 zipf.writestr( 

202 ZANJ_MAIN, 

203 json.dumps( 

204 json_data, 

205 indent="\t", 

206 ), 

207 ) 

208 

209 # store externals 

210 for key, (ext_type, ext_data, ext_path) in self._externals.items(): 

211 # why force zip64? numpy.savez does it 

212 with zipf.open(key, "w", force_zip64=True) as fp: 

213 EXTERNAL_STORE_FUNCS[ext_type](self, fp, ext_data) 

214 

215 zipf.close() 

216 

217 # clear the externals, again 

218 self._externals = dict() 

219 

220 return file_path 

221 

222 def read( 

223 self, 

224 file_path: Union[str, Path], 

225 ) -> Any: 

226 """load the object from a ZANJ archive 

227 # TODO: load only some part of the zanj file by passing an ObjectPath 

228 """ 

229 file_path = Path(file_path) 

230 if not file_path.exists(): 

231 raise FileNotFoundError(f"file not found: {file_path}") 

232 if not file_path.is_file(): 

233 raise FileNotFoundError(f"not a file: {file_path}") 

234 

235 loaded_zanj: LoadedZANJ = LoadedZANJ( 

236 path=file_path, 

237 zanj=self, 

238 ) 

239 

240 loaded_zanj.populate_externals() 

241 

242 return load_item_recursive( 

243 loaded_zanj._json_data, 

244 path=tuple(), 

245 zanj=self, 

246 error_mode=self.error_mode, 

247 # lh_map=loader_handlers, 

248 ) 

249 

250 

251zanj.externals._ZANJ_pre = ZANJ # type: ignore