Coverage for zanj / loading.py: 91%

139 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 19:37 -0700

1from __future__ import annotations 

2 

3import json 

4import threading 

5import typing 

6import zipfile 

7from dataclasses import dataclass 

8from pathlib import Path 

9from typing import Any, Callable 

10 

11import numpy as np 

12 

13try: 

14 import pandas as pd # type: ignore[import] 

15 

16 pandas_DataFrame = pd.DataFrame # type: ignore[no-redef] 

17except ImportError: 

18 

19 class pandas_DataFrame: # type: ignore[no-redef] 

20 def __init__(self, *args, **kwargs): 

21 raise ImportError("cannot load pandas DataFrame, pandas is not installed") 

22 

23 

24try: 

25 import polars as pl # type: ignore[import] 

26 

27 polars_DataFrame = pl.DataFrame # type: ignore[no-redef] 

28except ImportError: 

29 

30 class polars_DataFrame: # type: ignore[no-redef] 

31 def __init__(self, *args, **kwargs): 

32 raise ImportError("cannot load polars DataFrame, polars is not installed") 

33 

34 

35from muutils.errormode import ErrorMode 

36from muutils.json_serialize.array import load_array 

37from muutils.json_serialize.json_serialize import ObjectPath 

38 

39from zanj.consts import ( 

40 JSONdict, 

41 JSONitem, 

42 _FORMAT_KEY, 

43 _REF_KEY, 

44 safe_getsource, 

45 string_as_lines, 

46) 

47 

48from zanj.externals import ( 

49 GET_EXTERNAL_LOAD_FUNC, 

50 ZANJ_MAIN, 

51 ZANJ_META, 

52 ExternalItem, 

53 _ZANJ_pre, 

54) 

55 

56# pylint: disable=protected-access, dangerous-default-value 

57 

58 

59def _populate_externals_error_checking(key, item) -> bool: 

60 """checks that the key is valid for the item. returns "True" we need to augment the path by accessing the "data" element""" 

61 

62 # special case for not fully loaded external item which we still need to populate 

63 if isinstance(item, typing.Mapping): 

64 if (_FORMAT_KEY in item) and item[_FORMAT_KEY].endswith(":external"): 

65 if "data" in item: 

66 return True 

67 else: 

68 raise KeyError( 

69 f"expected an external item, but could not find data: {list(item.keys())}", 

70 f"{item[_FORMAT_KEY]}, {len(item) = }, {item.get('data', '<EMPTY>') = }", 

71 ) 

72 

73 # if it's a list, make sure the key is an int and that it's in range 

74 if isinstance(item, typing.Sequence): 

75 if not isinstance(key, int): 

76 raise TypeError(f"improper type: '{type(key) = }', expected int") 

77 if key >= len(item): 

78 raise IndexError(f"index out of range: '{key = }', expected < {len(item)}") 

79 

80 # if it's a dict, make sure that the key is a str and that it's in the dict 

81 elif isinstance(item, typing.Mapping): 

82 if not isinstance(key, str): 

83 raise TypeError(f"improper type: '{type(key) = }', expected str") 

84 if key not in item: 

85 raise KeyError(f"key not in dict: '{key = }', expected in {item.keys()}") 

86 

87 # otherwise, raise an error 

88 else: 

89 raise TypeError(f"improper type: '{type(item) = }', expected dict or list") 

90 

91 return False 

92 

93 

94@dataclass 

95class LoaderHandler: 

96 """handler for loading an object from a json file or a ZANJ archive""" 

97 

98 # TODO: add a separate "asserts" function? 

99 # right now, any asserts must happen in `check` or `load` which is annoying with lambdas 

100 

101 # (json_data, path) -> whether to use this handler 

102 check: Callable[[JSONitem, ObjectPath, _ZANJ_pre], bool] 

103 # function to load the object (json_data, path) -> loaded_obj 

104 load: Callable[[JSONitem, ObjectPath, _ZANJ_pre], Any] 

