Coverage for zanj/loading.py: 79%

132 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-20 11:17 -0600

1from __future__ import annotations 

2 

3import json 

4import threading 

5import typing 

6import zipfile 

7from dataclasses import dataclass 

8from pathlib import Path 

9from typing import Any, Callable 

10 

11import numpy as np 

12 

13try: 

14 import pandas as pd # type: ignore[import] 

15 

16 pandas_DataFrame = pd.DataFrame # type: ignore[no-redef] 

17except ImportError: 

18 

19 class pandas_DataFrame: # type: ignore[no-redef] 

20 def __init__(self, *args, **kwargs): 

21 raise ImportError("cannot load pandas DataFrame, pandas is not installed") 

22 

23 

24from muutils.errormode import ErrorMode 

25from muutils.json_serialize.array import load_array 

26from muutils.json_serialize.json_serialize import ObjectPath 

27from muutils.json_serialize.util import ( 

28 _FORMAT_KEY, 

29 _REF_KEY, 

30 JSONdict, 

31 JSONitem, 

32 safe_getsource, 

33 string_as_lines, 

34) 

35 

36from zanj.externals import ( 

37 GET_EXTERNAL_LOAD_FUNC, 

38 ZANJ_MAIN, 

39 ZANJ_META, 

40 ExternalItem, 

41 _ZANJ_pre, 

42) 

43 

44# pylint: disable=protected-access, dangerous-default-value 

45 

46 

47def _populate_externals_error_checking(key, item) -> bool: 

48 """checks that the key is valid for the item. returns "True" we need to augment the path by accessing the "data" element""" 

49 

50 # special case for not fully loaded external item which we still need to populate 

51 if isinstance(item, typing.Mapping): 

52 if (_FORMAT_KEY in item) and item[_FORMAT_KEY].endswith(":external"): 

53 if "data" in item: 

54 return True 

55 else: 

56 raise KeyError( 

57 f"expected an external item, but could not find data: {list(item.keys())}", 

58 f"{item[_FORMAT_KEY]}, {len(item) = }, {item.get('data', '<EMPTY>') = }", 

59 ) 

60 

61 # if it's a list, make sure the key is an int and that it's in range 

62 if isinstance(item, typing.Sequence): 

63 if not isinstance(key, int): 

64 raise TypeError(f"improper type: '{type(key) = }', expected int") 

65 if key >= len(item): 

66 raise IndexError(f"index out of range: '{key = }', expected < {len(item)}") 

67 

68 # if it's a dict, make sure that the key is a str and that it's in the dict 

69 elif isinstance(item, typing.Mapping): 

70 if not isinstance(key, str): 

71 raise TypeError(f"improper type: '{type(key) = }', expected str") 

72 if key not in item: 

73 raise KeyError(f"key not in dict: '{key = }', expected in {item.keys()}") 

74 

75 # otherwise, raise an error 

76 else: 

77 raise TypeError(f"improper type: '{type(item) = }', expected dict or list") 

78 

79 return False 

80 

81 

82@dataclass 

83class LoaderHandler: 

84 """handler for loading an object from a json file or a ZANJ archive""" 

85 

86 # TODO: add a separate "asserts" function? 

87 # right now, any asserts must happen in `check` or `load` which is annoying with lambdas 

88 

89 # (json_data, path) -> whether to use this handler 

90 check: Callable[[JSONitem, ObjectPath, _ZANJ_pre], bool] 

91 # function to load the object (json_data, path) -> loaded_obj 

92 load: Callable[[JSONitem, ObjectPath, _ZANJ_pre], Any] 

93 # unique identifier for the handler, saved in __muutils_format__ field 

94 uid: str 

95 # source package of the handler -- note that this might be overridden by ZANJ 

96 source_pckg: str 

97 # priority of the handler, defaults are all 0 

98 priority: int = 0 

99 # description of the handler 

100 desc: str = "(no description)" 

101 

102 def serialize(self) -> JSONdict: 

103 """serialize the handler info""" 

104 return { 

105 # get the code and doc of the check function 

106 "check": { 

107 "code": safe_getsource(self.check), 

108 "doc": string_as_lines(self.check.__doc__), 

109 }, 

110 # get the code and doc of the load function 

111 "load": { 

112 "code": safe_getsource(self.load), 

113 "doc": string_as_lines(self.load.__doc__), 

114 }, 

115 # get the uid, source_pckg, priority, and desc 

116 "uid": str(self.uid), 

117 "source_pckg": str(self.source_pckg), 

118 "priority": int(self.priority), 

119 "desc": str(self.desc), 

120 } 

