docs for muutils v0.9.1
View Source on GitHub

muutils.json_serialize.array

this utilities module handles serialization and loading of numpy and torch arrays as json

  • array_list_meta is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability.
  • array_b64_meta is the most efficient, but is not human readable.
  • external is mostly for use in ZANJ

  1"""this utilities module handles serialization and loading of numpy and torch arrays as json
  2
  3- `array_list_meta` is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability.
  4- `array_b64_meta` is the most efficient, but is not human readable.
  5- `external` is mostly for use in [`ZANJ`](https://github.com/mivanit/ZANJ)
  6
  7"""
  8
  9from __future__ import annotations
 10
 11import base64
 12import typing
 13import warnings
 14from typing import (
 15    TYPE_CHECKING,
 16    Any,
 17    Iterable,
 18    Literal,
 19    Optional,
 20    Sequence,
 21    TypedDict,
 22    Union,
 23    overload,
 24)
 25
 26try:
 27    import numpy as np
 28except ImportError as e:
 29    warnings.warn(
 30        f"numpy is not installed, array serialization will not work: \n{e}",
 31        ImportWarning,
 32    )
 33
 34if TYPE_CHECKING:
 35    import numpy as np
 36    import torch
 37    from muutils.json_serialize.json_serialize import JsonSerializer
 38
 39from muutils.json_serialize.types import _FORMAT_KEY  # pyright: ignore[reportPrivateUsage]
 40
 41# TYPING: pyright complains way too much here
 42# pyright: reportCallIssue=false,reportArgumentType=false,reportUnknownVariableType=false,reportUnknownMemberType=false
 43
 44# Recursive type for nested numeric lists (output of arr.tolist())
 45NumericList = typing.Union[
 46    typing.List[typing.Union[int, float, bool]],
 47    typing.List["NumericList"],
 48]
 49
 50ArrayMode = Literal[
 51    "list",
 52    "array_list_meta",
 53    "array_hex_meta",
 54    "array_b64_meta",
 55    "external",
 56    "zero_dim",
 57]
 58
 59# Modes that produce SerializedArrayWithMeta (dict with metadata)
 60ArrayModeWithMeta = Literal[
 61    "array_list_meta",
 62    "array_hex_meta",
 63    "array_b64_meta",
 64    "zero_dim",
 65    "external",
 66]
 67
 68
 69def array_n_elements(arr: Any) -> int:  # type: ignore[name-defined]  # pyright: ignore[reportAny]
 70    """get the number of elements in an array"""
 71    if isinstance(arr, np.ndarray):
 72        return arr.size
 73    elif str(type(arr)) == "<class 'torch.Tensor'>":  # pyright: ignore[reportUnknownArgumentType, reportAny]
 74        assert hasattr(arr, "nelement"), (
 75            "torch Tensor does not have nelement() method? this should not happen"
 76        )  # pyright: ignore[reportAny]
 77        return arr.nelement()  # pyright: ignore[reportAny]
 78    else:
 79        raise TypeError(f"invalid type: {type(arr)}")  # pyright: ignore[reportAny]
 80
 81
 82class ArrayMetadata(TypedDict):
 83    """Metadata for a numpy/torch array"""
 84
 85    shape: list[int]
 86    dtype: str
 87    n_elements: int
 88
 89
 90class SerializedArrayWithMeta(TypedDict):
 91    """Serialized array with metadata (for array_list_meta, array_hex_meta, array_b64_meta, zero_dim modes)"""
 92
 93    __muutils_format__: str
 94    data: typing.Union[
 95        NumericList, str, int, float, bool
 96    ]  # list, hex str, b64 str, or scalar for zero_dim
 97    shape: list[int]
 98    dtype: str
 99    n_elements: int