105 # unique identifier for the handler, saved in __muutils_format__ field 

106 uid: str 

107 # source package of the handler -- note that this might be overridden by ZANJ 

108 source_pckg: str 

109 # priority of the handler, defaults are all 0 

110 priority: int = 0 

111 # description of the handler 

112 desc: str = "(no description)" 

113 

114 def serialize(self) -> JSONdict: 

115 """serialize the handler info""" 

116 return { 

117 # get the code and doc of the check function 

118 "check": { 

119 "code": safe_getsource(self.check), 

120 "doc": string_as_lines(self.check.__doc__), 

121 }, 

122 # get the code and doc of the load function 

123 "load": { 

124 "code": safe_getsource(self.load), 

125 "doc": string_as_lines(self.load.__doc__), 

126 }, 

127 # get the uid, source_pckg, priority, and desc 

128 "uid": str(self.uid), 

129 "source_pckg": str(self.source_pckg), 

130 "priority": int(self.priority), 

131 "desc": str(self.desc), 

132 } 

133 

134 @classmethod 

135 def from_formattedclass(cls, fc: type, priority: int = 0): 

136 """create a loader from a class with `serialize`, `load` methods and `__muutils_format__` attribute""" 

137 assert hasattr(fc, "serialize") 

138 assert callable(fc.serialize) # type: ignore 

139 assert hasattr(fc, "load") 

140 assert callable(fc.load) # type: ignore 

141 assert hasattr(fc, _FORMAT_KEY) 

142 assert isinstance(fc.__muutils_format__, str) # type: ignore 

143 

144 return cls( 

145 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

146 json_item[_FORMAT_KEY] == fc.__muutils_format__ # type: ignore[attr-defined] 

147 ), 

148 load=lambda json_item, path=None, z=None: fc.load(json_item, path, z), # type: ignore[misc] 

149 uid=fc.__muutils_format__, # type: ignore[attr-defined] 

150 source_pckg=str(fc.__module__), 

151 priority=priority, 

152 desc=f"formatted class loader for {fc.__name__}", 

153 ) 

154 

155 

156# TODO: how can we type hint this without actually importing torch? 

157def _torch_loaderhandler_load( 

158 json_item: JSONitem, 

159 path: ObjectPath, 

160 z: _ZANJ_pre | None = None, 

161): 

162 """load a torch tensor from a json item""" 

163 try: 

164 import torch # type: ignore[import-not-found] 

165 from muutils.tensor_utils import TORCH_DTYPE_MAP 

166 except ImportError as e: 

167 err_msg: str = f"could not import torch, which we need to load the object at {path = }: {json_item = }" 

168 raise ImportError(err_msg) from e 

169 

170 return torch.tensor( 

171 # json_item is JSONitem but load_array expects narrower types; runtime check is in LoaderHandler.check 

172 load_array(json_item), # type: ignore[no-matching-overload, call-overload] 

173 dtype=TORCH_DTYPE_MAP[json_item["dtype"]], # type: ignore[index, call-overload] 

174 ) 

175 

176 

177# NOTE: there are type ignores on the loaders, since the type checking should be the responsibility of the check function 

178 

179LOADER_MAP_LOCK = threading.Lock() 

180 

