Coverage for zanj/serializing.py: 88%

65 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-11-06 16:15 +0000

1from __future__ import annotations 

2 

3import json 

4import sys 

5from dataclasses import dataclass 

6from typing import IO, Any, Callable, Iterable, Sequence 

7import warnings 

8 

9import numpy as np 

10from muutils.json_serialize.array import arr_metadata 

11from muutils.json_serialize.json_serialize import ( # JsonSerializer, 

12 DEFAULT_HANDLERS, 

13 ObjectPath, 

14 SerializerHandler, 

15) 

16from muutils.json_serialize.util import ( 

17 JSONdict, 

18 JSONitem, 

19 MonoTuple, 

20 _FORMAT_KEY, 

21 _REF_KEY, 

22) 

23 

24from zanj.externals import ExternalItem, ExternalItemType, _ZANJ_pre 

25 

26KW_ONLY_KWARGS: dict = dict() 

27if sys.version_info >= (3, 10): 

28 KW_ONLY_KWARGS["kw_only"] = True 

29 

30# pylint: disable=unused-argument, protected-access, unexpected-keyword-arg 

31# for some reason pylint complains about kwargs to ZANJSerializerHandler 

32 

33 

34def jsonl_metadata(data: list[JSONdict]) -> dict: 

35 """metadata about a jsonl object""" 

36 all_cols: set[str] = set([col for item in data for col in item.keys()]) 

37 return { 

38 "data[0]": data[0], 

39 "len(data)": len(data), 

40 "columns": { 

41 col: { 

42 "types": list( 

43 set([type(item[col]).__name__ for item in data if col in item]) 

44 ), 

45 "len": len([item[col] for item in data if col in item]), 

46 } 

47 for col in all_cols 

48 if col != _FORMAT_KEY 

49 }, 

50 } 

51 

52 

53def store_npy(self: _ZANJ_pre, fp: IO[bytes], data: np.ndarray) -> None: 

54 """store numpy array to given file as .npy""" 

55 # TODO: Type `<module 'numpy.lib'>` has no attribute `format` --> zanj/serializing.py:54:5 

56 # info: rule `unresolved-attribute` is enabled by default 

57 np.lib.format.write_array( # ty: ignore[unresolved-attribute] 

58 fp=fp, 

59 array=np.asanyarray(data), 

60 allow_pickle=False, 

61 ) 

62 

63 

64def store_jsonl(self: _ZANJ_pre, fp: IO[bytes], data: Sequence[JSONitem]) -> None: 

65 """store sequence to given file as .jsonl""" 

66 

67 for item in data: 

68 fp.write(json.dumps(item).encode("utf-8")) 

69 fp.write("\n".encode("utf-8")) 

70 

71 

72EXTERNAL_STORE_FUNCS: dict[ 

73 ExternalItemType, Callable[[_ZANJ_pre, IO[bytes], Any], None] 

74] = { 

75 "npy": store_npy, 

76 "jsonl": store_jsonl, 

77} 

78 

79 

80@dataclass(**KW_ONLY_KWARGS) 

81class ZANJSerializerHandler(SerializerHandler): 

82 """a handler for ZANJ serialization""" 

83 

84 # unique identifier for the handler, saved in _FORMAT_KEY field 

85 # uid: str 

86 # source package of the handler -- note that this might be overridden by ZANJ 

87 source_pckg: str 

88 # (self_config, object) -> whether to use this handler 

89 check: Callable[[_ZANJ_pre, Any, ObjectPath], bool] 

90 # (self_config, object, path) -> serialized object 

91 serialize_func: Callable[[_ZANJ_pre, Any, ObjectPath], JSONitem] 

92 # optional description of how this serializer works 

93 # desc: str = "(no description)" 

94 

95 

96def zanj_external_serialize( 

97 jser: _ZANJ_pre, 

98 data: Any, 

99 path: ObjectPath, 

100 item_type: ExternalItemType, 

101 _format: str, 

102) -> JSONitem: 

103 """stores a numpy array or jsonl externally in a ZANJ object 

104 

105 # Parameters: 

106 - `jser: ZANJ` 

107 - `data: Any` 

108 - `path: ObjectPath` 

109 - `item_type: ExternalItemType` 

110 

111 # Returns: 

112 - `JSONitem` 

113 json data with reference 

114 

115 # Modifies: 

116 - modifies `jser._externals` 

117 """ 

118 # get the path, make sure its unique 

119 assert isinstance(path, tuple), ( 

120 f"path must be a tuple, got {type(path) = } {path = }" 

121 ) 

122 joined_path: str = "/".join([str(p) for p in path]) 

123 archive_path: str = f"{joined_path}.{item_type}" 

124 

125 # TODO: somehow need to control whether a failure here causes a fallback to other handlers, or whether the except should propagate 

126 # this will probably require changes to the upstream muutils.json_serialize code 

127 if archive_path in jser._externals: 

128 err_msg = f"external path {archive_path} already exists!" 

129 warnings.warn(err_msg) 

130 raise ValueError(err_msg) 

131 # Check for true path prefix conflicts (not just string prefix) 

132 # Only flag when one path is a directory ancestor of another (contains "/" separator) 

133 for p in jser._externals.keys(): 

134 # Remove the file extension to get the joined_path 

135 existing_joined_path = p.rsplit(".", 1)[0] 

136 # Check if one is a true path prefix with "/" separator 

137 if existing_joined_path.startswith(joined_path + "/") or joined_path.startswith( 

138 existing_joined_path + "/" 

139 ): 

