Coverage for zanj / serializing.py: 100%

74 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 01:52 -0700

1from __future__ import annotations 

2 

3import json 

4import sys 

5from dataclasses import dataclass 

6from typing import IO, Any, Callable, Iterable, Sequence 

7import warnings 

8 

9import numpy as np 

10from muutils.json_serialize.array import arr_metadata 

11from muutils.json_serialize.json_serialize import ( # JsonSerializer, 

12 DEFAULT_HANDLERS, 

13 ObjectPath, 

14 SerializerHandler, 

15) 

16 

17from zanj.consts import JSONdict, JSONitem, MonoTuple, _FORMAT_KEY, _REF_KEY 

18 

19from zanj.externals import ExternalItem, ExternalItemType, _ZANJ_pre 

20 

21KW_ONLY_KWARGS: dict = dict() 

22if sys.version_info >= (3, 10): 

23 KW_ONLY_KWARGS["kw_only"] = True 

24 

25# pylint: disable=unused-argument, protected-access, unexpected-keyword-arg 

26# for some reason pylint complains about kwargs to ZANJSerializerHandler 

27 

28 

29def jsonl_metadata(data: list[JSONdict]) -> dict[str, Any]: 

30 """metadata about a jsonl object""" 

31 all_cols: set[str] = set([col for item in data for col in item.keys()]) 

32 output: dict[str, Any] = { 

33 "len(data)": len(data), 

34 "columns": { 

35 col: { 

36 "types": list( 

37 set([type(item[col]).__name__ for item in data if col in item]) 

38 ), 

39 "len": len([item[col] for item in data if col in item]), 

40 } 

41 for col in all_cols 

42 if col != _FORMAT_KEY 

43 }, 

44 } 

45 if len(data) > 0: 

46 output["data[0]"] = data[0] 

47 return output 

48 

49 

50def store_npy(self: _ZANJ_pre, fp: IO[bytes], data: np.ndarray) -> None: 

51 """store numpy array to given file as .npy""" 

52 # TODO: Type `<module 'numpy.lib'>` has no attribute `format` --> zanj/serializing.py:54:5 

53 # info: rule `unresolved-attribute` is enabled by default 

54 np.lib.format.write_array( # ty: ignore[unresolved-attribute] 

55 fp=fp, 

56 array=np.asanyarray(data), 

57 allow_pickle=False, 

58 ) 

59 

60 

61def store_jsonl(self: _ZANJ_pre, fp: IO[bytes], data: Sequence[JSONitem]) -> None: 

62 """store sequence to given file as .jsonl""" 

63 

64 for item in data: 

65 fp.write(json.dumps(item).encode("utf-8")) 

66 fp.write("\n".encode("utf-8")) 

67 

68 

69EXTERNAL_STORE_FUNCS: dict[ 

70 ExternalItemType, Callable[[_ZANJ_pre, IO[bytes], Any], None] 

71] = { 

72 "npy": store_npy, 

73 "jsonl": store_jsonl, 

74} 

75 

76 

77@dataclass(**KW_ONLY_KWARGS) 

78class ZANJSerializerHandler(SerializerHandler): 

79 """a handler for ZANJ serialization""" 

80 

81 # unique identifier for the handler, saved in _FORMAT_KEY field 

82 # uid: str 

83 # source package of the handler -- note that this might be overridden by ZANJ 

84 source_pckg: str 

85 # (self_config, object) -> whether to use this handler 

86 check: Callable[[_ZANJ_pre, Any, ObjectPath], bool] 

87 # (self_config, object, path) -> serialized object 

88 serialize_func: Callable[[_ZANJ_pre, Any, ObjectPath], JSONitem] 

89 # optional description of how this serializer works 

90 # desc: str = "(no description)" 

91 

92 

93def zanj_external_serialize( 

94 jser: _ZANJ_pre, 

95 data: Any, 

96 path: ObjectPath, 

97 item_type: ExternalItemType, 

98 _format: str, 

99) -> JSONitem: 

