muutils.json_serialize.array
this utilities module handles serialization and loading of numpy and torch arrays as json
array_list_metais less efficient (arrays are stored as nested lists), but preserves both metadata and human readability.array_b64_metais the most efficient, but is not human readable.externalis mostly for use inZANJ
1"""this utilities module handles serialization and loading of numpy and torch arrays as json 2 3- `array_list_meta` is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability. 4- `array_b64_meta` is the most efficient, but is not human readable. 5- `external` is mostly for use in [`ZANJ`](https://github.com/mivanit/ZANJ) 6 7""" 8 9from __future__ import annotations 10 11import base64 12import typing 13import warnings 14from typing import ( 15 TYPE_CHECKING, 16 Any, 17 Iterable, 18 Literal, 19 Optional, 20 Sequence, 21 TypedDict, 22 Union, 23 overload, 24) 25 26try: 27 import numpy as np 28except ImportError as e: 29 warnings.warn( 30 f"numpy is not installed, array serialization will not work: \n{e}", 31 ImportWarning, 32 ) 33 34if TYPE_CHECKING: 35 import numpy as np 36 import torch 37 from muutils.json_serialize.json_serialize import JsonSerializer 38 39from muutils.json_serialize.types import _FORMAT_KEY # pyright: ignore[reportPrivateUsage] 40 41# TYPING: pyright complains way too much here 42# pyright: reportCallIssue=false,reportArgumentType=false,reportUnknownVariableType=false,reportUnknownMemberType=false 43 44# Recursive type for nested numeric lists (output of arr.tolist()) 45NumericList = typing.Union[ 46 typing.List[typing.Union[int, float, bool]], 47 typing.List["NumericList"], 48] 49 50ArrayMode = Literal[ 51 "list", 52 "array_list_meta", 53 "array_hex_meta", 54 "array_b64_meta", 55 "external", 56 "zero_dim", 57] 58 59# Modes that produce SerializedArrayWithMeta (dict with metadata) 60ArrayModeWithMeta = Literal[ 61 "array_list_meta", 62 "array_hex_meta", 63 "array_b64_meta", 64 "zero_dim", 65 "external", 66] 67 68 69def array_n_elements(arr: Any) -> int: # type: ignore[name-defined] # pyright: ignore[reportAny] 70 """get the number of elements in an array""" 71 if isinstance(arr, np.ndarray): 72 return arr.size 73 elif str(type(arr)) == "<class 'torch.Tensor'>": # pyright: ignore[reportUnknownArgumentType, reportAny] 74 assert hasattr(arr, "nelement"), ( 75 "torch Tensor does not have nelement() method? this should not happen" 76 ) # pyright: ignore[reportAny] 77 return arr.nelement() # pyright: ignore[reportAny] 78 else: 79 raise TypeError(f"invalid type: {type(arr)}") # pyright: ignore[reportAny] 80 81 82class ArrayMetadata(TypedDict): 83 """Metadata for a numpy/torch array""" 84 85 shape: list[int] 86 dtype: str 87 n_elements: int 88 89 90class SerializedArrayWithMeta(TypedDict): 91 """Serialized array with metadata (for array_list_meta, array_hex_meta, array_b64_meta, zero_dim modes)""" 92 93 __muutils_format__: str 94 data: typing.Union[ 95 NumericList, str, int, float, bool 96 ] # list, hex str, b64 str, or scalar for zero_dim 97 shape: list[int] 98 dtype: str 99 n_elements: int 100 101 102def arr_metadata(arr: Any) -> ArrayMetadata: # pyright: ignore[reportAny] 103 """get metadata for a numpy array""" 104 return { 105 "shape": list(arr.shape), # pyright: ignore[reportAny] 106 "dtype": ( 107 arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) # pyright: ignore[reportAny] 108 ), 109 "n_elements": array_n_elements(arr), 110 } 111 112 113@overload 114def serialize_array( 115 jser: "JsonSerializer", 116 arr: "Union[np.ndarray, torch.Tensor]", 117 path: str | Sequence[str | int], 118 array_mode: Literal["list"], 119) -> NumericList: ... 120@overload 121def serialize_array( 122 jser: "JsonSerializer", 123 arr: "Union[np.ndarray, torch.Tensor]", 124 path: str | Sequence[str | int], 125 array_mode: ArrayModeWithMeta, 126) -> SerializedArrayWithMeta: ... 127@overload 128def serialize_array( 129 jser: "JsonSerializer", 130 arr: "Union[np.ndarray, torch.Tensor]", 131 path: str | Sequence[str | int], 132 array_mode: None = None, 133) -> SerializedArrayWithMeta | NumericList: ... 134def serialize_array( 135 jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821 136 arr: "Union[np.ndarray, torch.Tensor]", 137 path: str | Sequence[str | int], # pyright: ignore[reportUnusedParameter] 138 array_mode: ArrayMode | None = None, 139) -> SerializedArrayWithMeta | NumericList: 140 """serialize a numpy or pytorch array in one of several modes 141 142 if the object is zero-dimensional, simply get the unique item 143 144 `array_mode: ArrayMode` can be one of: 145 - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`) 146 - `array_list_meta`: serialize dict with metadata, actual list under the key `data` 147 - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data` 148 - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data` 149 150 for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is: 151 ``` 152 { 153 _FORMAT_KEY: <array_list_meta|array_hex_meta>, 154 "shape": arr.shape, 155 "dtype": str(arr.dtype), 156 "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>, 157 } 158 ``` 159 160 # Parameters: 161 - `arr : Any` array to serialize 162 - `array_mode : ArrayMode` mode in which to serialize the array 163 (defaults to `None` and inheriting from `jser: JsonSerializer`) 164 165 # Returns: 166 - `JSONitem` 167 json serialized array 168 169 # Raises: 170 - `KeyError` : if the array mode is not valid 171 """ 172 173 if array_mode is None: 174 array_mode = jser.array_mode 175 176 arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}" 177 arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr) # pyright: ignore[reportUnnecessaryIsInstance] 178 179 # Handle list mode first (no metadata needed) 180 if array_mode == "list": 181 return arr_np.tolist() # pyright: ignore[reportAny] 182 183 # For all other modes, compute metadata once 184 metadata: ArrayMetadata = arr_metadata(arr if len(arr.shape) == 0 else arr_np) 185 186 # TYPING: ty<=0.0.1a24 does not appear to support unpacking TypedDicts, so we do things manually. change it back later maybe? 187 188 # handle zero-dimensional arrays 189 if len(arr.shape) == 0: 190 return SerializedArrayWithMeta( 191 __muutils_format__=f"{arr_type}:zero_dim", 192 data=arr.item(), # pyright: ignore[reportAny] 193 shape=metadata["shape"], 194 dtype=metadata["dtype"], 195 n_elements=metadata["n_elements"], 196 ) 197 198 # Handle the metadata modes 199 if array_mode == "array_list_meta": 200 return SerializedArrayWithMeta( 201 __muutils_format__=f"{arr_type}:array_list_meta", 202 data=arr_np.tolist(), # pyright: ignore[reportAny] 203 shape=metadata["shape"], 204 dtype=metadata["dtype"], 205 n_elements=metadata["n_elements"], 206 ) 207 elif array_mode == "array_hex_meta": 208 return SerializedArrayWithMeta( 209 __muutils_format__=f"{arr_type}:array_hex_meta", 210 data=arr_np.tobytes().hex(), 211 shape=metadata["shape"], 212 dtype=metadata["dtype"], 213 n_elements=metadata["n_elements"], 214 ) 215 elif array_mode == "array_b64_meta": 216 return SerializedArrayWithMeta( 217 __muutils_format__=f"{arr_type}:array_b64_meta", 218 data=base64.b64encode(arr_np.tobytes()).decode(), 219 shape=metadata["shape"], 220 dtype=metadata["dtype"], 221 n_elements=metadata["n_elements"], 222 ) 223 else: 224 raise KeyError(f"invalid array_mode: {array_mode}") 225 226 227@overload 228def infer_array_mode( 229 arr: SerializedArrayWithMeta, 230) -> ArrayModeWithMeta: ... 231@overload 232def infer_array_mode(arr: NumericList) -> Literal["list"]: ... 233def infer_array_mode( 234 arr: Union[SerializedArrayWithMeta, NumericList], 235) -> ArrayMode: 236 """given a serialized array, infer the mode 237 238 assumes the array was serialized via `serialize_array()` 239 """ 240 return_mode: ArrayMode 241 if isinstance(arr, typing.Mapping): 242 # _FORMAT_KEY always maps to a string 243 fmt: str = arr.get(_FORMAT_KEY, "") # type: ignore 244 if fmt.endswith(":array_list_meta"): 245 arr_data = arr["data"] # ty: ignore[invalid-argument-type] 246 if not isinstance(arr_data, Iterable): 247 raise ValueError(f"invalid list format: {type(arr_data) = }\t{arr}") 248 return_mode = "array_list_meta" 249 elif fmt.endswith(":array_hex_meta"): 250 arr_data = arr["data"] # ty: ignore[invalid-argument-type] 251 if not isinstance(arr_data, str): 252 raise ValueError(f"invalid hex format: {type(arr_data) = }\t{arr}") 253 return_mode = "array_hex_meta" 254 elif fmt.endswith(":array_b64_meta"): 255 arr_data = arr["data"] # ty: ignore[invalid-argument-type] 256 if not isinstance(arr_data, str): 257 raise ValueError(f"invalid b64 format: {type(arr_data) = }\t{arr}") 258 return_mode = "array_b64_meta" 259 elif fmt.endswith(":external"): 260 return_mode = "external" 261 elif fmt.endswith(":zero_dim"): 262 return_mode = "zero_dim" 263 else: 264 raise ValueError(f"invalid format: {arr}") 265 elif isinstance(arr, list): # pyright: ignore[reportUnnecessaryIsInstance] 266 return_mode = "list" 267 else: 268 raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }") # pyright: ignore[reportUnreachable] 269 270 return return_mode 271 272 273@overload 274def load_array( 275 arr: SerializedArrayWithMeta, 276 array_mode: Optional[ArrayModeWithMeta] = None, 277) -> np.ndarray: ... 278@overload 279def load_array( 280 arr: NumericList, 281 array_mode: Optional[Literal["list"]] = None, 282) -> np.ndarray: ... 283@overload 284def load_array( 285 arr: np.ndarray, 286 array_mode: None = None, 287) -> np.ndarray: ... 288def load_array( 289 arr: Union[SerializedArrayWithMeta, np.ndarray, NumericList], 290 array_mode: Optional[ArrayMode] = None, 291) -> np.ndarray: 292 """load a json-serialized array, infer the mode if not specified""" 293 # return arr if its already a numpy array 294 if isinstance(arr, np.ndarray): 295 assert array_mode is None, ( 296 "array_mode should not be specified when loading a numpy array, since that is a no-op" 297 ) 298 return arr 299 300 # try to infer the array_mode 301 array_mode_inferred: ArrayMode = infer_array_mode(arr) 302 if array_mode is None: 303 array_mode = array_mode_inferred 304 elif array_mode != array_mode_inferred: 305 warnings.warn( 306 f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}" 307 ) 308 309 # actually load the array 310 if array_mode == "array_list_meta": 311 assert isinstance(arr, typing.Mapping), ( 312 f"invalid list format: {type(arr) = }\n{arr = }" 313 ) 314 data = np.array(arr["data"], dtype=arr["dtype"]) # type: ignore 315 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore 316 raise ValueError(f"invalid shape: {arr}") 317 return data 318 319 elif array_mode == "array_hex_meta": 320 assert isinstance(arr, typing.Mapping), ( 321 f"invalid list format: {type(arr) = }\n{arr = }" 322 ) 323 data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) # type: ignore 324 return data.reshape(arr["shape"]) # type: ignore 325 326 elif array_mode == "array_b64_meta": 327 assert isinstance(arr, typing.Mapping), ( 328 f"invalid list format: {type(arr) = }\n{arr = }" 329 ) 330 data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) # type: ignore 331 return data.reshape(arr["shape"]) # type: ignore 332 333 elif array_mode == "list": 334 assert isinstance(arr, typing.Sequence), ( 335 f"invalid list format: {type(arr) = }\n{arr = }" 336 ) 337 return np.array(arr) # type: ignore 338 elif array_mode == "external": 339 assert isinstance(arr, typing.Mapping) 340 if "data" not in arr: 341 raise KeyError( # pyright: ignore[reportUnreachable] 342 f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}" 343 ) 344 # we can ignore here since we assume ZANJ has taken care of it 345 return arr["data"] # type: ignore[return-value] # pyright: ignore[reportReturnType] 346 elif array_mode == "zero_dim": 347 assert isinstance(arr, typing.Mapping) 348 data = np.array(arr["data"]) # ty: ignore[invalid-argument-type] 349 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore 350 raise ValueError(f"invalid shape: {arr}") 351 return data 352 else: 353 raise ValueError(f"invalid array_mode: {array_mode}") # pyright: ignore[reportUnreachable]
NumericList =
typing.Union[typing.List[typing.Union[int, float, bool]], typing.List[ForwardRef('NumericList')]]
ArrayMode =
typing.Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
ArrayModeWithMeta =
typing.Literal['array_list_meta', 'array_hex_meta', 'array_b64_meta', 'zero_dim', 'external']
def
array_n_elements(arr: Any) -> int:
70def array_n_elements(arr: Any) -> int: # type: ignore[name-defined] # pyright: ignore[reportAny] 71 """get the number of elements in an array""" 72 if isinstance(arr, np.ndarray): 73 return arr.size 74 elif str(type(arr)) == "<class 'torch.Tensor'>": # pyright: ignore[reportUnknownArgumentType, reportAny] 75 assert hasattr(arr, "nelement"), ( 76 "torch Tensor does not have nelement() method? this should not happen" 77 ) # pyright: ignore[reportAny] 78 return arr.nelement() # pyright: ignore[reportAny] 79 else: 80 raise TypeError(f"invalid type: {type(arr)}") # pyright: ignore[reportAny]
get the number of elements in an array
class
ArrayMetadata(typing.TypedDict):
83class ArrayMetadata(TypedDict): 84 """Metadata for a numpy/torch array""" 85 86 shape: list[int] 87 dtype: str 88 n_elements: int
Metadata for a numpy/torch array
Inherited Members
- builtins.dict
- get
- setdefault
- pop
- popitem
- keys
- items
- values
- update
- fromkeys
- clear
- copy
class
SerializedArrayWithMeta(typing.TypedDict):
91class SerializedArrayWithMeta(TypedDict): 92 """Serialized array with metadata (for array_list_meta, array_hex_meta, array_b64_meta, zero_dim modes)""" 93 94 __muutils_format__: str 95 data: typing.Union[ 96 NumericList, str, int, float, bool 97 ] # list, hex str, b64 str, or scalar for zero_dim 98 shape: list[int] 99 dtype: str 100 n_elements: int
Serialized array with metadata (for array_list_meta, array_hex_meta, array_b64_meta, zero_dim modes)
data: Union[List[Union[int, float, bool]], List[Union[List[Union[int, float, bool]], List[ForwardRef('NumericList')]]], str, int, float, bool]
Inherited Members
- builtins.dict
- get
- setdefault
- pop
- popitem
- keys
- items
- values
- update
- fromkeys
- clear
- copy
103def arr_metadata(arr: Any) -> ArrayMetadata: # pyright: ignore[reportAny] 104 """get metadata for a numpy array""" 105 return { 106 "shape": list(arr.shape), # pyright: ignore[reportAny] 107 "dtype": ( 108 arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) # pyright: ignore[reportAny] 109 ), 110 "n_elements": array_n_elements(arr), 111 }
get metadata for a numpy array
def
serialize_array( jser: muutils.json_serialize.JsonSerializer, arr: Union[numpy.ndarray, torch.Tensor], path: Union[str, Sequence[str | int]], array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None) -> Union[SerializedArrayWithMeta, List[Union[int, float, bool]], List[Union[List[Union[int, float, bool]], List[ForwardRef('NumericList')]]]]:
135def serialize_array( 136 jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821 137 arr: "Union[np.ndarray, torch.Tensor]", 138 path: str | Sequence[str | int], # pyright: ignore[reportUnusedParameter] 139 array_mode: ArrayMode | None = None, 140) -> SerializedArrayWithMeta | NumericList: 141 """serialize a numpy or pytorch array in one of several modes 142 143 if the object is zero-dimensional, simply get the unique item 144 145 `array_mode: ArrayMode` can be one of: 146 - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`) 147 - `array_list_meta`: serialize dict with metadata, actual list under the key `data` 148 - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data` 149 - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data` 150 151 for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is: 152 ``` 153 { 154 _FORMAT_KEY: <array_list_meta|array_hex_meta>, 155 "shape": arr.shape, 156 "dtype": str(arr.dtype), 157 "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>, 158 } 159 ``` 160 161 # Parameters: 162 - `arr : Any` array to serialize 163 - `array_mode : ArrayMode` mode in which to serialize the array 164 (defaults to `None` and inheriting from `jser: JsonSerializer`) 165 166 # Returns: 167 - `JSONitem` 168 json serialized array 169 170 # Raises: 171 - `KeyError` : if the array mode is not valid 172 """ 173 174 if array_mode is None: 175 array_mode = jser.array_mode 176 177 arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}" 178 arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr) # pyright: ignore[reportUnnecessaryIsInstance] 179 180 # Handle list mode first (no metadata needed) 181 if array_mode == "list": 182 return arr_np.tolist() # pyright: ignore[reportAny] 183 184 # For all other modes, compute metadata once 185 metadata: ArrayMetadata = arr_metadata(arr if len(arr.shape) == 0 else arr_np) 186 187 # TYPING: ty<=0.0.1a24 does not appear to support unpacking TypedDicts, so we do things manually. change it back later maybe? 188 189 # handle zero-dimensional arrays 190 if len(arr.shape) == 0: 191 return SerializedArrayWithMeta( 192 __muutils_format__=f"{arr_type}:zero_dim", 193 data=arr.item(), # pyright: ignore[reportAny] 194 shape=metadata["shape"], 195 dtype=metadata["dtype"], 196 n_elements=metadata["n_elements"], 197 ) 198 199 # Handle the metadata modes 200 if array_mode == "array_list_meta": 201 return SerializedArrayWithMeta( 202 __muutils_format__=f"{arr_type}:array_list_meta", 203 data=arr_np.tolist(), # pyright: ignore[reportAny] 204 shape=metadata["shape"], 205 dtype=metadata["dtype"], 206 n_elements=metadata["n_elements"], 207 ) 208 elif array_mode == "array_hex_meta": 209 return SerializedArrayWithMeta( 210 __muutils_format__=f"{arr_type}:array_hex_meta", 211 data=arr_np.tobytes().hex(), 212 shape=metadata["shape"], 213 dtype=metadata["dtype"], 214 n_elements=metadata["n_elements"], 215 ) 216 elif array_mode == "array_b64_meta": 217 return SerializedArrayWithMeta( 218 __muutils_format__=f"{arr_type}:array_b64_meta", 219 data=base64.b64encode(arr_np.tobytes()).decode(), 220 shape=metadata["shape"], 221 dtype=metadata["dtype"], 222 n_elements=metadata["n_elements"], 223 ) 224 else: 225 raise KeyError(f"invalid array_mode: {array_mode}")
serialize a numpy or pytorch array in one of several modes
if the object is zero-dimensional, simply get the unique item
array_mode: ArrayMode can be one of:
list: serialize as a list of values, no metadata (equivalent toarr.tolist())array_list_meta: serialize dict with metadata, actual list under the keydataarray_hex_meta: serialize dict with metadata, actual hex string under the keydataarray_b64_meta: serialize dict with metadata, actual base64 string under the keydata
for array_list_meta, array_hex_meta, and array_b64_meta, the serialized object is:
{
_FORMAT_KEY: <array_list_meta|array_hex_meta>,
"shape": arr.shape,
"dtype": str(arr.dtype),
"data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>,
}
Parameters:
arr : Anyarray to serializearray_mode : ArrayModemode in which to serialize the array (defaults toNoneand inheriting fromjser: JsonSerializer)
Returns:
JSONitemjson serialized array
Raises:
KeyError: if the array mode is not valid
def
infer_array_mode( arr: Union[SerializedArrayWithMeta, List[Union[int, float, bool]], List[Union[List[Union[int, float, bool]], List[ForwardRef('NumericList')]]]]) -> Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']:
234def infer_array_mode( 235 arr: Union[SerializedArrayWithMeta, NumericList], 236) -> ArrayMode: 237 """given a serialized array, infer the mode 238 239 assumes the array was serialized via `serialize_array()` 240 """ 241 return_mode: ArrayMode 242 if isinstance(arr, typing.Mapping): 243 # _FORMAT_KEY always maps to a string 244 fmt: str = arr.get(_FORMAT_KEY, "") # type: ignore 245 if fmt.endswith(":array_list_meta"): 246 arr_data = arr["data"] # ty: ignore[invalid-argument-type] 247 if not isinstance(arr_data, Iterable): 248 raise ValueError(f"invalid list format: {type(arr_data) = }\t{arr}") 249 return_mode = "array_list_meta" 250 elif fmt.endswith(":array_hex_meta"): 251 arr_data = arr["data"] # ty: ignore[invalid-argument-type] 252 if not isinstance(arr_data, str): 253 raise ValueError(f"invalid hex format: {type(arr_data) = }\t{arr}") 254 return_mode = "array_hex_meta" 255 elif fmt.endswith(":array_b64_meta"): 256 arr_data = arr["data"] # ty: ignore[invalid-argument-type] 257 if not isinstance(arr_data, str): 258 raise ValueError(f"invalid b64 format: {type(arr_data) = }\t{arr}") 259 return_mode = "array_b64_meta" 260 elif fmt.endswith(":external"): 261 return_mode = "external" 262 elif fmt.endswith(":zero_dim"): 263 return_mode = "zero_dim" 264 else: 265 raise ValueError(f"invalid format: {arr}") 266 elif isinstance(arr, list): # pyright: ignore[reportUnnecessaryIsInstance] 267 return_mode = "list" 268 else: 269 raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }") # pyright: ignore[reportUnreachable] 270 271 return return_mode
given a serialized array, infer the mode
assumes the array was serialized via serialize_array()
def
load_array( arr: Union[SerializedArrayWithMeta, numpy.ndarray, List[Union[int, float, bool]], List[Union[List[Union[int, float, bool]], List[ForwardRef('NumericList')]]]], array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None) -> numpy.ndarray:
289def load_array( 290 arr: Union[SerializedArrayWithMeta, np.ndarray, NumericList], 291 array_mode: Optional[ArrayMode] = None, 292) -> np.ndarray: 293 """load a json-serialized array, infer the mode if not specified""" 294 # return arr if its already a numpy array 295 if isinstance(arr, np.ndarray): 296 assert array_mode is None, ( 297 "array_mode should not be specified when loading a numpy array, since that is a no-op" 298 ) 299 return arr 300 301 # try to infer the array_mode 302 array_mode_inferred: ArrayMode = infer_array_mode(arr) 303 if array_mode is None: 304 array_mode = array_mode_inferred 305 elif array_mode != array_mode_inferred: 306 warnings.warn( 307 f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}" 308 ) 309 310 # actually load the array 311 if array_mode == "array_list_meta": 312 assert isinstance(arr, typing.Mapping), ( 313 f"invalid list format: {type(arr) = }\n{arr = }" 314 ) 315 data = np.array(arr["data"], dtype=arr["dtype"]) # type: ignore 316 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore 317 raise ValueError(f"invalid shape: {arr}") 318 return data 319 320 elif array_mode == "array_hex_meta": 321 assert isinstance(arr, typing.Mapping), ( 322 f"invalid list format: {type(arr) = }\n{arr = }" 323 ) 324 data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) # type: ignore 325 return data.reshape(arr["shape"]) # type: ignore 326 327 elif array_mode == "array_b64_meta": 328 assert isinstance(arr, typing.Mapping), ( 329 f"invalid list format: {type(arr) = }\n{arr = }" 330 ) 331 data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) # type: ignore 332 return data.reshape(arr["shape"]) # type: ignore 333 334 elif array_mode == "list": 335 assert isinstance(arr, typing.Sequence), ( 336 f"invalid list format: {type(arr) = }\n{arr = }" 337 ) 338 return np.array(arr) # type: ignore 339 elif array_mode == "external": 340 assert isinstance(arr, typing.Mapping) 341 if "data" not in arr: 342 raise KeyError( # pyright: ignore[reportUnreachable] 343 f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}" 344 ) 345 # we can ignore here since we assume ZANJ has taken care of it 346 return arr["data"] # type: ignore[return-value] # pyright: ignore[reportReturnType] 347 elif array_mode == "zero_dim": 348 assert isinstance(arr, typing.Mapping) 349 data = np.array(arr["data"]) # ty: ignore[invalid-argument-type] 350 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore 351 raise ValueError(f"invalid shape: {arr}") 352 return data 353 else: 354 raise ValueError(f"invalid array_mode: {array_mode}") # pyright: ignore[reportUnreachable]
load a json-serialized array, infer the mode if not specified