100
101
102def arr_metadata(arr: Any) -> ArrayMetadata:  # pyright: ignore[reportAny]
103    """get metadata for a numpy array"""
104    return {
105        "shape": list(arr.shape),  # pyright: ignore[reportAny]
106        "dtype": (
107            arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype)  # pyright: ignore[reportAny]
108        ),
109        "n_elements": array_n_elements(arr),
110    }
111
112
113@overload
114def serialize_array(
115    jser: "JsonSerializer",
116    arr: "Union[np.ndarray, torch.Tensor]",
117    path: str | Sequence[str | int],
118    array_mode: Literal["list"],
119) -> NumericList: ...
120@overload
121def serialize_array(
122    jser: "JsonSerializer",
123    arr: "Union[np.ndarray, torch.Tensor]",
124    path: str | Sequence[str | int],
125    array_mode: ArrayModeWithMeta,
126) -> SerializedArrayWithMeta: ...
127@overload
128def serialize_array(
129    jser: "JsonSerializer",
130    arr: "Union[np.ndarray, torch.Tensor]",
131    path: str | Sequence[str | int],
132    array_mode: None = None,
133) -> SerializedArrayWithMeta | NumericList: ...
134def serialize_array(
135    jser: "JsonSerializer",  # type: ignore[name-defined] # noqa: F821
136    arr: "Union[np.ndarray, torch.Tensor]",
137    path: str | Sequence[str | int],  # pyright: ignore[reportUnusedParameter]
138    array_mode: ArrayMode | None = None,
139) -> SerializedArrayWithMeta | NumericList:
140    """serialize a numpy or pytorch array in one of several modes
141
142    if the object is zero-dimensional, simply get the unique item
143
144    `array_mode: ArrayMode` can be one of:
145    - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`)
146    - `array_list_meta`: serialize dict with metadata, actual list under the key `data`
147    - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data`
148    - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data`
149
150    for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is:
151    ```
152    {
153        _FORMAT_KEY: <array_list_meta|array_hex_meta>,
154        "shape": arr.shape,
155        "dtype": str(arr.dtype),
156        "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>,
157    }
158    ```
159
160    # Parameters:
161     - `arr : Any` array to serialize
162     - `array_mode : ArrayMode` mode in which to serialize the array
163       (defaults to `None` and inheriting from `jser: JsonSerializer`)
164
165    # Returns:
166     - `JSONitem`
167       json serialized array
168
169    # Raises:
170     - `KeyError` : if the array mode is not valid
171    """
172
173    if array_mode is None:
174        array_mode = jser.array_mode
175
176    arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}"
177    arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr)  # pyright: ignore[reportUnnecessaryIsInstance]
178
179    # Handle list mode first (no metadata needed)
180    if array_mode == "list":
181        return arr_np.tolist()  # pyright: ignore[reportAny]
182
183    # For all other modes, compute metadata once
184    metadata: ArrayMetadata = arr_metadata(arr if len(arr.shape) == 0 else arr_np)
185
186    # TYPING: ty<=0.0.1a24 does not appear to support unpacking TypedDicts, so we do things manually. change it back later maybe?
187
188    # handle zero-dimensional arrays
189    if len(arr.shape) == 0:
190        return SerializedArrayWithMeta(
191            __muutils_format__=f"{arr_type}:zero_dim",
192            data=arr.item(),  # pyright: ignore[reportAny]
193            shape=metadata["shape"],
194            dtype=metadata["dtype"],
195            n_elements=metadata["n_elements"],
196        )
197
198    # Handle the metadata modes
199    if array_mode == "array_list_meta":
200        return SerializedArrayWithMeta(
201            __muutils_format__=f"{arr_type}:array_list_meta",
202            data=arr_np.tolist(),  # pyright: ignore[reportAny]
203            shape=metadata["shape"],
204            dtype=metadata["dtype"],
205            n_elements=metadata["n_elements"],
206        )
207    elif array_mode == "array_hex_meta":
208        return SerializedArrayWithMeta(
209            __muutils_format__=f"{arr_type}:array_hex_meta",
210            data=arr_np.tobytes().hex(),
211            shape=metadata["shape"],
212            dtype=metadata["dtype"],
213            n_elements=metadata["n_elements"],
214        )
215    elif array_mode == "array_b64_meta":
216        return SerializedArrayWithMeta(
217            __muutils_format__=f"{arr_type}:array_b64_meta",
218            data=base64.b64encode(arr_np.tobytes()).decode(),
219            shape=metadata["shape"],
220            dtype=metadata["dtype"],
221            n_elements=metadata["n_elements"],
222        )
223    else:
224        raise KeyError(f"invalid array_mode: {array_mode}")
225
226
227@overload
228def infer_array_mode(
229    arr: SerializedArrayWithMeta,
230) -> ArrayModeWithMeta: ...
231@overload
232def infer_array_mode(arr: NumericList) -> Literal["list"]: ...
233def infer_array_mode(
234    arr: Union[SerializedArrayWithMeta, NumericList],
235) -> ArrayMode:
236    """given a serialized array, infer the mode
237
238    assumes the array was serialized via `serialize_array()`
239    """
240    return_mode: ArrayMode
241    if isinstance(arr, typing.Mapping):
242        # _FORMAT_KEY always maps to a string
243        fmt: str = arr.get(_FORMAT_KEY, "")  # type: ignore
244        if fmt.endswith(":array_list_meta"):
245            arr_data = arr["data"]  # ty: ignore[invalid-argument-type]
246            if not isinstance(arr_data, Iterable):
247                raise ValueError(f"invalid list format: {type(arr_data) = }\t{arr}")
248            return_mode = "array_list_meta"
249        elif fmt.endswith(":array_hex_meta"):
250            arr_data = arr["data"]  # ty: ignore[invalid-argument-type]
251            if not isinstance(arr_data, str):
252                raise ValueError(f"invalid hex format: {type(arr_data) = }\t{arr}")
253            return_mode = "array_hex_meta"
254        elif fmt.endswith(":array_b64_meta"):
255            arr_data = arr["data"]  # ty: ignore[invalid-argument-type]
256            if not isinstance(arr_data, str):
257                raise ValueError(f"invalid b64 format: {type(arr_data) = }\t{arr}")
258            return_mode = "array_b64_meta"
259        elif fmt.endswith(":external"):
260            return_mode = "external"
261        elif fmt.endswith(":zero_dim"):
262            return_mode = "zero_dim"
263        else:
264            raise ValueError(f"invalid format: {arr}")
265    elif isinstance(arr, list):  # pyright: ignore[reportUnnecessaryIsInstance]
266        return_mode = "list"
267    else:
268        raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }")  # pyright: ignore[reportUnreachable]
269
270    return return_mode
271
272
273@overload
274def load_array(
275    arr: SerializedArrayWithMeta,
276    array_mode: Optional[ArrayModeWithMeta] = None,
277) -> np.ndarray: ...
278@overload
279def load_array(
280    arr: NumericList,
281    array_mode: Optional[Literal["list"]] = None,
282) -> np.ndarray: ...
283@overload
284def load_array(
285    arr: np.ndarray,
286    array_mode: None = None,
287) -> np.ndarray: ...
288def load_array(
289    arr: Union[SerializedArrayWithMeta, np.ndarray, NumericList],
290    array_mode: Optional[ArrayMode] = None,
291) -> np.ndarray:
292    """load a json-serialized array, infer the mode if not specified"""
293    # return arr if its already a numpy array
294    if isinstance(arr, np.ndarray):
295        assert array_mode is None, (
296            "array_mode should not be specified when loading a numpy array, since that is a no-op"
297        )
298        return arr
299
300    # try to infer the array_mode
301    array_mode_inferred: ArrayMode = infer_array_mode(arr)
302    if array_mode is None:
303        array_mode = array_mode_inferred
304    elif array_mode != array_mode_inferred:
305        warnings.warn(
306            f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}"
307        )
308
309    # actually load the array
310    if array_mode == "array_list_meta":
311        assert isinstance(arr, typing.Mapping), (
312            f"invalid list format: {type(arr) = }\n{arr = }"
313        )
314        data = np.array(arr["data"], dtype=arr["dtype"])  # type: ignore
315        if tuple(arr["shape"]) != tuple(data.shape):  # type: ignore
316            raise ValueError(f"invalid shape: {arr}")
317        return data
318
319    elif array_mode == "array_hex_meta":
320        assert isinstance(arr, typing.Mapping), (
321            f"invalid list format: {type(arr) = }\n{arr = }"
322        )
323        data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"])  # type: ignore
324        return data.reshape(arr["shape"])  # type: ignore
325
326    elif array_mode == "array_b64_meta":
327        assert isinstance(arr, typing.Mapping), (
328            f"invalid list format: {type(arr) = }\n{arr = }"
329        )
330        data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"])  # type: ignore
331        return data.reshape(arr["shape"])  # type: ignore
332
333    elif array_mode == "list":
334        assert isinstance(arr, typing.Sequence), (
335            f"invalid list format: {type(arr) = }\n{arr = }"
336        )
337        return np.array(arr)  # type: ignore
338    elif array_mode == "external":
339        assert isinstance(arr, typing.Mapping)
340        if "data" not in arr:
341            raise KeyError(  # pyright: ignore[reportUnreachable]
342                f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}"
343            )
344        # we can ignore here since we assume ZANJ has taken care of it
345        return arr["data"]  # type: ignore[return-value] # pyright: ignore[reportReturnType]
346    elif array_mode == "zero_dim":
347        assert isinstance(arr, typing.Mapping)
348        data = np.array(arr["data"])  # ty: ignore[invalid-argument-type]
349        if tuple(arr["shape"]) != tuple(data.shape):  # type: ignore
350            raise ValueError(f"invalid shape: {arr}")
351        return data
352    else:
353        raise ValueError(f"invalid array_mode: {array_mode}")  # pyright: ignore[reportUnreachable]