100 """stores a numpy array or jsonl externally in a ZANJ object 

101 

102 # Parameters: 

103 - `jser: ZANJ` 

104 - `data: Any` 

105 - `path: ObjectPath` 

106 - `item_type: ExternalItemType` 

107 

108 # Returns: 

109 - `JSONitem` 

110 json data with reference 

111 

112 # Modifies: 

113 - modifies `jser._externals` 

114 """ 

115 # get the path, make sure its unique 

116 assert isinstance(path, tuple), ( 

117 f"path must be a tuple, got {type(path) = } {path = }" 

118 ) 

119 joined_path: str = "/".join([str(p) for p in path]) 

120 archive_path: str = f"{joined_path}.{item_type}" 

121 

122 # TODO: somehow need to control whether a failure here causes a fallback to other handlers, or whether the except should propagate 

123 # this will probably require changes to the upstream muutils.json_serialize code 

124 if archive_path in jser._externals: 

125 err_msg = f"external path {archive_path} already exists!" 

126 warnings.warn(err_msg) 

127 raise ValueError(err_msg) 

128 # Check for true path prefix conflicts (not just string prefix) 

129 # Only flag when one path is a directory ancestor of another (contains "/" separator) 

130 for p in jser._externals.keys(): 

131 # Remove the file extension to get the joined_path 

132 existing_joined_path = p.rsplit(".", 1)[0] 

133 # Check if one is a true path prefix with "/" separator 

134 if existing_joined_path.startswith(joined_path + "/") or joined_path.startswith( 

135 existing_joined_path + "/" 

136 ): 

137 err_msg = ( 

138 f"external path {joined_path} is a prefix of another path {p}!\n" 

139 + f"{jser._externals.keys() = }\n{joined_path = }\n{path = }\n{p = }\n{existing_joined_path = }\n{archive_path = }\n{_format = }" 

140 ) 

141 warnings.warn(err_msg) 

142 raise ValueError(err_msg) 

143 

144 # process the data if needed, assemble metadata 

145 data_new: Any = data 

146 output: dict = { 

147 _FORMAT_KEY: _format, 

148 _REF_KEY: archive_path, 

149 } 

150 if item_type == "npy": 

151 # check type 

152 data_type_str: str = str(type(data)) 

153 if data_type_str == "<class 'torch.Tensor'>": 

154 # detach and convert 

155 data_new = data.detach().cpu().numpy() 

156 elif data_type_str == "<class 'numpy.ndarray'>": 

157 pass 

158 else: 

159 # if not a numpy array, except 

160 raise TypeError(f"expected numpy.ndarray, got {data_type_str}") 

161 # get metadata 

162 output.update(arr_metadata(data)) 

163 elif item_type.startswith("jsonl"): 

164 # check via module and class name to avoid importing pandas (works with pandas 3.0+) 

165 dataframe_columns = None 

166 if ( 

167 "pandas" in data.__class__.__module__ 

168 and data.__class__.__name__ == "DataFrame" 

169 ): 

170 dataframe_columns = data.columns.tolist() 

171 data_new = data.to_dict(orient="records") 

172 elif ( 

173 "polars" in data.__class__.__module__ 

174 and data.__class__.__name__ == "DataFrame" 

175 ): 

176 dataframe_columns = data.columns 

177 data_new = data.to_dicts() 

178 elif isinstance(data, (list, tuple, Iterable, Sequence)): 

179 data_new = [ 

180 jser.json_serialize(item, tuple(path) + (i,)) 

181 for i, item in enumerate(data) 

182 ] 

183 else: 

184 raise TypeError( 

185 f"expected list or pandas.DataFrame for jsonl, got {type(data)}" 

186 ) 

187 

188 if all([isinstance(item, dict) for item in data_new]): 

189 output.update(jsonl_metadata(data_new)) 

190 

191 # set DataFrame columns after jsonl_metadata to avoid being overwritten 

192 if dataframe_columns is not None: 

193 output["columns"] = dataframe_columns 

194 

195 # store the item for external serialization 

196 jser._externals[archive_path] = ExternalItem( 

197 item_type=item_type, 

198 data=data_new, 

199 path=path, 

200 ) 

