Coverage for muutils\json_serialize\array.py: 64%

92 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-15 20:56 -0600

1"""this utilities module handles serialization and loading of numpy and torch arrays as json 

2 

3- `array_list_meta` is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability. 

4- `array_b64_meta` is the most efficient, but is not human readable. 

5- `external` is mostly for use in [`ZANJ`](https://github.com/mivanit/ZANJ) 

6 

7""" 

8 

9from __future__ import annotations 

10 

11import base64 

12import typing 

13import warnings 

14from typing import Any, Iterable, Literal, Optional, Sequence 

15 

16import numpy as np 

17 

18from muutils.json_serialize.util import JSONitem 

19 

20# pylint: disable=unused-argument 

21 

22ArrayMode = Literal[ 

23 "list", 

24 "array_list_meta", 

25 "array_hex_meta", 

26 "array_b64_meta", 

27 "external", 

28 "zero_dim", 

29] 

30 

31 

32def array_n_elements(arr) -> int: # type: ignore[name-defined] 

33 """get the number of elements in an array""" 

34 if isinstance(arr, np.ndarray): 

35 return arr.size 

36 elif str(type(arr)) == "<class 'torch.Tensor'>": 

37 return arr.nelement() 

38 else: 

39 raise TypeError(f"invalid type: {type(arr)}") 

40 

41 

42def arr_metadata(arr) -> dict[str, list[int] | str | int]: 

43 """get metadata for a numpy array""" 

44 return { 

45 "shape": list(arr.shape), 

46 "dtype": ( 

47 arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) 

48 ), 

49 "n_elements": array_n_elements(arr), 

50 } 

51 

52 

53def serialize_array( 

54 jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821 

55 arr: np.ndarray, 

56 path: str | Sequence[str | int], 

57 array_mode: ArrayMode | None = None, 

58) -> JSONitem: 

59 """serialize a numpy or pytorch array in one of several modes 

60 

61 if the object is zero-dimensional, simply get the unique item 

62 

63 `array_mode: ArrayMode` can be one of: 

64 - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`) 

65 - `array_list_meta`: serialize dict with metadata, actual list under the key `data` 

66 - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data` 

67 - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data` 

68 

69 for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is: 

70 ``` 

71 { 

72 "__format__": <array_list_meta|array_hex_meta>, 

73 "shape": arr.shape, 

74 "dtype": str(arr.dtype), 

75 "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>, 

76 } 

77 ``` 

78 

79 # Parameters: 

80 - `arr : Any` array to serialize 

81 - `array_mode : ArrayMode` mode in which to serialize the array 

82 (defaults to `None` and inheriting from `jser: JsonSerializer`) 

83 

84 # Returns: 

85 - `JSONitem` 

86 json serialized array 

87 

88 # Raises: 

89 - `KeyError` : if the array mode is not valid 

90 """ 

91 

92 if array_mode is None: 

93 array_mode = jser.array_mode 

94 

95 arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}" 

96 arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr) 

97 

98 # handle zero-dimensional arrays 

99 if len(arr.shape) == 0: 

100 return { 

101 "__format__": f"{arr_type}:zero_dim", 

102 "data": arr.item(), 

103 **arr_metadata(arr), 

104 } 

105 

106 if array_mode == "array_list_meta": 

107 return { 

108 "__format__": f"{arr_type}:array_list_meta", 

109 "data": arr_np.tolist(), 

110 **arr_metadata(arr_np), 

111 } 

112 elif array_mode == "list": 

113 return arr_np.tolist() 

114 elif array_mode == "array_hex_meta": 

115 return { 

116 "__format__": f"{arr_type}:array_hex_meta", 

117 "data": arr_np.tobytes().hex(), 

118 **arr_metadata(arr_np), 

119 } 

120 elif array_mode == "array_b64_meta": 