181LOADER_MAP: dict[str, LoaderHandler] = { 

182 lh.uid: lh 

183 for lh in [ 

184 # array external 

185 LoaderHandler( 

186 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

187 isinstance(json_item, typing.Mapping) 

188 and _FORMAT_KEY in json_item 

189 and json_item[_FORMAT_KEY].startswith("numpy.ndarray") 

190 # and json_item["data"].dtype.name == json_item["dtype"] 

191 # and tuple(json_item["data"].shape) == tuple(json_item["shape"]) 

192 ), 

193 load=lambda json_item, path=None, z=None: np.array( # type: ignore[misc] 

194 load_array(json_item), dtype=np.dtype(json_item["dtype"]) 

195 ), 

196 uid="numpy.ndarray", 

197 source_pckg="zanj", 

198 desc="numpy.ndarray loader", 

199 ), 

200 LoaderHandler( 

201 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

202 isinstance(json_item, typing.Mapping) 

203 and _FORMAT_KEY in json_item 

204 and json_item[_FORMAT_KEY].startswith("torch.Tensor") 

205 # and json_item["data"].dtype.name == json_item["dtype"] 

206 # and tuple(json_item["data"].shape) == tuple(json_item["shape"]) 

207 ), 

208 load=_torch_loaderhandler_load, 

209 uid="torch.Tensor", 

210 source_pckg="zanj", 

211 desc="torch.Tensor loader", 

212 ), 

213 # pandas 

214 LoaderHandler( 

215 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

216 isinstance(json_item, typing.Mapping) 

217 and _FORMAT_KEY in json_item 

218 and json_item[_FORMAT_KEY].startswith("pandas.DataFrame") 

219 and "data" in json_item 

220 and isinstance(json_item["data"], typing.Sequence) 

221 ), 

222 load=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

223 pandas_DataFrame(json_item["data"]) 

224 # if there is no data, load just the columns (this is for empty dataframes) 

225 if json_item["data"] 

226 else pandas_DataFrame(columns=json_item.get("columns")) 

227 ), 

228 uid="pandas.DataFrame", 

229 source_pckg="zanj", 

230 desc="pandas.DataFrame loader", 

231 ), 

232 # polars 

233 LoaderHandler( 

234 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

235 isinstance(json_item, typing.Mapping) 

236 and _FORMAT_KEY in json_item 

237 and json_item[_FORMAT_KEY].startswith("polars.DataFrame") 

238 and "data" in json_item 

239 and isinstance(json_item["data"], typing.Sequence) 

240 ), 

241 load=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

242 polars_DataFrame(json_item["data"]) 

243 if json_item["data"] 

244 else polars_DataFrame( 

245 schema={col: str for col in json_item.get("columns", [])} 

246 ) 

247 ), 

248 uid="polars.DataFrame", 

249 source_pckg="zanj", 

250 desc="polars.DataFrame loader", 

251 ), 

252 # list/tuple external 

253 LoaderHandler( 

254 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

255 isinstance(json_item, typing.Mapping) 

256 and _FORMAT_KEY in json_item 

257 and json_item[_FORMAT_KEY].startswith("list") 

258 and "data" in json_item 

259 and isinstance(json_item["data"], typing.Sequence) 

260 ), 

261 load=lambda json_item, path=None, z=None: [ # type: ignore[misc, arg-type] 

262 load_item_recursive(x, path, z) # type: ignore[arg-type] 

263 for x in json_item["data"] 

264 ], 

265 uid="list", 

266 source_pckg="zanj", 

267 desc="list loader, for externals", 

268 ), 

269 LoaderHandler( 

270 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

271 isinstance(json_item, typing.Mapping) 

272 and _FORMAT_KEY in json_item 

273 and json_item[_FORMAT_KEY].startswith("tuple") 

274 and "data" in json_item 

275 and isinstance(json_item["data"], typing.Sequence) 

276 ), 

277 load=lambda json_item, path=None, z=None: tuple( # type: ignore[misc, arg-type] 

278 [load_item_recursive(x, path, z) for x in json_item["data"]] # type: ignore[arg-type] 

279 ), 

280 uid="tuple", 

281 source_pckg="zanj", 

282 desc="tuple loader, for externals", 

283 ), 

284 ] 

285} 

286 

287 

288def register_loader_handler(handler: LoaderHandler): 

289 """register a custom loader handler""" 

290 global LOADER_MAP, LOADER_MAP_LOCK 

291 with LOADER_MAP_LOCK: 

292 LOADER_MAP[handler.uid] = handler 

293 

294 