121 

122 @classmethod 

123 def from_formattedclass(cls, fc: type, priority: int = 0): 

124 """create a loader from a class with `serialize`, `load` methods and `__muutils_format__` attribute""" 

125 assert hasattr(fc, "serialize") 

126 assert callable(fc.serialize) # type: ignore 

127 assert hasattr(fc, "load") 

128 assert callable(fc.load) # type: ignore 

129 assert hasattr(fc, _FORMAT_KEY) 

130 assert isinstance(fc.__muutils_format__, str) # type: ignore 

131 

132 return cls( 

133 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

134 json_item[_FORMAT_KEY] == fc.__muutils_format__ # type: ignore[attr-defined] 

135 ), 

136 load=lambda json_item, path=None, z=None: fc.load(json_item, path, z), # type: ignore[misc] 

137 uid=fc.__muutils_format__, # type: ignore[attr-defined] 

138 source_pckg=str(fc.__module__), 

139 priority=priority, 

140 desc=f"formatted class loader for {fc.__name__}", 

141 ) 

142 

143 

144# TODO: how can we type hint this without actually importing torch? 

145def _torch_loaderhandler_load( 

146 json_item: JSONitem, 

147 path: ObjectPath, 

148 z: _ZANJ_pre | None = None, 

149): 

150 """load a torch tensor from a json item""" 

151 try: 

152 import torch # type: ignore[import-not-found] 

153 from muutils.tensor_utils import TORCH_DTYPE_MAP 

154 except ImportError as e: 

155 err_msg: str = f"could not import torch, which we need to load the object at {path = }: {json_item = }" 

156 raise ImportError(err_msg) from e 

157 

158 return torch.tensor( 

159 load_array(json_item), 

160 dtype=TORCH_DTYPE_MAP[json_item["dtype"]], # type: ignore[index, call-overload] 

161 ) 

162 

163 

164# NOTE: there are type ignores on the loaders, since the type checking should be the responsibility of the check function 

165 

166LOADER_MAP_LOCK = threading.Lock() 

167 

168LOADER_MAP: dict[str, LoaderHandler] = { 

169 lh.uid: lh 

170 for lh in [ 

171 # array external 

172 LoaderHandler( 

173 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

174 isinstance(json_item, typing.Mapping) 

175 and _FORMAT_KEY in json_item 

176 and json_item[_FORMAT_KEY].startswith("numpy.ndarray") 

177 # and json_item["data"].dtype.name == json_item["dtype"] 

178 # and tuple(json_item["data"].shape) == tuple(json_item["shape"]) 

179 ), 

180 load=lambda json_item, path=None, z=None: np.array( # type: ignore[misc] 

181 load_array(json_item), dtype=np.dtype(json_item["dtype"]) 

182 ), 

183 uid="numpy.ndarray", 

184 source_pckg="zanj", 

185 desc="numpy.ndarray loader", 

186 ), 

187 LoaderHandler( 

188 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

189 isinstance(json_item, typing.Mapping) 

190 and _FORMAT_KEY in json_item 

191 and json_item[_FORMAT_KEY].startswith("torch.Tensor") 

192 # and json_item["data"].dtype.name == json_item["dtype"] 

193 # and tuple(json_item["data"].shape) == tuple(json_item["shape"]) 

194 ), 

195 load=_torch_loaderhandler_load, 

196 uid="torch.Tensor", 

197 source_pckg="zanj", 

198 desc="torch.Tensor loader", 

199 ), 

200 # pandas 

201 LoaderHandler( 

202 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

203 isinstance(json_item, typing.Mapping) 

204 and _FORMAT_KEY in json_item 

205 and json_item[_FORMAT_KEY].startswith("pandas.DataFrame") 

206 and "data" in json_item 

207 and isinstance(json_item["data"], typing.Sequence) 

208 ), 

209 load=lambda json_item, path=None, z=None: pandas_DataFrame( # type: ignore[misc] 

210 json_item["data"] 

211 ), 

212 uid="pandas.DataFrame", 

213 source_pckg="zanj", 

214 desc="pandas.DataFrame loader", 

215 ), 

216 # list/tuple external 

217 LoaderHandler( 

218 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

219 isinstance(json_item, typing.Mapping) 

220 and _FORMAT_KEY in json_item 

221 and json_item[_FORMAT_KEY].startswith("list") 

222 and "data" in json_item 

223 and isinstance(json_item["data"], typing.Sequence) 

224 ), 

