Coverage for muutils\json_serialize\util.py: 45%

112 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-15 21:53 -0600

1"""utilities for json_serialize""" 

2 

3from __future__ import annotations 

4 

5import dataclasses 

6import functools 

7import inspect 

8import sys 

9import typing 

10import warnings 

11from typing import Any, Callable, Iterable, Union 

12 

13_NUMPY_WORKING: bool 

14try: 

15 _NUMPY_WORKING = True 

16except ImportError: 

17 warnings.warn("numpy not found, cannot serialize numpy arrays!") 

18 _NUMPY_WORKING = False 

19 

20 

21JSONitem = Union[bool, int, float, str, list, typing.Dict[str, Any], None] 

22JSONdict = typing.Dict[str, JSONitem] 

23Hashableitem = Union[bool, int, float, str, tuple] 

24 

25# or if python version <3.9 

26if typing.TYPE_CHECKING or sys.version_info < (3, 9): 

27 MonoTuple = typing.Sequence 

28else: 

29 

30 class MonoTuple: 

31 """tuple type hint, but for a tuple of any length with all the same type""" 

32 

33 __slots__ = () 

34 

35 def __new__(cls, *args, **kwargs): 

36 raise TypeError("Type MonoTuple cannot be instantiated.") 

37 

38 def __init_subclass__(cls, *args, **kwargs): 

39 raise TypeError(f"Cannot subclass {cls.__module__}") 

40 

41 # idk why mypy thinks there is no such function in typing 

42 @typing._tp_cache # type: ignore 

43 def __class_getitem__(cls, params): 

44 if getattr(params, "__origin__", None) == typing.Union: 

45 return typing.GenericAlias(tuple, (params, Ellipsis)) 

46 elif isinstance(params, type): 

47 typing.GenericAlias(tuple, (params, Ellipsis)) 

48 # test if has len and is iterable 

49 elif isinstance(params, Iterable): 

50 if len(params) == 0: 

51 return tuple 

52 elif len(params) == 1: 

53 return typing.GenericAlias(tuple, (params[0], Ellipsis)) 

54 else: 

55 raise TypeError(f"MonoTuple expects 1 type argument, got {params = }") 

56 

57 

58class UniversalContainer: 

59 """contains everything -- `x in UniversalContainer()` is always True""" 

60 

61 def __contains__(self, x: Any) -> bool: 

62 return True 

63 

64 

65def isinstance_namedtuple(x: Any) -> bool: 

66 """checks if `x` is a `namedtuple` 

67 

68 credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple 

69 """ 

70 t: type = type(x) 

71 b: tuple = t.__bases__ 

72 if len(b) != 1 or (b[0] is not tuple): 

73 return False 

74 f: Any = getattr(t, "_fields", None) 

75 if not isinstance(f, tuple): 

76 return False 

77 return all(isinstance(n, str) for n in f) 

78 

79 

80def try_catch(func: Callable): 

81 """wraps the function to catch exceptions, returns serialized error message on exception 

82 

83 returned func will return normal result on success, or error message on exception 

84 """ 

85 

86 @functools.wraps(func) 

87 def newfunc(*args, **kwargs): 

88 try: 

89 return func(*args, **kwargs) 

90 except Exception as e: 

91 return f"{e.__class__.__name__}: {e}" 

92 

93 return newfunc 

94 

95 

96def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem: 

97 if isinstance(obj, typing.Mapping): 

98 return tuple((k, _recursive_hashify(v)) for k, v in obj.items()) 

99 elif isinstance(obj, (tuple, list, Iterable)): 

100 return tuple(_recursive_hashify(v) for v in obj) 

101 elif isinstance(obj, (bool, int, float, str)): 

102 return obj 

103 else: 

104 if force: 

105 return str(obj) 

106 else: 

107 raise ValueError(f"cannot hashify:\n{obj}") 

108 

109 

110class SerializationException(Exception): 

111 pass 

112 

113 

114def string_as_lines(s: str | None) -> list[str]: 

115 """for easier reading of long strings in json, split up by newlines 

116 

117 sort of like how jupyter notebooks do it 

118 """ 

119 if s is None: 

120 return list() 

121 else: 

122 return s.splitlines(keepends=False) 

123 

124 

125def safe_getsource(func) -> list[str]: 

126 try: 

127 return string_as_lines(inspect.getsource(func)) 

128 except Exception as e: 

129 return string_as_lines(f"Error: Unable to retrieve source code:\n{e}") 

130 

