Coverage for zanj / loading.py: 91%
139 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 19:37 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 19:37 -0700
1from __future__ import annotations
3import json
4import threading
5import typing
6import zipfile
7from dataclasses import dataclass
8from pathlib import Path
9from typing import Any, Callable
11import numpy as np
13try:
14 import pandas as pd # type: ignore[import]
16 pandas_DataFrame = pd.DataFrame # type: ignore[no-redef]
17except ImportError:
19 class pandas_DataFrame: # type: ignore[no-redef]
20 def __init__(self, *args, **kwargs):
21 raise ImportError("cannot load pandas DataFrame, pandas is not installed")
24try:
25 import polars as pl # type: ignore[import]
27 polars_DataFrame = pl.DataFrame # type: ignore[no-redef]
28except ImportError:
30 class polars_DataFrame: # type: ignore[no-redef]
31 def __init__(self, *args, **kwargs):
32 raise ImportError("cannot load polars DataFrame, polars is not installed")
35from muutils.errormode import ErrorMode
36from muutils.json_serialize.array import load_array
37from muutils.json_serialize.json_serialize import ObjectPath
39from zanj.consts import (
40 JSONdict,
41 JSONitem,
42 _FORMAT_KEY,
43 _REF_KEY,
44 safe_getsource,
45 string_as_lines,
46)
48from zanj.externals import (
49 GET_EXTERNAL_LOAD_FUNC,
50 ZANJ_MAIN,
51 ZANJ_META,
52 ExternalItem,
53 _ZANJ_pre,
54)
56# pylint: disable=protected-access, dangerous-default-value
59def _populate_externals_error_checking(key, item) -> bool:
60 """checks that the key is valid for the item. returns "True" we need to augment the path by accessing the "data" element"""
62 # special case for not fully loaded external item which we still need to populate
63 if isinstance(item, typing.Mapping):
64 if (_FORMAT_KEY in item) and item[_FORMAT_KEY].endswith(":external"):
65 if "data" in item:
66 return True
67 else:
68 raise KeyError(
69 f"expected an external item, but could not find data: {list(item.keys())}",
70 f"{item[_FORMAT_KEY]}, {len(item) = }, {item.get('data', '<EMPTY>') = }",
71 )
73 # if it's a list, make sure the key is an int and that it's in range
74 if isinstance(item, typing.Sequence):
75 if not isinstance(key, int):
76 raise TypeError(f"improper type: '{type(key) = }', expected int")
77 if key >= len(item):
78 raise IndexError(f"index out of range: '{key = }', expected < {len(item)}")
80 # if it's a dict, make sure that the key is a str and that it's in the dict
81 elif isinstance(item, typing.Mapping):
82 if not isinstance(key, str):
83 raise TypeError(f"improper type: '{type(key) = }', expected str")
84 if key not in item:
85 raise KeyError(f"key not in dict: '{key = }', expected in {item.keys()}")
87 # otherwise, raise an error
88 else:
89 raise TypeError(f"improper type: '{type(item) = }', expected dict or list")
91 return False
94@dataclass
95class LoaderHandler:
96 """handler for loading an object from a json file or a ZANJ archive"""
98 # TODO: add a separate "asserts" function?
99 # right now, any asserts must happen in `check` or `load` which is annoying with lambdas
101 # (json_data, path) -> whether to use this handler
102 check: Callable[[JSONitem, ObjectPath, _ZANJ_pre], bool]
103 # function to load the object (json_data, path) -> loaded_obj
104 load: Callable[[JSONitem, ObjectPath, _ZANJ_pre], Any]
105 # unique identifier for the handler, saved in __muutils_format__ field
106 uid: str
107 # source package of the handler -- note that this might be overridden by ZANJ
108 source_pckg: str
109 # priority of the handler, defaults are all 0
110 priority: int = 0
111 # description of the handler
112 desc: str = "(no description)"
114 def serialize(self) -> JSONdict:
115 """serialize the handler info"""
116 return {
117 # get the code and doc of the check function
118 "check": {
119 "code": safe_getsource(self.check),
120 "doc": string_as_lines(self.check.__doc__),
121 },
122 # get the code and doc of the load function
123 "load": {
124 "code": safe_getsource(self.load),
125 "doc": string_as_lines(self.load.__doc__),
126 },
127 # get the uid, source_pckg, priority, and desc
128 "uid": str(self.uid),
129 "source_pckg": str(self.source_pckg),
130 "priority": int(self.priority),
131 "desc": str(self.desc),
132 }
134 @classmethod
135 def from_formattedclass(cls, fc: type, priority: int = 0):
136 """create a loader from a class with `serialize`, `load` methods and `__muutils_format__` attribute"""
137 assert hasattr(fc, "serialize")
138 assert callable(fc.serialize) # type: ignore
139 assert hasattr(fc, "load")
140 assert callable(fc.load) # type: ignore
141 assert hasattr(fc, _FORMAT_KEY)
142 assert isinstance(fc.__muutils_format__, str) # type: ignore
144 return cls(
145 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
146 json_item[_FORMAT_KEY] == fc.__muutils_format__ # type: ignore[attr-defined]
147 ),
148 load=lambda json_item, path=None, z=None: fc.load(json_item, path, z), # type: ignore[misc]
149 uid=fc.__muutils_format__, # type: ignore[attr-defined]
150 source_pckg=str(fc.__module__),
151 priority=priority,
152 desc=f"formatted class loader for {fc.__name__}",
153 )
156# TODO: how can we type hint this without actually importing torch?
157def _torch_loaderhandler_load(
158 json_item: JSONitem,
159 path: ObjectPath,
160 z: _ZANJ_pre | None = None,
161):
162 """load a torch tensor from a json item"""
163 try:
164 import torch # type: ignore[import-not-found]
165 from muutils.tensor_utils import TORCH_DTYPE_MAP
166 except ImportError as e:
167 err_msg: str = f"could not import torch, which we need to load the object at {path = }: {json_item = }"
168 raise ImportError(err_msg) from e
170 return torch.tensor(
171 # json_item is JSONitem but load_array expects narrower types; runtime check is in LoaderHandler.check
172 load_array(json_item), # type: ignore[no-matching-overload, call-overload]
173 dtype=TORCH_DTYPE_MAP[json_item["dtype"]], # type: ignore[index, call-overload]
174 )
177# NOTE: there are type ignores on the loaders, since the type checking should be the responsibility of the check function
179LOADER_MAP_LOCK = threading.Lock()
181LOADER_MAP: dict[str, LoaderHandler] = {
182 lh.uid: lh
183 for lh in [
184 # array external
185 LoaderHandler(
186 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
187 isinstance(json_item, typing.Mapping)
188 and _FORMAT_KEY in json_item
189 and json_item[_FORMAT_KEY].startswith("numpy.ndarray")
190 # and json_item["data"].dtype.name == json_item["dtype"]
191 # and tuple(json_item["data"].shape) == tuple(json_item["shape"])
192 ),
193 load=lambda json_item, path=None, z=None: np.array( # type: ignore[misc]
194 load_array(json_item), dtype=np.dtype(json_item["dtype"])
195 ),
196 uid="numpy.ndarray",
197 source_pckg="zanj",
198 desc="numpy.ndarray loader",
199 ),
200 LoaderHandler(
201 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
202 isinstance(json_item, typing.Mapping)
203 and _FORMAT_KEY in json_item
204 and json_item[_FORMAT_KEY].startswith("torch.Tensor")
205 # and json_item["data"].dtype.name == json_item["dtype"]
206 # and tuple(json_item["data"].shape) == tuple(json_item["shape"])
207 ),
208 load=_torch_loaderhandler_load,
209 uid="torch.Tensor",
210 source_pckg="zanj",
211 desc="torch.Tensor loader",
212 ),
213 # pandas
214 LoaderHandler(
215 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
216 isinstance(json_item, typing.Mapping)
217 and _FORMAT_KEY in json_item
218 and json_item[_FORMAT_KEY].startswith("pandas.DataFrame")
219 and "data" in json_item
220 and isinstance(json_item["data"], typing.Sequence)
221 ),
222 load=lambda json_item, path=None, z=None: ( # type: ignore[misc]
223 pandas_DataFrame(json_item["data"])
224 # if there is no data, load just the columns (this is for empty dataframes)
225 if json_item["data"]
226 else pandas_DataFrame(columns=json_item.get("columns"))
227 ),
228 uid="pandas.DataFrame",
229 source_pckg="zanj",
230 desc="pandas.DataFrame loader",
231 ),
232 # polars
233 LoaderHandler(
234 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
235 isinstance(json_item, typing.Mapping)
236 and _FORMAT_KEY in json_item
237 and json_item[_FORMAT_KEY].startswith("polars.DataFrame")
238 and "data" in json_item
239 and isinstance(json_item["data"], typing.Sequence)
240 ),
241 load=lambda json_item, path=None, z=None: ( # type: ignore[misc]
242 polars_DataFrame(json_item["data"])
243 if json_item["data"]
244 else polars_DataFrame(
245 schema={col: str for col in json_item.get("columns", [])}
246 )
247 ),
248 uid="polars.DataFrame",
249 source_pckg="zanj",
250 desc="polars.DataFrame loader",
251 ),
252 # list/tuple external
253 LoaderHandler(
254 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
255 isinstance(json_item, typing.Mapping)
256 and _FORMAT_KEY in json_item
257 and json_item[_FORMAT_KEY].startswith("list")
258 and "data" in json_item
259 and isinstance(json_item["data"], typing.Sequence)
260 ),
261 load=lambda json_item, path=None, z=None: [ # type: ignore[misc, arg-type]
262 load_item_recursive(x, path, z) # type: ignore[arg-type]
263 for x in json_item["data"]
264 ],
265 uid="list",
266 source_pckg="zanj",
267 desc="list loader, for externals",
268 ),
269 LoaderHandler(
270 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
271 isinstance(json_item, typing.Mapping)
272 and _FORMAT_KEY in json_item
273 and json_item[_FORMAT_KEY].startswith("tuple")
274 and "data" in json_item
275 and isinstance(json_item["data"], typing.Sequence)
276 ),
277 load=lambda json_item, path=None, z=None: tuple( # type: ignore[misc, arg-type]
278 [load_item_recursive(x, path, z) for x in json_item["data"]] # type: ignore[arg-type]
279 ),
280 uid="tuple",
281 source_pckg="zanj",
282 desc="tuple loader, for externals",
283 ),
284 ]
285}
288def register_loader_handler(handler: LoaderHandler):
289 """register a custom loader handler"""
290 global LOADER_MAP, LOADER_MAP_LOCK
291 with LOADER_MAP_LOCK:
292 LOADER_MAP[handler.uid] = handler
295def get_item_loader(
296 json_item: JSONitem,
297 path: ObjectPath,
298 zanj: _ZANJ_pre | None = None,
299 error_mode: ErrorMode = ErrorMode.WARN,
300 # lh_map: dict[str, LoaderHandler] = LOADER_MAP,
301) -> LoaderHandler | None:
302 """get the loader for a json item"""
303 global LOADER_MAP
305 # check if we recognize the format
306 if isinstance(json_item, typing.Mapping) and _FORMAT_KEY in json_item:
307 if not isinstance(json_item[_FORMAT_KEY], str): # type: ignore[index]
308 raise TypeError(
309 f"invalid __muutils_format__ type '{type(json_item[_FORMAT_KEY])}' in '{path=}': '{json_item[_FORMAT_KEY] = }'" # type: ignore[index]
310 )
311 if json_item[_FORMAT_KEY] in LOADER_MAP: # type: ignore[index]
312 return LOADER_MAP[json_item[_FORMAT_KEY]] # type: ignore[index]
314 # if we dont recognize the format, try to find a loader that can handle it
315 for key, lh in LOADER_MAP.items():
316 if lh.check(json_item, path, zanj):
317 return lh
319 # if we still dont have a loader, return None
320 return None
323def load_item_recursive(
324 json_item: JSONitem,
325 path: ObjectPath,
326 zanj: _ZANJ_pre | None = None,
327 error_mode: ErrorMode = ErrorMode.WARN,
328 allow_not_loading: bool = True,
329) -> Any:
330 lh: LoaderHandler | None = get_item_loader(
331 json_item=json_item,
332 path=path,
333 zanj=zanj,
334 error_mode=error_mode,
335 # lh_map=lh_map,
336 )
338 if lh is not None:
339 # special case for serializable dataclasses
340 if (
341 isinstance(json_item, typing.Mapping)
342 and (_FORMAT_KEY in json_item)
343 and ("SerializableDataclass" in json_item[_FORMAT_KEY]) # type: ignore[operator]
344 ):
345 # why this horribleness?
346 # SerializableDataclass, if it has a field `x` which is also a SerializableDataclass, will automatically call `x.__class__.load()`
347 # However, we need to load things in containers, as well as arrays
348 processed_json_item: dict = {
349 key: (
350 val
351 if (
352 isinstance(val, typing.Mapping)
353 and (_FORMAT_KEY in val)
354 and ("SerializableDataclass" in val[_FORMAT_KEY]) # type: ignore[operator, index]
355 )
356 else load_item_recursive(
357 json_item=val, # type: ignore[arg-type]
358 path=tuple(path) + (key,), # type: ignore[arg-type]
359 zanj=zanj,
360 error_mode=error_mode,
361 )
362 )
363 for key, val in json_item.items()
364 }
366 return lh.load(processed_json_item, path, zanj)
368 else:
369 return lh.load(json_item, path, zanj)
370 else:
371 if isinstance(json_item, dict):
372 return {
373 key: load_item_recursive(
374 # ty doesn't narrow JSONitem to dict after isinstance check; string key indexing is safe here
375 json_item=json_item[key], # type: ignore[invalid-argument-type, call-overload]
376 path=tuple(path) + (key,),
377 zanj=zanj,
378 error_mode=error_mode,
379 # lh_map=lh_map,
380 )
381 for key in json_item
382 }
383 elif isinstance(json_item, list):
384 return [
385 load_item_recursive(
386 json_item=x,
387 path=tuple(path) + (i,),
388 zanj=zanj,
389 error_mode=error_mode,
390 # lh_map=lh_map,
391 )
392 for i, x in enumerate(json_item)
393 ]
394 elif isinstance(json_item, (str, int, float, bool, type(None))):
395 return json_item
396 else:
397 if allow_not_loading:
398 return json_item
399 else:
400 raise ValueError(
401 f"unknown type {type(json_item)} at {path}\n{json_item}"
402 )
405def _each_item_in_externals(
406 externals: dict[str, ExternalItem],
407 json_data: JSONitem,
408) -> typing.Iterable[tuple[str, ExternalItem, Any, ObjectPath]]:
409 """note that you MUST use the raw iterator, dont try to turn into a list or something"""
411 sorted_externals: list[tuple[str, ExternalItem]] = sorted(
412 externals.items(), key=lambda x: len(x[1].path)
413 )
415 for ext_path, ext_item in sorted_externals:
416 # get the path to the item
417 path: ObjectPath = tuple(ext_item.path)
418 assert len(path) > 0
419 assert all(isinstance(key, (str, int)) for key in path), (
420 f"improper types in path {path=}"
421 )
422 # get the item
423 item = json_data
424 for i, key in enumerate(path):
425 try:
426 # ignores in this block are because we cannot know the type is indexable in static analysis
427 # but, we check the types in the line below
428 external_unloaded: bool = _populate_externals_error_checking(key, item)
429 if external_unloaded:
430 item = item["data"] # type: ignore
431 item = item[key] # type: ignore[index]
433 except (KeyError, IndexError, TypeError) as e:
434 raise KeyError(
435 f"could not find '{key = }' at path '{ext_path = }', specifically at index '{i = }'",
436 f"'{type(item) =}', '{len(item) = }', '{item.keys() if isinstance(item, dict) else None = }'", # type: ignore
437 f"From error: {e = }",
438 f"\n\n{item=}\n\n{ext_item=}",
439 ) from e
441 yield (ext_path, ext_item, item, path)
444class LoadedZANJ:
445 """for loading a zanj file"""
447 def __init__(
448 self,
449 path: str | Path,
450 zanj: _ZANJ_pre,
451 ) -> None:
452 # path and zanj object
453 self._path: str = str(path)
454 self._zanj: _ZANJ_pre = zanj
456 # load zip file
457 _zipf: zipfile.ZipFile = zipfile.ZipFile(file=self._path, mode="r")
459 # load data
460 self._meta: JSONdict = json.load(_zipf.open(ZANJ_META, "r"))
461 self._json_data: JSONitem = json.load(_zipf.open(ZANJ_MAIN, "r"))
463 # read externals
464 self._externals: dict[str, ExternalItem] = dict()
465 for fname, ext_item in self._meta["externals_info"].items(): # type: ignore
466 item_type: str = ext_item["item_type"] # type: ignore
467 with _zipf.open(fname, "r") as fp:
468 self._externals[fname] = ExternalItem(
469 item_type=item_type, # type: ignore[arg-type]
470 data=GET_EXTERNAL_LOAD_FUNC(item_type)(self, fp),
471 path=ext_item["path"], # type: ignore
472 )
474 # close zip file
475 _zipf.close()
476 del _zipf
478 def populate_externals(self) -> None:
479 """put all external items into the main json data"""
481 # loop over once, populating the externals only
482 for ext_path, ext_item, item, path in _each_item_in_externals(
483 self._externals, self._json_data
484 ):
485 # replace the item with the external item
486 assert _REF_KEY in item # type: ignore
487 assert item[_REF_KEY] == ext_path # type: ignore
488 item["data"] = ext_item.data # type: ignore