201 

202 return output 

203 

204 

205DEFAULT_SERIALIZER_HANDLERS_ZANJ: MonoTuple[ZANJSerializerHandler] = tuple( 

206 [ 

207 ZANJSerializerHandler( 

208 check=lambda self, obj, path: ( 

209 isinstance(obj, np.ndarray) 

210 and obj.size >= self.external_array_threshold 

211 ), 

212 serialize_func=lambda self, obj, path: zanj_external_serialize( 

213 self, obj, path, item_type="npy", _format="numpy.ndarray:external" 

214 ), 

215 uid="numpy.ndarray:external", 

216 source_pckg="zanj", 

217 desc="external numpy array", 

218 ), 

219 ZANJSerializerHandler( 

220 check=lambda self, obj, path: ( 

221 str(type(obj)) == "<class 'torch.Tensor'>" 

222 and int(obj.nelement()) >= self.external_array_threshold 

223 ), 

224 serialize_func=lambda self, obj, path: zanj_external_serialize( 

225 self, obj, path, item_type="npy", _format="torch.Tensor:external" 

226 ), 

227 uid="torch.Tensor:external", 

228 source_pckg="zanj", 

229 desc="external torch tensor", 

230 ), 

231 ZANJSerializerHandler( 

232 check=lambda self, obj, path: ( 

233 isinstance(obj, list) and len(obj) >= self.external_list_threshold 

234 ), 

235 serialize_func=lambda self, obj, path: zanj_external_serialize( 

236 self, obj, path, item_type="jsonl", _format="list:external" 

237 ), 

238 uid="list:external", 

239 source_pckg="zanj", 

240 desc="external list", 

241 ), 

242 ZANJSerializerHandler( 

243 check=lambda self, obj, path: ( 

244 isinstance(obj, tuple) and len(obj) >= self.external_list_threshold 

245 ), 

246 serialize_func=lambda self, obj, path: zanj_external_serialize( 

247 self, obj, path, item_type="jsonl", _format="tuple:external" 

248 ), 

249 uid="tuple:external", 

250 source_pckg="zanj", 

251 desc="external tuple", 

252 ), 

253 ZANJSerializerHandler( 

254 check=lambda self, obj, path: ( 

255 "pandas" in obj.__class__.__module__ 

256 and obj.__class__.__name__ == "DataFrame" 

257 ), 

258 serialize_func=lambda self, obj, path: zanj_external_serialize( 

259 self, obj, path, item_type="jsonl", _format="pandas.DataFrame:external" 

260 ), 

261 uid="pandas.DataFrame:external", 

262 source_pckg="zanj", 

263 desc="external pandas DataFrame", 

264 ), 

265 ZANJSerializerHandler( 

266 check=lambda self, obj, path: ( 

267 "polars" in obj.__class__.__module__ 

268 and obj.__class__.__name__ == "DataFrame" 

269 ), 

270 serialize_func=lambda self, obj, path: zanj_external_serialize( 

271 self, obj, path, item_type="jsonl", _format="polars.DataFrame:external" 

272 ), 

273 uid="polars.DataFrame:external", 

274 source_pckg="zanj", 

275 desc="external polars DataFrame", 

276 ), 

277 # ZANJSerializerHandler( 

278 # check=lambda self, obj, path: "<class 'torch.nn.modules.module.Module'>" 

279 # in [str(t) for t in obj.__class__.__mro__], 

280 # serialize_func=lambda self, obj, path: zanj_serialize_torchmodule( 

281 # self, obj, path, 

282 # ), 

283 # uid="torch.nn.Module", 

284 # source_pckg="zanj", 

285 # desc="fallback torch serialization", 

286 # ), 

287 ] 

288) + tuple( 

289 DEFAULT_HANDLERS # type: ignore[arg-type] 

290) 

291 

292# the complaint above is: 

293# error: Argument 1 to "tuple" has incompatible type "Sequence[SerializerHandler]"; expected "Iterable[ZANJSerializerHandler]" [arg-type]