Coverage for muutils\json_serialize\array.py: 64%
92 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-15 20:56 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-15 20:56 -0600
1"""this utilities module handles serialization and loading of numpy and torch arrays as json
3- `array_list_meta` is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability.
4- `array_b64_meta` is the most efficient, but is not human readable.
5- `external` is mostly for use in [`ZANJ`](https://github.com/mivanit/ZANJ)
7"""
9from __future__ import annotations
11import base64
12import typing
13import warnings
14from typing import Any, Iterable, Literal, Optional, Sequence
16import numpy as np
18from muutils.json_serialize.util import JSONitem
20# pylint: disable=unused-argument
22ArrayMode = Literal[
23 "list",
24 "array_list_meta",
25 "array_hex_meta",
26 "array_b64_meta",
27 "external",
28 "zero_dim",
29]
32def array_n_elements(arr) -> int: # type: ignore[name-defined]
33 """get the number of elements in an array"""
34 if isinstance(arr, np.ndarray):
35 return arr.size
36 elif str(type(arr)) == "<class 'torch.Tensor'>":
37 return arr.nelement()
38 else:
39 raise TypeError(f"invalid type: {type(arr)}")
42def arr_metadata(arr) -> dict[str, list[int] | str | int]:
43 """get metadata for a numpy array"""
44 return {
45 "shape": list(arr.shape),
46 "dtype": (
47 arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype)
48 ),
49 "n_elements": array_n_elements(arr),
50 }
53def serialize_array(
54 jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821
55 arr: np.ndarray,
56 path: str | Sequence[str | int],
57 array_mode: ArrayMode | None = None,
58) -> JSONitem:
59 """serialize a numpy or pytorch array in one of several modes
61 if the object is zero-dimensional, simply get the unique item
63 `array_mode: ArrayMode` can be one of:
64 - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`)
65 - `array_list_meta`: serialize dict with metadata, actual list under the key `data`
66 - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data`
67 - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data`
69 for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is:
70 ```
71 {
72 "__format__": <array_list_meta|array_hex_meta>,
73 "shape": arr.shape,
74 "dtype": str(arr.dtype),
75 "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>,
76 }
77 ```
79 # Parameters:
80 - `arr : Any` array to serialize
81 - `array_mode : ArrayMode` mode in which to serialize the array
82 (defaults to `None` and inheriting from `jser: JsonSerializer`)
84 # Returns:
85 - `JSONitem`
86 json serialized array
88 # Raises:
89 - `KeyError` : if the array mode is not valid
90 """
92 if array_mode is None:
93 array_mode = jser.array_mode
95 arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}"
96 arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr)
98 # handle zero-dimensional arrays
99 if len(arr.shape) == 0:
100 return {
101 "__format__": f"{arr_type}:zero_dim",
102 "data": arr.item(),
103 **arr_metadata(arr),
104 }
106 if array_mode == "array_list_meta":
107 return {
108 "__format__": f"{arr_type}:array_list_meta",
109 "data": arr_np.tolist(),
110 **arr_metadata(arr_np),
111 }
112 elif array_mode == "list":
113 return arr_np.tolist()
114 elif array_mode == "array_hex_meta":
115 return {
116 "__format__": f"{arr_type}:array_hex_meta",
117 "data": arr_np.tobytes().hex(),
118 **arr_metadata(arr_np),
119 }
120 elif array_mode == "array_b64_meta":
121 return {
122 "__format__": f"{arr_type}:array_b64_meta",
123 "data": base64.b64encode(arr_np.tobytes()).decode(),
124 **arr_metadata(arr_np),
125 }
126 else:
127 raise KeyError(f"invalid array_mode: {array_mode}")
130def infer_array_mode(arr: JSONitem) -> ArrayMode:
131 """given a serialized array, infer the mode
133 assumes the array was serialized via `serialize_array()`
134 """
135 if isinstance(arr, typing.Mapping):
136 fmt: str = arr.get("__format__", "")
137 if fmt.endswith(":array_list_meta"):
138 if not isinstance(arr["data"], Iterable):
139 raise ValueError(f"invalid list format: {type(arr['data']) = }\t{arr}")
140 return "array_list_meta"
141 elif fmt.endswith(":array_hex_meta"):
142 if not isinstance(arr["data"], str):
143 raise ValueError(f"invalid hex format: {type(arr['data']) = }\t{arr}")
144 return "array_hex_meta"
145 elif fmt.endswith(":array_b64_meta"):
146 if not isinstance(arr["data"], str):
147 raise ValueError(f"invalid b64 format: {type(arr['data']) = }\t{arr}")
148 return "array_b64_meta"
149 elif fmt.endswith(":external"):
150 return "external"
151 elif fmt.endswith(":zero_dim"):
152 return "zero_dim"
153 else:
154 raise ValueError(f"invalid format: {arr}")
155 elif isinstance(arr, list):
156 return "list"
157 else:
158 raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }")
161def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any:
162 """load a json-serialized array, infer the mode if not specified"""
163 # return arr if its already a numpy array
164 if isinstance(arr, np.ndarray) and array_mode is None:
165 return arr
167 # try to infer the array_mode
168 array_mode_inferred: ArrayMode = infer_array_mode(arr)
169 if array_mode is None:
170 array_mode = array_mode_inferred
171 elif array_mode != array_mode_inferred:
172 warnings.warn(
173 f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}"
174 )
176 # actually load the array
177 if array_mode == "array_list_meta":
178 assert isinstance(
179 arr, typing.Mapping
180 ), f"invalid list format: {type(arr) = }\n{arr = }"
182 data = np.array(arr["data"], dtype=arr["dtype"])
183 if tuple(arr["shape"]) != tuple(data.shape):
184 raise ValueError(f"invalid shape: {arr}")
185 return data
187 elif array_mode == "array_hex_meta":
188 assert isinstance(
189 arr, typing.Mapping
190 ), f"invalid list format: {type(arr) = }\n{arr = }"
192 data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"])
193 return data.reshape(arr["shape"])
195 elif array_mode == "array_b64_meta":
196 assert isinstance(
197 arr, typing.Mapping
198 ), f"invalid list format: {type(arr) = }\n{arr = }"
200 data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"])
201 return data.reshape(arr["shape"])
203 elif array_mode == "list":
204 assert isinstance(
205 arr, typing.Sequence
206 ), f"invalid list format: {type(arr) = }\n{arr = }"
208 return np.array(arr)
209 elif array_mode == "external":
210 # assume ZANJ has taken care of it
211 assert isinstance(arr, typing.Mapping)
212 if "data" not in arr:
213 raise KeyError(
214 f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}"
215 )
216 return arr["data"]
217 elif array_mode == "zero_dim":
218 assert isinstance(arr, typing.Mapping)
219 data = np.array(arr["data"])
220 if tuple(arr["shape"]) != tuple(data.shape):
221 raise ValueError(f"invalid shape: {arr}")
222 return data
223 else:
224 raise ValueError(f"invalid array_mode: {array_mode}")