NumericList = typing.Union[typing.List[typing.Union[int, float, bool]], typing.List[ForwardRef('NumericList')]]
ArrayMode = typing.Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
ArrayModeWithMeta = typing.Literal['array_list_meta', 'array_hex_meta', 'array_b64_meta', 'zero_dim', 'external']
def array_n_elements(arr: Any) -> int:
70def array_n_elements(arr: Any) -> int:  # type: ignore[name-defined]  # pyright: ignore[reportAny]
71    """get the number of elements in an array"""
72    if isinstance(arr, np.ndarray):
73        return arr.size
74    elif str(type(arr)) == "<class 'torch.Tensor'>":  # pyright: ignore[reportUnknownArgumentType, reportAny]
75        assert hasattr(arr, "nelement"), (
76            "torch Tensor does not have nelement() method? this should not happen"
77        )  # pyright: ignore[reportAny]
78        return arr.nelement()  # pyright: ignore[reportAny]
79    else:
80        raise TypeError(f"invalid type: {type(arr)}")  # pyright: ignore[reportAny]

get the number of elements in an array

class ArrayMetadata(typing.TypedDict):
83class ArrayMetadata(TypedDict):
84    """Metadata for a numpy/torch array"""
85
86    shape: list[int]
87    dtype: str
88    n_elements: int

