Coverage for zanj/loading.py: 79%
132 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 11:17 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 11:17 -0600
1from __future__ import annotations
3import json
4import threading
5import typing
6import zipfile
7from dataclasses import dataclass
8from pathlib import Path
9from typing import Any, Callable
11import numpy as np
13try:
14 import pandas as pd # type: ignore[import]
16 pandas_DataFrame = pd.DataFrame # type: ignore[no-redef]
17except ImportError:
19 class pandas_DataFrame: # type: ignore[no-redef]
20 def __init__(self, *args, **kwargs):
21 raise ImportError("cannot load pandas DataFrame, pandas is not installed")
24from muutils.errormode import ErrorMode
25from muutils.json_serialize.array import load_array
26from muutils.json_serialize.json_serialize import ObjectPath
27from muutils.json_serialize.util import (
28 _FORMAT_KEY,
29 _REF_KEY,
30 JSONdict,
31 JSONitem,
32 safe_getsource,
33 string_as_lines,
34)
36from zanj.externals import (
37 GET_EXTERNAL_LOAD_FUNC,
38 ZANJ_MAIN,
39 ZANJ_META,
40 ExternalItem,
41 _ZANJ_pre,
42)
44# pylint: disable=protected-access, dangerous-default-value
47def _populate_externals_error_checking(key, item) -> bool:
48 """checks that the key is valid for the item. returns "True" we need to augment the path by accessing the "data" element"""
50 # special case for not fully loaded external item which we still need to populate
51 if isinstance(item, typing.Mapping):
52 if (_FORMAT_KEY in item) and item[_FORMAT_KEY].endswith(":external"):
53 if "data" in item:
54 return True
55 else:
56 raise KeyError(
57 f"expected an external item, but could not find data: {list(item.keys())}",
58 f"{item[_FORMAT_KEY]}, {len(item) = }, {item.get('data', '<EMPTY>') = }",
59 )
61 # if it's a list, make sure the key is an int and that it's in range
62 if isinstance(item, typing.Sequence):
63 if not isinstance(key, int):
64 raise TypeError(f"improper type: '{type(key) = }', expected int")
65 if key >= len(item):
66 raise IndexError(f"index out of range: '{key = }', expected < {len(item)}")
68 # if it's a dict, make sure that the key is a str and that it's in the dict
69 elif isinstance(item, typing.Mapping):
70 if not isinstance(key, str):
71 raise TypeError(f"improper type: '{type(key) = }', expected str")
72 if key not in item:
73 raise KeyError(f"key not in dict: '{key = }', expected in {item.keys()}")
75 # otherwise, raise an error
76 else:
77 raise TypeError(f"improper type: '{type(item) = }', expected dict or list")
79 return False
82@dataclass
83class LoaderHandler:
84 """handler for loading an object from a json file or a ZANJ archive"""
86 # TODO: add a separate "asserts" function?
87 # right now, any asserts must happen in `check` or `load` which is annoying with lambdas
89 # (json_data, path) -> whether to use this handler
90 check: Callable[[JSONitem, ObjectPath, _ZANJ_pre], bool]
91 # function to load the object (json_data, path) -> loaded_obj
92 load: Callable[[JSONitem, ObjectPath, _ZANJ_pre], Any]
93 # unique identifier for the handler, saved in __muutils_format__ field
94 uid: str
95 # source package of the handler -- note that this might be overridden by ZANJ
96 source_pckg: str
97 # priority of the handler, defaults are all 0
98 priority: int = 0
99 # description of the handler
100 desc: str = "(no description)"
102 def serialize(self) -> JSONdict:
103 """serialize the handler info"""
104 return {
105 # get the code and doc of the check function
106 "check": {
107 "code": safe_getsource(self.check),
108 "doc": string_as_lines(self.check.__doc__),
109 },
110 # get the code and doc of the load function
111 "load": {
112 "code": safe_getsource(self.load),
113 "doc": string_as_lines(self.load.__doc__),
114 },
115 # get the uid, source_pckg, priority, and desc
116 "uid": str(self.uid),
117 "source_pckg": str(self.source_pckg),
118 "priority": int(self.priority),
119 "desc": str(self.desc),
120 }
122 @classmethod
123 def from_formattedclass(cls, fc: type, priority: int = 0):
124 """create a loader from a class with `serialize`, `load` methods and `__muutils_format__` attribute"""
125 assert hasattr(fc, "serialize")
126 assert callable(fc.serialize) # type: ignore
127 assert hasattr(fc, "load")
128 assert callable(fc.load) # type: ignore
129 assert hasattr(fc, _FORMAT_KEY)
130 assert isinstance(fc.__muutils_format__, str) # type: ignore
132 return cls(
133 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
134 json_item[_FORMAT_KEY] == fc.__muutils_format__ # type: ignore[attr-defined]
135 ),
136 load=lambda json_item, path=None, z=None: fc.load(json_item, path, z), # type: ignore[misc]
137 uid=fc.__muutils_format__, # type: ignore[attr-defined]
138 source_pckg=str(fc.__module__),
139 priority=priority,
140 desc=f"formatted class loader for {fc.__name__}",
141 )
144# TODO: how can we type hint this without actually importing torch?
145def _torch_loaderhandler_load(
146 json_item: JSONitem,
147 path: ObjectPath,
148 z: _ZANJ_pre | None = None,
149):
150 """load a torch tensor from a json item"""
151 try:
152 import torch # type: ignore[import-not-found]
153 from muutils.tensor_utils import TORCH_DTYPE_MAP
154 except ImportError as e:
155 err_msg: str = f"could not import torch, which we need to load the object at {path = }: {json_item = }"
156 raise ImportError(err_msg) from e
158 return torch.tensor(
159 load_array(json_item),
160 dtype=TORCH_DTYPE_MAP[json_item["dtype"]], # type: ignore[index, call-overload]
161 )
164# NOTE: there are type ignores on the loaders, since the type checking should be the responsibility of the check function
166LOADER_MAP_LOCK = threading.Lock()
168LOADER_MAP: dict[str, LoaderHandler] = {
169 lh.uid: lh
170 for lh in [
171 # array external
172 LoaderHandler(
173 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
174 isinstance(json_item, typing.Mapping)
175 and _FORMAT_KEY in json_item
176 and json_item[_FORMAT_KEY].startswith("numpy.ndarray")
177 # and json_item["data"].dtype.name == json_item["dtype"]
178 # and tuple(json_item["data"].shape) == tuple(json_item["shape"])
179 ),
180 load=lambda json_item, path=None, z=None: np.array( # type: ignore[misc]
181 load_array(json_item), dtype=np.dtype(json_item["dtype"])
182 ),
183 uid="numpy.ndarray",
184 source_pckg="zanj",
185 desc="numpy.ndarray loader",
186 ),
187 LoaderHandler(
188 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
189 isinstance(json_item, typing.Mapping)
190 and _FORMAT_KEY in json_item
191 and json_item[_FORMAT_KEY].startswith("torch.Tensor")
192 # and json_item["data"].dtype.name == json_item["dtype"]
193 # and tuple(json_item["data"].shape) == tuple(json_item["shape"])
194 ),
195 load=_torch_loaderhandler_load,
196 uid="torch.Tensor",
197 source_pckg="zanj",
198 desc="torch.Tensor loader",
199 ),
200 # pandas
201 LoaderHandler(
202 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
203 isinstance(json_item, typing.Mapping)
204 and _FORMAT_KEY in json_item
205 and json_item[_FORMAT_KEY].startswith("pandas.DataFrame")
206 and "data" in json_item
207 and isinstance(json_item["data"], typing.Sequence)
208 ),
209 load=lambda json_item, path=None, z=None: pandas_DataFrame( # type: ignore[misc]
210 json_item["data"]
211 ),
212 uid="pandas.DataFrame",
213 source_pckg="zanj",
214 desc="pandas.DataFrame loader",
215 ),
216 # list/tuple external
217 LoaderHandler(
218 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
219 isinstance(json_item, typing.Mapping)
220 and _FORMAT_KEY in json_item
221 and json_item[_FORMAT_KEY].startswith("list")
222 and "data" in json_item
223 and isinstance(json_item["data"], typing.Sequence)
224 ),
225 load=lambda json_item, path=None, z=None: [ # type: ignore[misc]
226 load_item_recursive(x, path, z) for x in json_item["data"]
227 ],
228 uid="list",
229 source_pckg="zanj",
230 desc="list loader, for externals",
231 ),
232 LoaderHandler(
233 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
234 isinstance(json_item, typing.Mapping)
235 and _FORMAT_KEY in json_item
236 and json_item[_FORMAT_KEY].startswith("tuple")
237 and "data" in json_item
238 and isinstance(json_item["data"], typing.Sequence)
239 ),
240 load=lambda json_item, path=None, z=None: tuple( # type: ignore[misc]
241 [load_item_recursive(x, path, z) for x in json_item["data"]]
242 ),
243 uid="tuple",
244 source_pckg="zanj",
245 desc="tuple loader, for externals",
246 ),
247 ]
248}
251def register_loader_handler(handler: LoaderHandler):
252 """register a custom loader handler"""
253 global LOADER_MAP, LOADER_MAP_LOCK
254 with LOADER_MAP_LOCK:
255 LOADER_MAP[handler.uid] = handler
258def get_item_loader(
259 json_item: JSONitem,
260 path: ObjectPath,
261 zanj: _ZANJ_pre | None = None,
262 error_mode: ErrorMode = ErrorMode.WARN,
263 # lh_map: dict[str, LoaderHandler] = LOADER_MAP,
264) -> LoaderHandler | None:
265 """get the loader for a json item"""
266 global LOADER_MAP
268 # check if we recognize the format
269 if isinstance(json_item, typing.Mapping) and _FORMAT_KEY in json_item:
270 if not isinstance(json_item[_FORMAT_KEY], str):
271 raise TypeError(
272 f"invalid __muutils_format__ type '{type(json_item[_FORMAT_KEY])}' in '{path=}': '{json_item[_FORMAT_KEY] = }'"
273 )
274 if json_item[_FORMAT_KEY] in LOADER_MAP:
275 return LOADER_MAP[json_item[_FORMAT_KEY]] # type: ignore[index]
277 # if we dont recognize the format, try to find a loader that can handle it
278 for key, lh in LOADER_MAP.items():
279 if lh.check(json_item, path, zanj):
280 return lh
282 # if we still dont have a loader, return None
283 return None
286def load_item_recursive(
287 json_item: JSONitem,
288 path: ObjectPath,
289 zanj: _ZANJ_pre | None = None,
290 error_mode: ErrorMode = ErrorMode.WARN,
291 allow_not_loading: bool = True,
292) -> Any:
293 lh: LoaderHandler | None = get_item_loader(
294 json_item=json_item,
295 path=path,
296 zanj=zanj,
297 error_mode=error_mode,
298 # lh_map=lh_map,
299 )
301 if lh is not None:
302 # special case for serializable dataclasses
303 if (
304 isinstance(json_item, typing.Mapping)
305 and (_FORMAT_KEY in json_item)
306 and ("SerializableDataclass" in json_item[_FORMAT_KEY]) # type: ignore[operator]
307 ):
308 # why this horribleness?
309 # SerializableDataclass, if it has a field `x` which is also a SerializableDataclass, will automatically call `x.__class__.load()`
310 # However, we need to load things in containers, as well as arrays
311 processed_json_item: dict = {
312 key: (
313 val
314 if (
315 isinstance(val, typing.Mapping)
316 and (_FORMAT_KEY in val)
317 and ("SerializableDataclass" in val[_FORMAT_KEY])
318 )
319 else load_item_recursive(
320 json_item=val,
321 path=tuple(path) + (key,),
322 zanj=zanj,
323 error_mode=error_mode,
324 )
325 )
326 for key, val in json_item.items()
327 }
329 return lh.load(processed_json_item, path, zanj)
331 else:
332 return lh.load(json_item, path, zanj)
333 else:
334 if isinstance(json_item, dict):
335 return {
336 key: load_item_recursive(
337 json_item=json_item[key],
338 path=tuple(path) + (key,),
339 zanj=zanj,
340 error_mode=error_mode,
341 # lh_map=lh_map,
342 )
343 for key in json_item
344 }
345 elif isinstance(json_item, list):
346 return [
347 load_item_recursive(
348 json_item=x,
349 path=tuple(path) + (i,),
350 zanj=zanj,
351 error_mode=error_mode,
352 # lh_map=lh_map,
353 )
354 for i, x in enumerate(json_item)
355 ]
356 elif isinstance(json_item, (str, int, float, bool, type(None))):
357 return json_item
358 else:
359 if allow_not_loading:
360 return json_item
361 else:
362 raise ValueError(
363 f"unknown type {type(json_item)} at {path}\n{json_item}"
364 )
367def _each_item_in_externals(
368 externals: dict[str, ExternalItem],
369 json_data: JSONitem,
370) -> typing.Iterable[tuple[str, ExternalItem, Any, ObjectPath]]:
371 """note that you MUST use the raw iterator, dont try to turn into a list or something"""
373 sorted_externals: list[tuple[str, ExternalItem]] = sorted(
374 externals.items(), key=lambda x: len(x[1].path)
375 )
377 for ext_path, ext_item in sorted_externals:
378 # get the path to the item
379 path: ObjectPath = tuple(ext_item.path)
380 assert len(path) > 0
381 assert all(isinstance(key, (str, int)) for key in path), (
382 f"improper types in path {path=}"
383 )
384 # get the item
385 item = json_data
386 for i, key in enumerate(path):
387 try:
388 # ignores in this block are because we cannot know the type is indexable in static analysis
389 # but, we check the types in the line below
390 external_unloaded: bool = _populate_externals_error_checking(key, item)
391 if external_unloaded:
392 item = item["data"] # type: ignore
393 item = item[key] # type: ignore[index]
395 except (KeyError, IndexError, TypeError) as e:
396 raise KeyError(
397 f"could not find '{key = }' at path '{ext_path = }', specifically at index '{i = }'",
398 f"'{type(item) =}', '{len(item) = }', '{item.keys() if isinstance(item, dict) else None = }'", # type: ignore
399 f"From error: {e = }",
400 f"\n\n{item=}\n\n{ext_item=}",
401 ) from e
403 yield (ext_path, ext_item, item, path)
406class LoadedZANJ:
407 """for loading a zanj file"""
409 def __init__(
410 self,
411 path: str | Path,
412 zanj: _ZANJ_pre,
413 ) -> None:
414 # path and zanj object
415 self._path: str = str(path)
416 self._zanj: _ZANJ_pre = zanj
418 # load zip file
419 _zipf: zipfile.ZipFile = zipfile.ZipFile(file=self._path, mode="r")
421 # load data
422 self._meta: JSONdict = json.load(_zipf.open(ZANJ_META, "r"))
423 self._json_data: JSONitem = json.load(_zipf.open(ZANJ_MAIN, "r"))
425 # read externals
426 self._externals: dict[str, ExternalItem] = dict()
427 for fname, ext_item in self._meta["externals_info"].items(): # type: ignore
428 item_type: str = ext_item["item_type"] # type: ignore
429 with _zipf.open(fname, "r") as fp:
430 self._externals[fname] = ExternalItem(
431 item_type=item_type, # type: ignore[arg-type]
432 data=GET_EXTERNAL_LOAD_FUNC(item_type)(self, fp),
433 path=ext_item["path"], # type: ignore
434 )
436 # close zip file
437 _zipf.close()
438 del _zipf
440 def populate_externals(self) -> None:
441 """put all external items into the main json data"""
443 # loop over once, populating the externals only
444 for ext_path, ext_item, item, path in _each_item_in_externals(
445 self._externals, self._json_data
446 ):
447 # replace the item with the external item
448 assert _REF_KEY in item # type: ignore
449 assert item[_REF_KEY] == ext_path # type: ignore
450 item["data"] = ext_item.data # type: ignore