225 load=lambda json_item, path=None, z=None: [ # type: ignore[misc] 

226 load_item_recursive(x, path, z) for x in json_item["data"] 

227 ], 

228 uid="list", 

229 source_pckg="zanj", 

230 desc="list loader, for externals", 

231 ), 

232 LoaderHandler( 

233 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

234 isinstance(json_item, typing.Mapping) 

235 and _FORMAT_KEY in json_item 

236 and json_item[_FORMAT_KEY].startswith("tuple") 

237 and "data" in json_item 

238 and isinstance(json_item["data"], typing.Sequence) 

239 ), 

240 load=lambda json_item, path=None, z=None: tuple( # type: ignore[misc] 

241 [load_item_recursive(x, path, z) for x in json_item["data"]] 

242 ), 

243 uid="tuple", 

244 source_pckg="zanj", 

245 desc="tuple loader, for externals", 

246 ), 

247 ] 

248} 

249 

250 

251def register_loader_handler(handler: LoaderHandler): 

252 """register a custom loader handler""" 

253 global LOADER_MAP, LOADER_MAP_LOCK 

254 with LOADER_MAP_LOCK: 

255 LOADER_MAP[handler.uid] = handler 

256 

257 

258def get_item_loader( 

259 json_item: JSONitem, 

260 path: ObjectPath, 

261 zanj: _ZANJ_pre | None = None, 

262 error_mode: ErrorMode = ErrorMode.WARN, 

263 # lh_map: dict[str, LoaderHandler] = LOADER_MAP, 

264) -> LoaderHandler | None: 

265 """get the loader for a json item""" 

266 global LOADER_MAP 

267 

268 # check if we recognize the format 

269 if isinstance(json_item, typing.Mapping) and _FORMAT_KEY in json_item: 

270 if not isinstance(json_item[_FORMAT_KEY], str): 

271 raise TypeError( 

272 f"invalid __muutils_format__ type '{type(json_item[_FORMAT_KEY])}' in '{path=}': '{json_item[_FORMAT_KEY] = }'" 

273 ) 

274 if json_item[_FORMAT_KEY] in LOADER_MAP: 

275 return LOADER_MAP[json_item[_FORMAT_KEY]] # type: ignore[index] 

276 

277 # if we dont recognize the format, try to find a loader that can handle it 

278 for key, lh in LOADER_MAP.items(): 

279 if lh.check(json_item, path, zanj): 

280 return lh 

281 

282 # if we still dont have a loader, return None 

283 return None 

284 

285 

286def load_item_recursive( 

287 json_item: JSONitem, 

288 path: ObjectPath, 

289 zanj: _ZANJ_pre | None = None, 

290 error_mode: ErrorMode = ErrorMode.WARN, 

291 allow_not_loading: bool = True, 

292) -> Any: 

293 lh: LoaderHandler | None = get_item_loader( 

294 json_item=json_item, 

295 path=path, 

296 zanj=zanj, 

297 error_mode=error_mode, 

298 # lh_map=lh_map, 

299 ) 

300 

301 if lh is not None: 

302 # special case for serializable dataclasses 

303 if ( 

304 isinstance(json_item, typing.Mapping) 

305 and (_FORMAT_KEY in json_item) 

306 and ("SerializableDataclass" in json_item[_FORMAT_KEY]) # type: ignore[operator] 

307 ): 

308 # why this horribleness? 

309 # SerializableDataclass, if it has a field `x` which is also a SerializableDataclass, will automatically call `x.__class__.load()` 

310 # However, we need to load things in containers, as well as arrays 

311 processed_json_item: dict = { 

312 key: ( 

313 val 

314 if ( 

315 isinstance(val, typing.Mapping) 

316 and (_FORMAT_KEY in val) 

317 and ("SerializableDataclass" in val[_FORMAT_KEY]) 

318 ) 

319 else load_item_recursive( 

320 json_item=val, 

321 path=tuple(path) + (key,), 

322 zanj=zanj, 

323 error_mode=error_mode, 

324 ) 

325 ) 

326 for key, val in json_item.items() 

327 } 

328 

329 return lh.load(processed_json_item, path, zanj) 

330 

331 else: 

332 return lh.load(json_item, path, zanj) 

333 else: 

334 if isinstance(json_item, dict): 