Metadata for a numpy/torch array

shape: list[int]
dtype: str
n_elements: int
Inherited Members
builtins.dict
get
setdefault
pop
popitem
keys
items
values
update
fromkeys
clear
copy
class SerializedArrayWithMeta(typing.TypedDict):
 91class SerializedArrayWithMeta(TypedDict):
 92    """Serialized array with metadata (for array_list_meta, array_hex_meta, array_b64_meta, zero_dim modes)"""
 93
 94    __muutils_format__: str
 95    data: typing.Union[
 96        NumericList, str, int, float, bool
 97    ]  # list, hex str, b64 str, or scalar for zero_dim
 98    shape: list[int]
 99    dtype: str
100    n_elements: int

Serialized array with metadata (for array_list_meta, array_hex_meta, array_b64_meta, zero_dim modes)

data: Union[List[Union[int, float, bool]], List[Union[List[Union[int, float, bool]], List[ForwardRef('NumericList')]]], str, int, float, bool]
shape: list[int]
dtype: str
n_elements: int
Inherited Members
builtins.dict
get
setdefault
pop
popitem
keys
items
values
update
fromkeys
clear
copy
def arr_metadata(arr: Any) -> ArrayMetadata:
103def arr_metadata(arr: Any) -> ArrayMetadata:  # pyright: ignore[reportAny]
104    """get metadata for a numpy array"""
105    return {
106        "shape": list(arr.shape),  # pyright: ignore[reportAny]
107        "dtype": (
108            arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype)  # pyright: ignore[reportAny]
109        ),
110        "n_elements": array_n_elements(arr),
111    }

get metadata for a numpy array