295def get_item_loader( 

296 json_item: JSONitem, 

297 path: ObjectPath, 

298 zanj: _ZANJ_pre | None = None, 

299 error_mode: ErrorMode = ErrorMode.WARN, 

300 # lh_map: dict[str, LoaderHandler] = LOADER_MAP, 

301) -> LoaderHandler | None: 

302 """get the loader for a json item""" 

303 global LOADER_MAP 

304 

305 # check if we recognize the format 

306 if isinstance(json_item, typing.Mapping) and _FORMAT_KEY in json_item: 

307 if not isinstance(json_item[_FORMAT_KEY], str): # type: ignore[index] 

308 raise TypeError( 

309 f"invalid __muutils_format__ type '{type(json_item[_FORMAT_KEY])}' in '{path=}': '{json_item[_FORMAT_KEY] = }'" # type: ignore[index] 

310 ) 

311 if json_item[_FORMAT_KEY] in LOADER_MAP: # type: ignore[index] 

312 return LOADER_MAP[json_item[_FORMAT_KEY]] # type: ignore[index] 

313 

314 # if we dont recognize the format, try to find a loader that can handle it 

315 for key, lh in LOADER_MAP.items(): 

316 if lh.check(json_item, path, zanj): 

317 return lh 

318 

319 # if we still dont have a loader, return None 

320 return None 

321 

322 

323def load_item_recursive( 

324 json_item: JSONitem, 

325 path: ObjectPath, 

326 zanj: _ZANJ_pre | None = None, 

327 error_mode: ErrorMode = ErrorMode.WARN, 

328 allow_not_loading: bool = True, 

329) -> Any: 

330 lh: LoaderHandler | None = get_item_loader( 

331 json_item=json_item, 

332 path=path, 

333 zanj=zanj, 

334 error_mode=error_mode, 

335 # lh_map=lh_map, 

336 ) 

337 

338 if lh is not None: 

339 # special case for serializable dataclasses 

340 if ( 

341 isinstance(json_item, typing.Mapping) 

342 and (_FORMAT_KEY in json_item) 

343 and ("SerializableDataclass" in json_item[_FORMAT_KEY]) # type: ignore[operator] 

344 ): 

345 # why this horribleness? 

346 # SerializableDataclass, if it has a field `x` which is also a SerializableDataclass, will automatically call `x.__class__.load()` 

347 # However, we need to load things in containers, as well as arrays 

348 processed_json_item: dict = { 

349 key: ( 

350 val 

351 if ( 

352 isinstance(val, typing.Mapping) 

353 and (_FORMAT_KEY in val) 

354 and ("SerializableDataclass" in val[_FORMAT_KEY]) # type: ignore[operator, index] 

355 ) 

356 else load_item_recursive( 

357 json_item=val, # type: ignore[arg-type] 

358 path=tuple(path) + (key,), # type: ignore[arg-type] 

359 zanj=zanj, 

360 error_mode=error_mode, 

361 ) 

362 ) 

363 for key, val in json_item.items() 

364 } 

365 

366 return lh.load(processed_json_item, path, zanj) 

367 

368 else: 

369 return lh.load(json_item, path, zanj) 

370 else: 

371 if isinstance(json_item, dict): 

372 return { 

373 key: load_item_recursive( 

374 # ty doesn't narrow JSONitem to dict after isinstance check; string key indexing is safe here 

375 json_item=json_item[key], # type: ignore[invalid-argument-type, call-overload] 

376 path=tuple(path) + (key,), 

377 zanj=zanj, 

378 error_mode=error_mode, 

379 # lh_map=lh_map, 

380 ) 

381 for key in json_item 

382 } 

383 elif isinstance(json_item, list): 

384 return [ 

385 load_item_recursive( 

386 json_item=x, 

387 path=tuple(path) + (i,), 

388 zanj=zanj, 

389 error_mode=error_mode, 

390 # lh_map=lh_map, 

391 ) 

392 for i, x in enumerate(json_item) 

393 ] 