140 err_msg = ( 

141 f"external path {joined_path} is a prefix of another path {p}!\n" 

142 + f"{jser._externals.keys() = }\n{joined_path = }\n{path = }\n{p = }\n{existing_joined_path = }\n{archive_path = }\n{_format = }" 

143 ) 

144 warnings.warn(err_msg) 

145 raise ValueError(err_msg) 

146 

147 # process the data if needed, assemble metadata 

148 data_new: Any = data 

149 output: dict = { 

150 _FORMAT_KEY: _format, 

151 _REF_KEY: archive_path, 

152 } 

153 if item_type == "npy": 

154 # check type 

155 data_type_str: str = str(type(data)) 

156 if data_type_str == "<class 'torch.Tensor'>": 

157 # detach and convert 

158 data_new = data.detach().cpu().numpy() 

159 elif data_type_str == "<class 'numpy.ndarray'>": 

160 pass 

161 else: 

162 # if not a numpy array, except 

163 raise TypeError(f"expected numpy.ndarray, got {data_type_str}") 

164 # get metadata 

165 output.update(arr_metadata(data)) 

166 elif item_type.startswith("jsonl"): 

167 # check via mro to avoid importing pandas 

168 if any("pandas.core.frame.DataFrame" in str(t) for t in data.__class__.__mro__): 

169 output["columns"] = data.columns.tolist() 

170 data_new = data.to_dict(orient="records") 

171 elif isinstance(data, (list, tuple, Iterable, Sequence)): 

172 data_new = [ 

173 jser.json_serialize(item, tuple(path) + (i,)) 

174 for i, item in enumerate(data) 

175 ] 

176 else: 

177 raise TypeError( 

178 f"expected list or pandas.DataFrame for jsonl, got {type(data)}" 

179 ) 

180 

181 if all([isinstance(item, dict) for item in data_new]): 

182 output.update(jsonl_metadata(data_new)) 

183 

184 # store the item for external serialization 

185 jser._externals[archive_path] = ExternalItem( 

186 item_type=item_type, 

187 data=data_new, 

188 path=path, 

189 ) 

190 

191 return output 

192 

193 

194DEFAULT_SERIALIZER_HANDLERS_ZANJ: MonoTuple[ZANJSerializerHandler] = tuple( 

195 [ 

196 ZANJSerializerHandler( 

197 check=lambda self, obj, path: ( 

198 isinstance(obj, np.ndarray) 

199 and obj.size >= self.external_array_threshold 

200 ), 

201 serialize_func=lambda self, obj, path: zanj_external_serialize( 

202 self, obj, path, item_type="npy", _format="numpy.ndarray:external" 

203 ), 

204 uid="numpy.ndarray:external", 

205 source_pckg="zanj", 

206 desc="external numpy array", 

207 ), 

208 ZANJSerializerHandler( 

209 check=lambda self, obj, path: ( 

210 str(type(obj)) == "<class 'torch.Tensor'>" 

211 and int(obj.nelement()) >= self.external_array_threshold 

212 ), 

213 serialize_func=lambda self, obj, path: zanj_external_serialize( 

214 self, obj, path, item_type="npy", _format="torch.Tensor:external" 

215 ), 

216 uid="torch.Tensor:external", 

217 source_pckg="zanj", 

218 desc="external torch tensor", 

219 ), 

220 ZANJSerializerHandler( 

221 check=lambda self, obj, path: isinstance(obj, list) 

222 and len(obj) >= self.external_list_threshold, 

223 serialize_func=lambda self, obj, path: zanj_external_serialize( 

224 self, obj, path, item_type="jsonl", _format="list:external" 

225 ), 

226 uid="list:external", 

227 source_pckg="zanj", 

228 desc="external list", 

229 ), 

230 ZANJSerializerHandler( 

231 check=lambda self, obj, path: isinstance(obj, tuple) 

232 and len(obj) >= self.external_list_threshold, 

233 serialize_func=lambda self, obj, path: zanj_external_serialize( 

234 self, obj, path, item_type="jsonl", _format="tuple:external" 

235 ), 

236 uid="tuple:external", 

237 source_pckg="zanj", 

238 desc="external tuple", 

239 ), 

240 ZANJSerializerHandler( 

241 check=lambda self, obj, path: ( 

242 any( 

243 "pandas.core.frame.DataFrame" in str(t) 

244 for t in obj.__class__.__mro__ 

245 ) 

246 and len(obj) >= self.external_list_threshold 

247 ), 

248 serialize_func=lambda self, obj, path: zanj_external_serialize( 

249 self, obj, path, item_type="jsonl", _format="pandas.DataFrame:external" 

250 ), 

251 uid="pandas.DataFrame:external", 

252 source_pckg="zanj", 

253 desc="external pandas DataFrame", 

254 ), 

255 # ZANJSerializerHandler( 

256 # check=lambda self, obj, path: "<class 'torch.nn.modules.module.Module'>" 

257 # in [str(t) for t in obj.__class__.__mro__], 

258 # serialize_func=lambda self, obj, path: zanj_serialize_torchmodule( 

259 # self, obj, path, 

260 # ), 

261 # uid="torch.nn.Module", 

262 # source_pckg="zanj", 

263 # desc="fallback torch serialization", 

264 # ), 

265 ] 

266) + tuple( 

267 DEFAULT_HANDLERS # type: ignore[arg-type] 

268) 

269 

270# the complaint above is: 

271# error: Argument 1 to "tuple" has incompatible type "Sequence[SerializerHandler]"; expected "Iterable[ZANJSerializerHandler]" [arg-type]