131 

132# credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises 

133def array_safe_eq(a: Any, b: Any) -> bool: 

134 """check if two objects are equal, account for if numpy arrays or torch tensors""" 

135 if a is b: 

136 return True 

137 

138 if type(a) is not type(b): 

139 return False 

140 

141 if ( 

142 str(type(a)) == "<class 'numpy.ndarray'>" 

143 and str(type(b)) == "<class 'numpy.ndarray'>" 

144 ) or ( 

145 str(type(a)) == "<class 'torch.Tensor'>" 

146 and str(type(b)) == "<class 'torch.Tensor'>" 

147 ): 

148 return (a == b).all() 

149 

150 if ( 

151 str(type(a)) == "<class 'pandas.core.frame.DataFrame'>" 

152 and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>" 

153 ): 

154 return a.equals(b) 

155 

156 if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence): 

157 if len(a) == 0 and len(b) == 0: 

158 return True 

159 return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b)) 

160 

161 if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)): 

162 return len(a) == len(b) and all( 

163 array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2]) 

164 for k1, k2 in zip(a.keys(), b.keys()) 

165 ) 

166 

167 try: 

168 return bool(a == b) 

169 except (TypeError, ValueError) as e: 

170 warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}") 

171 return NotImplemented # type: ignore[return-value] 

172 

173 

174def dc_eq( 

175 dc1, 

176 dc2, 

177 except_when_class_mismatch: bool = False, 

178 false_when_class_mismatch: bool = True, 

179 except_when_field_mismatch: bool = False, 

180) -> bool: 

181 """ 

182 checks if two dataclasses which (might) hold numpy arrays are equal 

183 

184 # Parameters: 

185 

186 - `dc1`: the first dataclass 

187 - `dc2`: the second dataclass 

188 - `except_when_class_mismatch: bool` 

189 if `True`, will throw `TypeError` if the classes are different. 

190 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` 

191 (default: `False`) 

192 - `false_when_class_mismatch: bool` 

193 only relevant if `except_when_class_mismatch` is `False`. 

194 if `True`, will return `False` if the classes are different. 

195 if `False`, will attempt to compare the fields. 

196 - `except_when_field_mismatch: bool` 

197 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. 

198 if `True`, will throw `TypeError` if the fields are different. 

199 (default: `True`) 

200 

201 # Returns: 

202 - `bool`: True if the dataclasses are equal, False otherwise 

203 

204 # Raises: 

205 - `TypeError`: if the dataclasses are of different classes 

206 - `AttributeError`: if the dataclasses have different fields 

207 

208 ``` 

209 [START] 

210 

211 ┌───────────┐ ┌─────────┐ 

212 │dc1 is dc2?├─►│ classes │ 

213 └──┬────────┘No│ match? │ 

214 ──── │ ├─────────┤ 

215 (True)◄──┘Yes │No │Yes 

216 ──── ▼ ▼ 

217 ┌────────────────┐ ┌────────────┐ 

218 │ except when │ │ fields keys│ 

219 │ class mismatch?│ │ match? │ 

220 ├───────────┬────┘ ├───────┬────┘ 

221 │Yes │No │No │Yes 

222 ▼ ▼ ▼ ▼ 

223 ─────────── ┌──────────┐ ┌────────┐ 

224 { raise } │ except │ │ field │ 

225 { TypeError } │ when │ │ values │ 

226 ─────────── │ field │ │ match? │ 

227 │ mismatch?│ ├────┬───┘ 

228 ├───────┬──┘ │ │Yes 

229 │Yes │No │No ▼ 

230 ▼ ▼ │ ──── 

231 ─────────────── ───── │ (True) 

232 { raise } (False)◄┘ ──── 

233 { AttributeError} ───── 

234 ─────────────── 

235 ``` 

236 

237 """ 

238 if dc1 is dc2: 

239 return True 

240 

241 if dc1.__class__ is not dc2.__class__: 

242 if except_when_class_mismatch: 

243 # if the classes don't match, raise an error 

244 raise TypeError( 

245 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" 

246 ) 

247 if except_when_field_mismatch: 

248 dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)]) 

249 dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)]) 

250 fields_match: bool = set(dc1_fields) == set(dc2_fields) 

251 if not fields_match: 

252 # if the fields match, keep going 

253 raise AttributeError( 

254 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" 

255 ) 

256 return False 

257 

258 return all( 

259 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) 

260 for fld in dataclasses.fields(dc1) 

261 if fld.compare 

262 )