def serialize_array( jser: muutils.json_serialize.JsonSerializer, arr: Union[numpy.ndarray, torch.Tensor], path: Union[str, Sequence[str | int]], array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None) -> Union[SerializedArrayWithMeta, List[Union[int, float, bool]], List[Union[List[Union[int, float, bool]], List[ForwardRef('NumericList')]]]]:
135def serialize_array(
136    jser: "JsonSerializer",  # type: ignore[name-defined] # noqa: F821
137    arr: "Union[np.ndarray, torch.Tensor]",
138    path: str | Sequence[str | int],  # pyright: ignore[reportUnusedParameter]
139    array_mode: ArrayMode | None = None,
140) -> SerializedArrayWithMeta | NumericList:
141    """serialize a numpy or pytorch array in one of several modes
142
143    if the object is zero-dimensional, simply get the unique item
144
145    `array_mode: ArrayMode` can be one of:
146    - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`)
147    - `array_list_meta`: serialize dict with metadata, actual list under the key `data`
148    - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data`
149    - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data`
150
151    for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is:
152    ```
153    {
154        _FORMAT_KEY: <array_list_meta|array_hex_meta>,
155        "shape": arr.shape,
156        "dtype": str(arr.dtype),
157        "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>,
158    }
159    ```
160
161    # Parameters:
162     - `arr : Any` array to serialize
163     - `array_mode : ArrayMode` mode in which to serialize the array
164       (defaults to `None` and inheriting from `jser: JsonSerializer`)
165
166    # Returns:
167     - `JSONitem`
168       json serialized array
169
170    # Raises:
171     - `KeyError` : if the array mode is not valid
172    """
173
174    if array_mode is None:
175        array_mode = jser.array_mode
176
177    arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}"
178    arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr)  # pyright: ignore[reportUnnecessaryIsInstance]
179
180    # Handle list mode first (no metadata needed)
181    if array_mode == "list":
182        return arr_np.tolist()  # pyright: ignore[reportAny]
183
184    # For all other modes, compute metadata once
185    metadata: ArrayMetadata = arr_metadata(arr if len(arr.shape) == 0 else arr_np)
186
187    # TYPING: ty<=0.0.1a24 does not appear to support unpacking TypedDicts, so we do things manually. change it back later maybe?
188
189    # handle zero-dimensional arrays
190    if len(arr.shape) == 0:
191        return SerializedArrayWithMeta(
192            __muutils_format__=f"{arr_type}:zero_dim",
193            data=arr.item(),  # pyright: ignore[reportAny]
194            shape=metadata["shape"],
195            dtype=metadata["dtype"],
196            n_elements=metadata["n_elements"],
197        )
198
199    # Handle the metadata modes
200    if array_mode == "array_list_meta":
201        return SerializedArrayWithMeta(
202            __muutils_format__=f"{arr_type}:array_list_meta",
203            data=arr_np.tolist(),  # pyright: ignore[reportAny]
204            shape=metadata["shape"],
205            dtype=metadata["dtype"],
206            n_elements=metadata["n_elements"],
207        )
208    elif array_mode == "array_hex_meta":
209        return SerializedArrayWithMeta(
210            __muutils_format__=f"{arr_type}:array_hex_meta",
211            data=arr_np.tobytes().hex(),
212            shape=metadata["shape"],
213            dtype=metadata["dtype"],
214            n_elements=metadata["n_elements"],
215        )
216    elif array_mode == "array_b64_meta":
217        return SerializedArrayWithMeta(
218            __muutils_format__=f"{arr_type}:array_b64_meta",
219            data=base64.b64encode(arr_np.tobytes()).decode(),
220            shape=metadata["shape"],
221            dtype=metadata["dtype"],
222            n_elements=metadata["n_elements"],
223        )
224    else:
225        raise KeyError(f"invalid array_mode: {array_mode}")

serialize a numpy or pytorch array in one of several modes

if the object is zero-dimensional, simply get the unique item

array_mode: ArrayMode can be one of:

  • list: serialize as a list of values, no metadata (equivalent to arr.tolist())
  • array_list_meta: serialize dict with metadata, actual list under the key data
  • array_hex_meta: serialize dict with metadata, actual hex string under the key data
  • array_b64_meta: serialize dict with metadata, actual base64 string under the key data

for array_list_meta, array_hex_meta, and array_b64_meta, the serialized object is:

{
    _FORMAT_KEY: <array_list_meta|array_hex_meta>,
    "shape": arr.shape,
    "dtype": str(arr.dtype),
    "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>,
}

Parameters:

  • arr : Any array to serialize
  • array_mode : ArrayMode mode in which to serialize the array (defaults to None and inheriting from jser: JsonSerializer)

Returns:

  • JSONitem json serialized array

Raises:

  • KeyError : if the array mode is not valid