335 return { 

336 key: load_item_recursive( 

337 json_item=json_item[key], 

338 path=tuple(path) + (key,), 

339 zanj=zanj, 

340 error_mode=error_mode, 

341 # lh_map=lh_map, 

342 ) 

343 for key in json_item 

344 } 

345 elif isinstance(json_item, list): 

346 return [ 

347 load_item_recursive( 

348 json_item=x, 

349 path=tuple(path) + (i,), 

350 zanj=zanj, 

351 error_mode=error_mode, 

352 # lh_map=lh_map, 

353 ) 

354 for i, x in enumerate(json_item) 

355 ] 

356 elif isinstance(json_item, (str, int, float, bool, type(None))): 

357 return json_item 

358 else: 

359 if allow_not_loading: 

360 return json_item 

361 else: 

362 raise ValueError( 

363 f"unknown type {type(json_item)} at {path}\n{json_item}" 

364 ) 

365 

366 

367def _each_item_in_externals( 

368 externals: dict[str, ExternalItem], 

369 json_data: JSONitem, 

370) -> typing.Iterable[tuple[str, ExternalItem, Any, ObjectPath]]: 

371 """note that you MUST use the raw iterator, dont try to turn into a list or something""" 

372 

373 sorted_externals: list[tuple[str, ExternalItem]] = sorted( 

374 externals.items(), key=lambda x: len(x[1].path) 

375 ) 

376 

377 for ext_path, ext_item in sorted_externals: 

378 # get the path to the item 

379 path: ObjectPath = tuple(ext_item.path) 

380 assert len(path) > 0 

381 assert all(isinstance(key, (str, int)) for key in path), ( 

382 f"improper types in path {path=}" 

383 ) 

384 # get the item 

385 item = json_data 

386 for i, key in enumerate(path): 

387 try: 

388 # ignores in this block are because we cannot know the type is indexable in static analysis 

389 # but, we check the types in the line below 

390 external_unloaded: bool = _populate_externals_error_checking(key, item) 

391 if external_unloaded: 

392 item = item["data"] # type: ignore 

393 item = item[key] # type: ignore[index] 

394 

395 except (KeyError, IndexError, TypeError) as e: 

396 raise KeyError( 

397 f"could not find '{key = }' at path '{ext_path = }', specifically at index '{i = }'", 

398 f"'{type(item) =}', '{len(item) = }', '{item.keys() if isinstance(item, dict) else None = }'", # type: ignore 

399 f"From error: {e = }", 

400 f"\n\n{item=}\n\n{ext_item=}", 

401 ) from e 

402 

403 yield (ext_path, ext_item, item, path) 

404 

405 

406class LoadedZANJ: 

407 """for loading a zanj file""" 

408 

409 def __init__( 

410 self, 

411 path: str | Path, 

412 zanj: _ZANJ_pre, 

413 ) -> None: 

414 # path and zanj object 

415 self._path: str = str(path) 

416 self._zanj: _ZANJ_pre = zanj 

417 

418 # load zip file 

419 _zipf: zipfile.ZipFile = zipfile.ZipFile(file=self._path, mode="r") 

420 

421 # load data 

422 self._meta: JSONdict = json.load(_zipf.open(ZANJ_META, "r")) 

423 self._json_data: JSONitem = json.load(_zipf.open(ZANJ_MAIN, "r")) 

424 

425 # read externals 

426 self._externals: dict[str, ExternalItem] = dict() 

427 for fname, ext_item in self._meta["externals_info"].items(): # type: ignore 

428 item_type: str = ext_item["item_type"] # type: ignore 

429 with _zipf.open(fname, "r") as fp: 

430 self._externals[fname] = ExternalItem( 

431 item_type=item_type, # type: ignore[arg-type] 

432 data=GET_EXTERNAL_LOAD_FUNC(item_type)(self, fp), 

433 path=ext_item["path"], # type: ignore 

434 ) 

435 

436 # close zip file 

437 _zipf.close() 

438 del _zipf 

439 

440 def populate_externals(self) -> None: 

441 """put all external items into the main json data""" 

442 

443 # loop over once, populating the externals only 

444 for ext_path, ext_item, item, path in _each_item_in_externals( 

445 self._externals, self._json_data 

446 ): 

447 # replace the item with the external item 

448 assert _REF_KEY in item # type: ignore 

449 assert item[_REF_KEY] == ext_path # type: ignore 

450 item["data"] = ext_item.data # type: ignore