121 return { 

122 "__format__": f"{arr_type}:array_b64_meta", 

123 "data": base64.b64encode(arr_np.tobytes()).decode(), 

124 **arr_metadata(arr_np), 

125 } 

126 else: 

127 raise KeyError(f"invalid array_mode: {array_mode}") 

128 

129 

130def infer_array_mode(arr: JSONitem) -> ArrayMode: 

131 """given a serialized array, infer the mode 

132 

133 assumes the array was serialized via `serialize_array()` 

134 """ 

135 if isinstance(arr, typing.Mapping): 

136 fmt: str = arr.get("__format__", "") 

137 if fmt.endswith(":array_list_meta"): 

138 if not isinstance(arr["data"], Iterable): 

139 raise ValueError(f"invalid list format: {type(arr['data']) = }\t{arr}") 

140 return "array_list_meta" 

141 elif fmt.endswith(":array_hex_meta"): 

142 if not isinstance(arr["data"], str): 

143 raise ValueError(f"invalid hex format: {type(arr['data']) = }\t{arr}") 

144 return "array_hex_meta" 

145 elif fmt.endswith(":array_b64_meta"): 

146 if not isinstance(arr["data"], str): 

147 raise ValueError(f"invalid b64 format: {type(arr['data']) = }\t{arr}") 

148 return "array_b64_meta" 

149 elif fmt.endswith(":external"): 

150 return "external" 

151 elif fmt.endswith(":zero_dim"): 

152 return "zero_dim" 

153 else: 

154 raise ValueError(f"invalid format: {arr}") 

155 elif isinstance(arr, list): 

156 return "list" 

157 else: 

158 raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }") 

159 

160 

161def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: 

162 """load a json-serialized array, infer the mode if not specified""" 

163 # return arr if its already a numpy array 

164 if isinstance(arr, np.ndarray) and array_mode is None: 

165 return arr 

166 

167 # try to infer the array_mode 

168 array_mode_inferred: ArrayMode = infer_array_mode(arr) 

169 if array_mode is None: 

170 array_mode = array_mode_inferred 

171 elif array_mode != array_mode_inferred: 

172 warnings.warn( 

173 f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}" 

174 ) 

175 

176 # actually load the array 

177 if array_mode == "array_list_meta": 

178 assert isinstance( 

179 arr, typing.Mapping 

180 ), f"invalid list format: {type(arr) = }\n{arr = }" 

181 

182 data = np.array(arr["data"], dtype=arr["dtype"]) 

183 if tuple(arr["shape"]) != tuple(data.shape): 

184 raise ValueError(f"invalid shape: {arr}") 

185 return data 

186 

187 elif array_mode == "array_hex_meta": 

188 assert isinstance( 

189 arr, typing.Mapping 

190 ), f"invalid list format: {type(arr) = }\n{arr = }" 

191 

192 data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) 

193 return data.reshape(arr["shape"]) 

194 

195 elif array_mode == "array_b64_meta": 

196 assert isinstance( 

197 arr, typing.Mapping 

198 ), f"invalid list format: {type(arr) = }\n{arr = }" 

199 

200 data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) 

201 return data.reshape(arr["shape"]) 

202 

203 elif array_mode == "list": 

204 assert isinstance( 

205 arr, typing.Sequence 

206 ), f"invalid list format: {type(arr) = }\n{arr = }" 

207 

208 return np.array(arr) 

209 elif array_mode == "external": 

210 # assume ZANJ has taken care of it 

211 assert isinstance(arr, typing.Mapping) 

212 if "data" not in arr: 

213 raise KeyError( 

214 f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}" 

215 ) 

216 return arr["data"] 

217 elif array_mode == "zero_dim": 

218 assert isinstance(arr, typing.Mapping) 

219 data = np.array(arr["data"]) 

220 if tuple(arr["shape"]) != tuple(data.shape): 

221 raise ValueError(f"invalid shape: {arr}") 

222 return data 

223 else: 

224 raise ValueError(f"invalid array_mode: {array_mode}")