def infer_array_mode( arr: Union[SerializedArrayWithMeta, List[Union[int, float, bool]], List[Union[List[Union[int, float, bool]], List[ForwardRef('NumericList')]]]]) -> Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']:
234def infer_array_mode(
235    arr: Union[SerializedArrayWithMeta, NumericList],
236) -> ArrayMode:
237    """given a serialized array, infer the mode
238
239    assumes the array was serialized via `serialize_array()`
240    """
241    return_mode: ArrayMode
242    if isinstance(arr, typing.Mapping):
243        # _FORMAT_KEY always maps to a string
244        fmt: str = arr.get(_FORMAT_KEY, "")  # type: ignore
245        if fmt.endswith(":array_list_meta"):
246            arr_data = arr["data"]  # ty: ignore[invalid-argument-type]
247            if not isinstance(arr_data, Iterable):
248                raise ValueError(f"invalid list format: {type(arr_data) = }\t{arr}")
249            return_mode = "array_list_meta"
250        elif fmt.endswith(":array_hex_meta"):
251            arr_data = arr["data"]  # ty: ignore[invalid-argument-type]
252            if not isinstance(arr_data, str):
253                raise ValueError(f"invalid hex format: {type(arr_data) = }\t{arr}")
254            return_mode = "array_hex_meta"
255        elif fmt.endswith(":array_b64_meta"):
256            arr_data = arr["data"]  # ty: ignore[invalid-argument-type]
257            if not isinstance(arr_data, str):
258                raise ValueError(f"invalid b64 format: {type(arr_data) = }\t{arr}")
259            return_mode = "array_b64_meta"
260        elif fmt.endswith(":external"):
261            return_mode = "external"
262        elif fmt.endswith(":zero_dim"):
263            return_mode = "zero_dim"
264        else:
265            raise ValueError(f"invalid format: {arr}")
266    elif isinstance(arr, list):  # pyright: ignore[reportUnnecessaryIsInstance]
267        return_mode = "list"
268    else:
269        raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }")  # pyright: ignore[reportUnreachable]
270
271    return return_mode

given a serialized array, infer the mode

assumes the array was serialized via serialize_array()

def load_array( arr: Union[SerializedArrayWithMeta, numpy.ndarray, List[Union[int, float, bool]], List[Union[List[Union[int, float, bool]], List[ForwardRef('NumericList')]]]], array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None) -> numpy.ndarray:
289def load_array(
290    arr: Union[SerializedArrayWithMeta, np.ndarray, NumericList],
291    array_mode: Optional[ArrayMode] = None,
292) -> np.ndarray:
293    """load a json-serialized array, infer the mode if not specified"""
294    # return arr if its already a numpy array
295    if isinstance(arr, np.ndarray):
296        assert array_mode is None, (
297            "array_mode should not be specified when loading a numpy array, since that is a no-op"
298        )
299        return arr
300
301    # try to infer the array_mode
302    array_mode_inferred: ArrayMode = infer_array_mode(arr)
303    if array_mode is None:
304        array_mode = array_mode_inferred
305    elif array_mode != array_mode_inferred:
306        warnings.warn(
307            f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}"
308        )
309
310    # actually load the array
311    if array_mode == "array_list_meta":
312        assert isinstance(arr, typing.Mapping), (
313            f"invalid list format: {type(arr) = }\n{arr = }"
314        )
315        data = np.array(arr["data"], dtype=arr["dtype"])  # type: ignore
316        if tuple(arr["shape"]) != tuple(data.shape):  # type: ignore
317            raise ValueError(f"invalid shape: {arr}")
318        return data
319
320    elif array_mode == "array_hex_meta":
321        assert isinstance(arr, typing.Mapping), (
322            f"invalid list format: {type(arr) = }\n{arr = }"
323        )
324        data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"])  # type: ignore
325        return data.reshape(arr["shape"])  # type: ignore
326
327    elif array_mode == "array_b64_meta":
328        assert isinstance(arr, typing.Mapping), (
329            f"invalid list format: {type(arr) = }\n{arr = }"
330        )
331        data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"])  # type: ignore
332        return data.reshape(arr["shape"])  # type: ignore
333
334    elif array_mode == "list":
335        assert isinstance(arr, typing.Sequence), (
336            f"invalid list format: {type(arr) = }\n{arr = }"
337        )
338        return np.array(arr)  # type: ignore
339    elif array_mode == "external":
340        assert isinstance(arr, typing.Mapping)
341        if "data" not in arr:
342            raise KeyError(  # pyright: ignore[reportUnreachable]
343                f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}"
344            )
345        # we can ignore here since we assume ZANJ has taken care of it
346        return arr["data"]  # type: ignore[return-value] # pyright: ignore[reportReturnType]
347    elif array_mode == "zero_dim":
348        assert isinstance(arr, typing.Mapping)
349        data = np.array(arr["data"])  # ty: ignore[invalid-argument-type]
350        if tuple(arr["shape"]) != tuple(data.shape):  # type: ignore
351            raise ValueError(f"invalid shape: {arr}")
352        return data
353    else:
354        raise ValueError(f"invalid array_mode: {array_mode}")  # pyright: ignore[reportUnreachable]

load a json-serialized array, infer the mode if not specified