394 elif isinstance(json_item, (str, int, float, bool, type(None))): 

395 return json_item 

396 else: 

397 if allow_not_loading: 

398 return json_item 

399 else: 

400 raise ValueError( 

401 f"unknown type {type(json_item)} at {path}\n{json_item}" 

402 ) 

403 

404 

405def _each_item_in_externals( 

406 externals: dict[str, ExternalItem], 

407 json_data: JSONitem, 

408) -> typing.Iterable[tuple[str, ExternalItem, Any, ObjectPath]]: 

409 """note that you MUST use the raw iterator, dont try to turn into a list or something""" 

410 

411 sorted_externals: list[tuple[str, ExternalItem]] = sorted( 

412 externals.items(), key=lambda x: len(x[1].path) 

413 ) 

414 

415 for ext_path, ext_item in sorted_externals: 

416 # get the path to the item 

417 path: ObjectPath = tuple(ext_item.path) 

418 assert len(path) > 0 

419 assert all(isinstance(key, (str, int)) for key in path), ( 

420 f"improper types in path {path=}" 

421 ) 

422 # get the item 

423 item = json_data 

424 for i, key in enumerate(path): 

425 try: 

426 # ignores in this block are because we cannot know the type is indexable in static analysis 

427 # but, we check the types in the line below 

428 external_unloaded: bool = _populate_externals_error_checking(key, item) 

429 if external_unloaded: 

430 item = item["data"] # type: ignore 

431 item = item[key] # type: ignore[index] 

432 

433 except (KeyError, IndexError, TypeError) as e: 

434 raise KeyError( 

435 f"could not find '{key = }' at path '{ext_path = }', specifically at index '{i = }'", 

436 f"'{type(item) =}', '{len(item) = }', '{item.keys() if isinstance(item, dict) else None = }'", # type: ignore 

437 f"From error: {e = }", 

438 f"\n\n{item=}\n\n{ext_item=}", 

439 ) from e 

440 

441 yield (ext_path, ext_item, item, path) 

442 

443 

444class LoadedZANJ: 

445 """for loading a zanj file""" 

446 

447 def __init__( 

448 self, 

449 path: str | Path, 

450 zanj: _ZANJ_pre, 

451 ) -> None: 

452 # path and zanj object 

453 self._path: str = str(path) 

454 self._zanj: _ZANJ_pre = zanj 

455 

456 # load zip file 

457 _zipf: zipfile.ZipFile = zipfile.ZipFile(file=self._path, mode="r") 

458 

459 # load data 

460 self._meta: JSONdict = json.load(_zipf.open(ZANJ_META, "r")) 

461 self._json_data: JSONitem = json.load(_zipf.open(ZANJ_MAIN, "r")) 

462 

463 # read externals 

464 self._externals: dict[str, ExternalItem] = dict() 

465 for fname, ext_item in self._meta["externals_info"].items(): # type: ignore 

466 item_type: str = ext_item["item_type"] # type: ignore 

467 with _zipf.open(fname, "r") as fp: 

468 self._externals[fname] = ExternalItem( 

469 item_type=item_type, # type: ignore[arg-type] 

470 data=GET_EXTERNAL_LOAD_FUNC(item_type)(self, fp), 

471 path=ext_item["path"], # type: ignore 

472 ) 

473 

474 # close zip file 

475 _zipf.close() 

476 del _zipf 

477 

478 def populate_externals(self) -> None: 

479 """put all external items into the main json data""" 

480 

481 # loop over once, populating the externals only 

482 for ext_path, ext_item, item, path in _each_item_in_externals( 

483 self._externals, self._json_data 

484 ): 

485 # replace the item with the external item 

486 assert _REF_KEY in item # type: ignore 

487 assert item[_REF_KEY] == ext_path # type: ignore 

488 item["data"] = ext_item.data # type: ignore