zanj
ZANJ
Overview
The ZANJ format is meant to be a way of saving arbitrary objects to disk, in a way that is flexible, allows keeping configuration and data together, and is human readable. It is very loosely inspired by HDF5 and the derived exdir format, and the implementation is inspired by npz files.
- You can take any
SerializableDataclassfrom the muutils library and save it to disk -- any large arrays or lists will be stored efficiently as external files in the zip archive, while the basic structure and metadata will be stored in readable JSON files. - You can also specify a special
ConfiguredModel, which inherits from atorch.nn.Modulewhich will let you save not just your model weights, but all required configuration information, plus any other metadata (like training logs) in a single file.
This library was originally a module in muutils
Installation
Available on PyPI as zanj
pip install zanj
Usage
You can find a runnable example of this in demo.ipynb
Saving a basic object
Any SerializableDataclass of basic types can be saved as zanj:
import numpy as np
import pandas as pd
from muutils.json_serialize import SerializableDataclass, serializable_dataclass, serializable_field
from zanj import ZANJ
@serializable_dataclass
class BasicZanj(SerializableDataclass):
a: str
q: int = 42
c: list[int] = serializable_field(default_factory=list)
# initialize a zanj reader/writer
zj = ZANJ()
# create an instance
instance: BasicZanj = BasicZanj("hello", 42, [1, 2, 3])
path: str = "tests/junk_data/path_to_save_instancezanj.zanj"
zj.save(instance, path)
recovered: BasicZanj = zj.read(path)
ZANJ will intelligently handle nested serializable dataclasses, numpy arrays, pytorch tensors, and pandas dataframes:
import torch
import pandas as pd
@serializable_dataclass
class Complicated(SerializableDataclass):
name: str
arr1: np.ndarray
arr2: np.ndarray
iris_data: pd.DataFrame
brain_data: pd.DataFrame
container: list[BasicZanj]
torch_tensor: torch.Tensor
For custom classes, you can specify a serialization_fn and loading_fn to handle the logic of converting to and from a json-serializable format:
@serializable_dataclass
class Complicated(SerializableDataclass):
name: str
device: torch.device = serializable_field(
serialization_fn=lambda self: str(self.device),
loading_fn=lambda data: torch.device(data["device"]),
)
Note that loading_fn takes the dictionary of the whole class -- this is in case you've stored data in multiple fields of the dict which are needed to reconstruct the object.
Saving Models
First, define a configuration class for your model. This class will hold the parameters for your model and any associated objects (like losses and optimizers). The configuration class should be a subclass of SerializableDataclass and use the serializable_field function to define fields that need special serialization.
Here's an example that defines a GPT-like model configuration:
from zanj.torchutil import ConfiguredModel, set_config_class
@serializable_dataclass
class MyNNConfig(SerializableDataclass):
input_dim: int
hidden_dim: int
output_dim: int
# store the activation function by name, reconstruct it by looking it up in torch.nn
act_fn: torch.nn.Module = serializable_field(
serialization_fn=lambda x: x.__name__,
loading_fn=lambda x: getattr(torch.nn, x["act_fn"]),
)
# same for the loss function
loss_kwargs: dict = serializable_field(default_factory=dict)
loss_factory: torch.nn.modules.loss._Loss = serializable_field(
default_factory=lambda: torch.nn.CrossEntropyLoss,
serialization_fn=lambda x: x.__name__,
loading_fn=lambda x: getattr(torch.nn, x["loss_factory"]),
)
loss = property(lambda self: self.loss_factory(**self.loss_kwargs))
Then, define your model class. It should be a subclass of ConfiguredModel, and use the set_config_class decorator to associate it with your configuration class. The __init__ method should take a single argument, which is an instance of your configuration class. You must also call the superclass __init__ method with the configuration instance.
@set_config_class(MyNNConfig)
class MyNN(ConfiguredModel[MyNNConfig]):
def __init__(self, config: MyNNConfig):
# call the superclass init!
# this will store the model in the zanj_model_config field
super().__init__(config)
# whatever you want here
self.net = torch.nn.Sequential(
torch.nn.Linear(config.input_dim, config.hidden_dim),
config.act_fn(),
torch.nn.Linear(config.hidden_dim, config.output_dim),
)
def forward(self, x):
return self.net(x)
You can now create instances of your model, save them to disk, and load them back into memory:
config = MyNNConfig(
input_dim=10,
hidden_dim=20,
output_dim=2,
act_fn=torch.nn.ReLU,
loss_kwargs=dict(reduction="mean"),
)
# create your model from the config, and save
model = MyNN(config)
fname = "tests/junk_data/path_to_save_modelzanj.zanj"
ZANJ().save(model, fname)
# load by calling the class method `read()`
loaded_model = MyNN.read(fname)
# zanj will actually infer the type of the object in the file
# -- and will warn you if you don't have the correct package installed
loaded_another_way = ZANJ().read(fname)
Configuration
When initializing a ZANJ object, you can specify some configuration info about saving, such as:
- thresholds for how big an array/table has to be before moving to external file
- compression settings
- error modes
- additional handlers for serialization
# how big an array or list (including pandas DataFrame) can be before moving it from the core JSON file
external_array_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_array_threshold
external_list_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_list_threshold
# compression settings passed to `zipfile` package
compress: bool | int = ZANJ_GLOBAL_DEFAULTS.compress
# for doing very cursed things in your own custom loading or serialization functions
custom_settings: dict[str, Any] | None = ZANJ_GLOBAL_DEFAULTS.custom_settings
# specify additional serialization handlers
handlers_pre: MonoTuple[SerializerHandler] = tuple()
handlers_default: MonoTuple[SerializerHandler] = DEFAULT_SERIALIZER_HANDLERS_ZANJ,
Implementation
The on-disk format is a file <filename>zanj.zanj is a zip file containing:
__zanj_meta__.json: a file containing zanj-specific metadata including:- system information
- installed packages
- information about external files
__zanj__.json: a file containing user-specified data- when an element is too big, it can be moved to an external file
.npyfor numpy arrays or torch tensors.jsonlfor pandas dataframes or large sequences
- list of external files stored in
__zanj_meta__.json - "$ref" key, specified in
_REF_KEYin muutils, will have value pointing to external file _FORMAT_KEYkey will detail an external format type
- when an element is too big, it can be moved to an external file
Comparison to other formats
| Format | Safe | Zero-copy | Lazy loading | No file size limit | Layout control | Flexibility | Bfloat16 |
|---|---|---|---|---|---|---|---|
| pickle (PyTorch) | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ✅ |
| H5 (Tensorflow) | ✅ | ❌ | ✅ | ✅ | ~ | ~ | ❌ |
| HDF5 | ✅ | ? | ✅ | ✅ | ~ | ✅ | ❌ |
| SavedModel (Tensorflow) | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ✅ |
| MsgPack (flax) | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ |
| Protobuf (ONNX) | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ |
| Cap'n'Proto | ✅ | ✅ | ~ | ✅ | ✅ | ~ | ❌ |
| Numpy (npy,npz) | ✅ | ? | ? | ❌ | ✅ | ❌ | ❌ |
| SafeTensors | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
| exdir | ✅ | ? | ? | ? | ? | ✅ | ❌ |
| ZANJ | ✅ | ❌ | ❌* | ✅ | ✅ | ✅ | ❌* |
- Safe: Can I use a file randomly downloaded and expect not to run arbitrary code ?
- Zero-copy: Does reading the file require more memory than the original file ?
- Lazy loading: Can I inspect the file without loading everything ? And loading only some tensors in it without scanning the whole file (distributed setting) ?
- Layout control: Lazy loading, is not necessarily enough since if the information about tensors is spread out in your file, then even if the information is lazily accessible you might have to access most of your file to read the available tensors (incurring many DISK -> RAM copies). Controlling the layout to keep fast access to single tensors is important.
- No file size limit: Is there a limit to the file size ?
- Flexibility: Can I save custom code in the format and be able to use it later with zero extra code ? (~ means we can store more than pure tensors, but no custom code)
- Bfloat16: Does the format support native bfloat16 (meaning no weird workarounds are necessary)? This is becoming increasingly important in the ML world.
* denotes this feature may be coming at a future date :)
(This table was stolen from safetensors)
1""" 2.. include:: ../README.md 3""" 4 5from __future__ import annotations 6 7from zanj.loading import register_loader_handler 8from zanj.zanj import ZANJ 9 10__all__ = [ 11 "register_loader_handler", 12 "ZANJ", 13 # modules 14 "externals", 15 "loading", 16 "serializing", 17 "torchutil", 18 "zanj", 19]
289def register_loader_handler(handler: LoaderHandler): 290 """register a custom loader handler""" 291 global LOADER_MAP, LOADER_MAP_LOCK 292 with LOADER_MAP_LOCK: 293 LOADER_MAP[handler.uid] = handler
register a custom loader handler
68class ZANJ(JsonSerializer): 69 """Zip up: Arrays in Numpy, JSON for everything else 70 71 given an arbitrary object, throw into a zip file, with arrays stored in .npy files, and everything else stored in a json file 72 73 (basically npz file with json) 74 75 - numpy (or pytorch) arrays are stored in paths according to their name and structure in the object 76 - everything else about the object is stored in a json file `zanj.json` in the root of the archive, via `muutils.json_serialize.JsonSerializer` 77 - metadata about ZANJ configuration, and optionally packages and versions, is stored in a `__zanj_meta__.json` file in the root of the archive 78 79 create a ZANJ-class via `z_cls = ZANJ().create(obj)`, and save/read instances of the object via `z_cls.save(obj, path)`, `z_cls.load(path)`. be sure to pass an **instance** of the object, to make sure that the attributes of the class can be correctly recognized 80 81 """ 82 83 def __init__( 84 self, 85 error_mode: ErrorMode = ZANJ_GLOBAL_DEFAULTS.error_mode, 86 internal_array_mode: ArrayMode = ZANJ_GLOBAL_DEFAULTS.internal_array_mode, 87 external_array_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_array_threshold, 88 external_list_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_list_threshold, 89 compress: bool | int = ZANJ_GLOBAL_DEFAULTS.compress, 90 custom_settings: dict[str, Any] | None = ZANJ_GLOBAL_DEFAULTS.custom_settings, 91 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 92 handlers_default: MonoTuple[ 93 SerializerHandler 94 ] = DEFAULT_SERIALIZER_HANDLERS_ZANJ, 95 ) -> None: 96 super().__init__( 97 array_mode=internal_array_mode, 98 error_mode=error_mode, 99 handlers_pre=handlers_pre, 100 handlers_default=handlers_default, 101 ) 102 103 self.external_array_threshold: int = external_array_threshold 104 self.external_list_threshold: int = external_list_threshold 105 self.custom_settings: dict = ( 106 custom_settings if custom_settings is not None else dict() 107 ) 108 109 # process compression to int if bool given 110 self.compress = compress 111 if isinstance(compress, bool): 112 if compress: 113 self.compress = zipfile.ZIP_DEFLATED 114 else: 115 self.compress = zipfile.ZIP_STORED 116 117 # create the externals, leave it empty 118 self._externals: dict[str, ExternalItem] = dict() 119 120 def externals_info(self) -> dict[str, dict[str, str | int | list[int]]]: 121 """return information about the current externals""" 122 output: dict[str, dict] = dict() 123 124 key: str 125 item: ExternalItem 126 for key, item in self._externals.items(): 127 data = item.data 128 output[key] = { 129 "item_type": item.item_type, 130 "path": item.path, 131 "type(data)": str(type(data)), 132 "len(data)": len(data), 133 } 134 135 if item.item_type == "ndarray": 136 output[key].update(arr_metadata(data)) 137 elif item.item_type.startswith("jsonl") and len(data) > 0: 138 output[key]["data[0]"] = data[0] 139 140 return { 141 key: val 142 for key, val in sorted(output.items(), key=lambda x: len(x[1]["path"])) 143 } 144 145 def meta(self) -> JSONitem: 146 """return the metadata of the ZANJ archive""" 147 148 serialization_handlers = {h.uid: h.serialize() for h in self.handlers} 149 load_handlers = {h.uid: h.serialize() for h in LOADER_MAP.values()} 150 151 return dict( 152 # configuration of this ZANJ instance 153 zanj_cfg=dict( 154 error_mode=str(self.error_mode), 155 array_mode=str(self.array_mode), 156 external_array_threshold=self.external_array_threshold, 157 external_list_threshold=self.external_list_threshold, 158 compress=self.compress, 159 serialization_handlers=serialization_handlers, 160 load_handlers=load_handlers, 161 ), 162 # system info (python, pip packages, torch & cuda, platform info, git info) 163 sysinfo=json_serialize(SysInfo.get_all(include=("python", "pytorch"))), 164 externals_info=self.externals_info(), 165 timestamp=time.time(), 166 ) 167 168 def save(self, obj: Any, file_path: str | Path) -> str: 169 """save the object to a ZANJ archive. returns the path to the archive""" 170 171 # adjust extension 172 file_path = str(file_path) 173 if not file_path.endswith(".zanj"): 174 file_path += ".zanj" 175 176 # make directory 177 dir_path: str = os.path.dirname(file_path) 178 if dir_path != "": 179 if not os.path.exists(dir_path): 180 os.makedirs(dir_path, exist_ok=False) 181 182 # clear the externals! 183 self._externals = dict() 184 185 # serialize the object -- this will populate self._externals 186 # TODO: calling self.json_serialize again here might be slow 187 json_data: JSONitem = self.json_serialize(self.json_serialize(obj)) 188 189 # open the zip file 190 zipf: zipfile.ZipFile = zipfile.ZipFile( 191 file=file_path, mode="w", compression=self.compress 192 ) 193 194 # store base json data and metadata 195 zipf.writestr( 196 ZANJ_META, 197 json.dumps( 198 self.json_serialize(self.meta()), 199 indent="\t", 200 ), 201 ) 202 zipf.writestr( 203 ZANJ_MAIN, 204 json.dumps( 205 json_data, 206 indent="\t", 207 ), 208 ) 209 210 # store externals 211 for key, (ext_type, ext_data, ext_path) in self._externals.items(): 212 # why force zip64? numpy.savez does it 213 with zipf.open(key, "w", force_zip64=True) as fp: 214 EXTERNAL_STORE_FUNCS[ext_type](self, fp, ext_data) 215 216 zipf.close() 217 218 # clear the externals, again 219 self._externals = dict() 220 221 return file_path 222 223 def read( 224 self, 225 file_path: Union[str, Path], 226 ) -> Any: 227 """load the object from a ZANJ archive 228 # TODO: load only some part of the zanj file by passing an ObjectPath 229 """ 230 file_path = Path(file_path) 231 if not file_path.exists(): 232 raise FileNotFoundError(f"file not found: {file_path}") 233 if not file_path.is_file(): 234 raise FileNotFoundError(f"not a file: {file_path}") 235 236 loaded_zanj: LoadedZANJ = LoadedZANJ( 237 path=file_path, 238 zanj=self, 239 ) 240 241 loaded_zanj.populate_externals() 242 243 return load_item_recursive( 244 loaded_zanj._json_data, 245 path=tuple(), 246 zanj=self, 247 error_mode=self.error_mode, 248 # lh_map=loader_handlers, 249 )
Zip up: Arrays in Numpy, JSON for everything else
given an arbitrary object, throw into a zip file, with arrays stored in .npy files, and everything else stored in a json file
(basically npz file with json)
- numpy (or pytorch) arrays are stored in paths according to their name and structure in the object
- everything else about the object is stored in a json file
zanj.jsonin the root of the archive, viamuutils.json_serialize.JsonSerializer - metadata about ZANJ configuration, and optionally packages and versions, is stored in a
__zanj_meta__.jsonfile in the root of the archive
create a ZANJ-class via z_cls = ZANJ().create(obj), and save/read instances of the object via z_cls.save(obj, path), z_cls.load(path). be sure to pass an instance of the object, to make sure that the attributes of the class can be correctly recognized
83 def __init__( 84 self, 85 error_mode: ErrorMode = ZANJ_GLOBAL_DEFAULTS.error_mode, 86 internal_array_mode: ArrayMode = ZANJ_GLOBAL_DEFAULTS.internal_array_mode, 87 external_array_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_array_threshold, 88 external_list_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_list_threshold, 89 compress: bool | int = ZANJ_GLOBAL_DEFAULTS.compress, 90 custom_settings: dict[str, Any] | None = ZANJ_GLOBAL_DEFAULTS.custom_settings, 91 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 92 handlers_default: MonoTuple[ 93 SerializerHandler 94 ] = DEFAULT_SERIALIZER_HANDLERS_ZANJ, 95 ) -> None: 96 super().__init__( 97 array_mode=internal_array_mode, 98 error_mode=error_mode, 99 handlers_pre=handlers_pre, 100 handlers_default=handlers_default, 101 ) 102 103 self.external_array_threshold: int = external_array_threshold 104 self.external_list_threshold: int = external_list_threshold 105 self.custom_settings: dict = ( 106 custom_settings if custom_settings is not None else dict() 107 ) 108 109 # process compression to int if bool given 110 self.compress = compress 111 if isinstance(compress, bool): 112 if compress: 113 self.compress = zipfile.ZIP_DEFLATED 114 else: 115 self.compress = zipfile.ZIP_STORED 116 117 # create the externals, leave it empty 118 self._externals: dict[str, ExternalItem] = dict()
120 def externals_info(self) -> dict[str, dict[str, str | int | list[int]]]: 121 """return information about the current externals""" 122 output: dict[str, dict] = dict() 123 124 key: str 125 item: ExternalItem 126 for key, item in self._externals.items(): 127 data = item.data 128 output[key] = { 129 "item_type": item.item_type, 130 "path": item.path, 131 "type(data)": str(type(data)), 132 "len(data)": len(data), 133 } 134 135 if item.item_type == "ndarray": 136 output[key].update(arr_metadata(data)) 137 elif item.item_type.startswith("jsonl") and len(data) > 0: 138 output[key]["data[0]"] = data[0] 139 140 return { 141 key: val 142 for key, val in sorted(output.items(), key=lambda x: len(x[1]["path"])) 143 }
return information about the current externals
145 def meta(self) -> JSONitem: 146 """return the metadata of the ZANJ archive""" 147 148 serialization_handlers = {h.uid: h.serialize() for h in self.handlers} 149 load_handlers = {h.uid: h.serialize() for h in LOADER_MAP.values()} 150 151 return dict( 152 # configuration of this ZANJ instance 153 zanj_cfg=dict( 154 error_mode=str(self.error_mode), 155 array_mode=str(self.array_mode), 156 external_array_threshold=self.external_array_threshold, 157 external_list_threshold=self.external_list_threshold, 158 compress=self.compress, 159 serialization_handlers=serialization_handlers, 160 load_handlers=load_handlers, 161 ), 162 # system info (python, pip packages, torch & cuda, platform info, git info) 163 sysinfo=json_serialize(SysInfo.get_all(include=("python", "pytorch"))), 164 externals_info=self.externals_info(), 165 timestamp=time.time(), 166 )
return the metadata of the ZANJ archive
168 def save(self, obj: Any, file_path: str | Path) -> str: 169 """save the object to a ZANJ archive. returns the path to the archive""" 170 171 # adjust extension 172 file_path = str(file_path) 173 if not file_path.endswith(".zanj"): 174 file_path += ".zanj" 175 176 # make directory 177 dir_path: str = os.path.dirname(file_path) 178 if dir_path != "": 179 if not os.path.exists(dir_path): 180 os.makedirs(dir_path, exist_ok=False) 181 182 # clear the externals! 183 self._externals = dict() 184 185 # serialize the object -- this will populate self._externals 186 # TODO: calling self.json_serialize again here might be slow 187 json_data: JSONitem = self.json_serialize(self.json_serialize(obj)) 188 189 # open the zip file 190 zipf: zipfile.ZipFile = zipfile.ZipFile( 191 file=file_path, mode="w", compression=self.compress 192 ) 193 194 # store base json data and metadata 195 zipf.writestr( 196 ZANJ_META, 197 json.dumps( 198 self.json_serialize(self.meta()), 199 indent="\t", 200 ), 201 ) 202 zipf.writestr( 203 ZANJ_MAIN, 204 json.dumps( 205 json_data, 206 indent="\t", 207 ), 208 ) 209 210 # store externals 211 for key, (ext_type, ext_data, ext_path) in self._externals.items(): 212 # why force zip64? numpy.savez does it 213 with zipf.open(key, "w", force_zip64=True) as fp: 214 EXTERNAL_STORE_FUNCS[ext_type](self, fp, ext_data) 215 216 zipf.close() 217 218 # clear the externals, again 219 self._externals = dict() 220 221 return file_path
save the object to a ZANJ archive. returns the path to the archive
223 def read( 224 self, 225 file_path: Union[str, Path], 226 ) -> Any: 227 """load the object from a ZANJ archive 228 # TODO: load only some part of the zanj file by passing an ObjectPath 229 """ 230 file_path = Path(file_path) 231 if not file_path.exists(): 232 raise FileNotFoundError(f"file not found: {file_path}") 233 if not file_path.is_file(): 234 raise FileNotFoundError(f"not a file: {file_path}") 235 236 loaded_zanj: LoadedZANJ = LoadedZANJ( 237 path=file_path, 238 zanj=self, 239 ) 240 241 loaded_zanj.populate_externals() 242 243 return load_item_recursive( 244 loaded_zanj._json_data, 245 path=tuple(), 246 zanj=self, 247 error_mode=self.error_mode, 248 # lh_map=loader_handlers, 249 )
load the object from a ZANJ archive
TODO: load only some part of the zanj file by passing an ObjectPath
Inherited Members
- muutils.json_serialize.json_serialize.JsonSerializer
- array_mode
- error_mode
- write_only_format
- handlers
- json_serialize
- hashify