muutils.json_serialize.util
utilities for json_serialize
1"""utilities for json_serialize""" 2 3from __future__ import annotations 4 5import dataclasses 6import functools 7import inspect 8import sys 9import typing 10import warnings 11from typing import Any, Callable, Iterable, TypeVar, Union 12 13from muutils.json_serialize.types import BaseType, Hashableitem 14 15if typing.TYPE_CHECKING: 16 pass 17 18_NUMPY_WORKING: bool 19try: 20 _NUMPY_WORKING = True 21except ImportError: 22 warnings.warn("numpy not found, cannot serialize numpy arrays!") 23 _NUMPY_WORKING = False 24 25 26# pyright: reportExplicitAny=false 27 28# At type-checking time, include array serialization types to avoid nominal type errors 29# This avoids superfluous imports at runtime 30# if TYPE_CHECKING: 31# from muutils.json_serialize.array import NumericList, SerializedArrayWithMeta 32 33# JSONitem = Union[ 34# BaseType, 35# typing.Sequence["JSONitem"], 36# typing.Dict[str, "JSONitem"], 37# SerializedArrayWithMeta, 38# NumericList, 39# ] 40# else: 41 42JSONitem = Union[ 43 BaseType, 44 typing.Sequence["JSONitem"], 45 typing.Dict[str, "JSONitem"], 46 # TODO: figure this out 47 # "_SerializedSet", 48 # "_SerializedFrozenset", 49] 50 51JSONdict = typing.Dict[str, JSONitem] 52 53 54# TODO: this bit is very broken 55# or if python version <3.9 56if typing.TYPE_CHECKING or sys.version_info < (3, 9): 57 MonoTuple = typing.Sequence 58else: 59 60 class MonoTuple: # pyright: ignore[reportUnreachable] 61 """tuple type hint, but for a tuple of any length with all the same type""" 62 63 __slots__ = () 64 65 def __new__(cls, *args, **kwargs): 66 raise TypeError("Type MonoTuple cannot be instantiated.") 67 68 def __init_subclass__(cls, *args, **kwargs): 69 raise TypeError(f"Cannot subclass {cls.__module__}") 70 71 # idk why mypy thinks there is no such function in typing 72 @typing._tp_cache # type: ignore 73 def __class_getitem__(cls, params): 74 if getattr(params, "__origin__", None) == typing.Union: 75 return typing.GenericAlias(tuple, (params, Ellipsis)) 76 elif isinstance(params, type): 77 typing.GenericAlias(tuple, (params, Ellipsis)) 78 # test if has len and is iterable 79 elif isinstance(params, Iterable): 80 if len(params) == 0: 81 return tuple 82 elif len(params) == 1: 83 return typing.GenericAlias(tuple, (params[0], Ellipsis)) 84 else: 85 raise TypeError(f"MonoTuple expects 1 type argument, got {params = }") 86 87 88# TYPING: we allow `Any` here because the container is... universal 89class UniversalContainer: 90 """contains everything -- `x in UniversalContainer()` is always True""" 91 92 def __contains__(self, x: Any) -> bool: # pyright: ignore[reportAny] 93 return True 94 95 96def isinstance_namedtuple(x: Any) -> bool: # pyright: ignore[reportAny] 97 """checks if `x` is a `namedtuple` 98 99 credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple 100 """ 101 t: type = type(x) # pyright: ignore[reportUnknownVariableType, reportAny] 102 b: tuple[type, ...] = t.__bases__ 103 if len(b) != 1 or (b[0] is not tuple): 104 return False 105 f: Any = getattr(t, "_fields", None) 106 if not isinstance(f, tuple): 107 return False 108 # fine that the type is unknown -- that's what we want to check 109 return all(isinstance(n, str) for n in f) # pyright: ignore[reportUnknownVariableType] 110 111 112T_FuncTryCatchReturn = TypeVar("T_FuncTryCatchReturn") 113 114 115def try_catch( 116 func: Callable[..., T_FuncTryCatchReturn], 117) -> Callable[..., Union[T_FuncTryCatchReturn, str]]: 118 """wraps the function to catch exceptions, returns serialized error message on exception 119 120 returned func will return normal result on success, or error message on exception 121 """ 122 123 @functools.wraps(func) 124 def newfunc(*args: Any, **kwargs: Any) -> Union[T_FuncTryCatchReturn, str]: # pyright: ignore[reportAny] 125 try: 126 return func(*args, **kwargs) 127 except Exception as e: 128 return f"{e.__class__.__name__}: {e}" 129 130 return newfunc 131 132 133# TYPING: can we get rid of any of these? 134def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem: # pyright: ignore[reportAny] 135 if isinstance(obj, typing.Mapping): 136 return tuple((k, _recursive_hashify(v)) for k, v in obj.items()) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] 137 elif isinstance(obj, (bool, int, float, str)): 138 return obj 139 elif isinstance(obj, (tuple, list, Iterable)): 140 return tuple(_recursive_hashify(v) for v in obj) # pyright: ignore[reportUnknownVariableType] 141 else: 142 if force: 143 return str(obj) # pyright: ignore[reportAny] 144 else: 145 raise ValueError(f"cannot hashify:\n{obj}") 146 147 148class SerializationException(Exception): 149 pass 150 151 152def string_as_lines(s: str | None) -> list[str]: 153 """for easier reading of long strings in json, split up by newlines 154 155 sort of like how jupyter notebooks do it 156 """ 157 if s is None: 158 return list() 159 else: 160 return s.splitlines(keepends=False) 161 162 163def safe_getsource(func: Callable[..., Any]) -> list[str]: 164 try: 165 return string_as_lines(inspect.getsource(func)) 166 except Exception as e: 167 return string_as_lines(f"Error: Unable to retrieve source code:\n{e}") 168 169 170# credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises 171def array_safe_eq(a: Any, b: Any) -> bool: # pyright: ignore[reportAny] 172 """check if two objects are equal, account for if numpy arrays or torch tensors""" 173 if a is b: 174 return True 175 176 if type(a) is not type(b): # pyright: ignore[reportAny] 177 return False 178 179 if ( 180 str(type(a)) == "<class 'numpy.ndarray'>" # pyright: ignore[reportAny, reportUnknownArgumentType] 181 and str(type(b)) == "<class 'numpy.ndarray'>" # pyright: ignore[reportAny, reportUnknownArgumentType] 182 ) or ( 183 str(type(a)) == "<class 'torch.Tensor'>" # pyright: ignore[reportAny, reportUnknownArgumentType] 184 and str(type(b)) == "<class 'torch.Tensor'>" # pyright: ignore[reportAny, reportUnknownArgumentType] 185 ): 186 return (a == b).all() # pyright: ignore[reportAny] 187 188 if ( 189 str(type(a)) == "<class 'pandas.core.frame.DataFrame'>" # pyright: ignore[reportUnknownArgumentType, reportAny] 190 and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>" # pyright: ignore[reportUnknownArgumentType, reportAny] 191 ): 192 return a.equals(b) # pyright: ignore[reportAny] 193 194 if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence): 195 if len(a) == 0 and len(b) == 0: 196 return True 197 return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b)) 198 199 if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)): 200 return len(a) == len(b) and all( # pyright: ignore[reportUnknownArgumentType] 201 array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2]) 202 for k1, k2 in zip(a.keys(), b.keys()) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] 203 ) 204 205 try: 206 return bool(a == b) # pyright: ignore[reportAny] 207 except (TypeError, ValueError) as e: 208 warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}") 209 return NotImplemented # type: ignore[return-value] 210 211 212# TYPING: see what can be done about so many `Any`s here 213def dc_eq( 214 dc1: Any, # pyright: ignore[reportAny] 215 dc2: Any, # pyright: ignore[reportAny] 216 except_when_class_mismatch: bool = False, 217 false_when_class_mismatch: bool = True, 218 except_when_field_mismatch: bool = False, 219) -> bool: 220 """ 221 checks if two dataclasses which (might) hold numpy arrays are equal 222 223 # Parameters: 224 225 - `dc1`: the first dataclass 226 - `dc2`: the second dataclass 227 - `except_when_class_mismatch: bool` 228 if `True`, will throw `TypeError` if the classes are different. 229 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` 230 (default: `False`) 231 - `false_when_class_mismatch: bool` 232 only relevant if `except_when_class_mismatch` is `False`. 233 if `True`, will return `False` if the classes are different. 234 if `False`, will attempt to compare the fields. 235 - `except_when_field_mismatch: bool` 236 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. 237 if `True`, will throw `AttributeError` if the fields are different. 238 (default: `False`) 239 240 # Returns: 241 - `bool`: True if the dataclasses are equal, False otherwise 242 243 # Raises: 244 - `TypeError`: if the dataclasses are of different classes 245 - `AttributeError`: if the dataclasses have different fields 246 247 ``` 248 [START] 249 ▼ 250 ┌─────────────┐ 251 │ dc1 is dc2? │───Yes───► (True) 252 └──────┬──────┘ 253 │No 254 ▼ 255 ┌───────────────┐ 256 │ classes match?│───Yes───► [compare field values] ───► (True/False) 257 └──────┬────────┘ 258 │No 259 ▼ 260 ┌────────────────────────────┐ 261 │ except_when_class_mismatch?│───Yes───► { raise TypeError } 262 └─────────────┬──────────────┘ 263 │No 264 ▼ 265 ┌────────────────────────────┐ 266 │ false_when_class_mismatch? │───Yes───► (False) 267 └─────────────┬──────────────┘ 268 │No 269 ▼ 270 ┌────────────────────────────┐ 271 │ except_when_field_mismatch?│───No────► [compare field values] 272 └─────────────┬──────────────┘ 273 │Yes 274 ▼ 275 ┌───────────────┐ 276 │ fields match? │───Yes───► [compare field values] 277 └──────┬────────┘ 278 │No 279 ▼ 280 { raise AttributeError } 281 ``` 282 283 """ 284 if dc1 is dc2: 285 return True 286 287 if dc1.__class__ is not dc2.__class__: # pyright: ignore[reportAny] 288 if except_when_class_mismatch: 289 # if the classes don't match, raise an error 290 raise TypeError( 291 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" # pyright: ignore[reportAny] 292 ) 293 if false_when_class_mismatch: 294 # return False immediately without attempting field comparison 295 return False 296 # classes don't match but we'll try to compare fields anyway 297 if except_when_field_mismatch: 298 dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)]) # pyright: ignore[reportAny] 299 dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)]) # pyright: ignore[reportAny] 300 fields_match: bool = set(dc1_fields) == set(dc2_fields) 301 if not fields_match: 302 # if the fields don't match, raise an error 303 raise AttributeError( 304 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" 305 ) 306 307 return all( 308 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) # pyright: ignore[reportAny] 309 for fld in dataclasses.fields(dc1) # pyright: ignore[reportAny] 310 if fld.compare 311 )
90class UniversalContainer: 91 """contains everything -- `x in UniversalContainer()` is always True""" 92 93 def __contains__(self, x: Any) -> bool: # pyright: ignore[reportAny] 94 return True
contains everything -- x in UniversalContainer() is always True
97def isinstance_namedtuple(x: Any) -> bool: # pyright: ignore[reportAny] 98 """checks if `x` is a `namedtuple` 99 100 credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple 101 """ 102 t: type = type(x) # pyright: ignore[reportUnknownVariableType, reportAny] 103 b: tuple[type, ...] = t.__bases__ 104 if len(b) != 1 or (b[0] is not tuple): 105 return False 106 f: Any = getattr(t, "_fields", None) 107 if not isinstance(f, tuple): 108 return False 109 # fine that the type is unknown -- that's what we want to check 110 return all(isinstance(n, str) for n in f) # pyright: ignore[reportUnknownVariableType]
checks if x is a namedtuple
credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
116def try_catch( 117 func: Callable[..., T_FuncTryCatchReturn], 118) -> Callable[..., Union[T_FuncTryCatchReturn, str]]: 119 """wraps the function to catch exceptions, returns serialized error message on exception 120 121 returned func will return normal result on success, or error message on exception 122 """ 123 124 @functools.wraps(func) 125 def newfunc(*args: Any, **kwargs: Any) -> Union[T_FuncTryCatchReturn, str]: # pyright: ignore[reportAny] 126 try: 127 return func(*args, **kwargs) 128 except Exception as e: 129 return f"{e.__class__.__name__}: {e}" 130 131 return newfunc
wraps the function to catch exceptions, returns serialized error message on exception
returned func will return normal result on success, or error message on exception
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- add_note
- args
153def string_as_lines(s: str | None) -> list[str]: 154 """for easier reading of long strings in json, split up by newlines 155 156 sort of like how jupyter notebooks do it 157 """ 158 if s is None: 159 return list() 160 else: 161 return s.splitlines(keepends=False)
for easier reading of long strings in json, split up by newlines
sort of like how jupyter notebooks do it
172def array_safe_eq(a: Any, b: Any) -> bool: # pyright: ignore[reportAny] 173 """check if two objects are equal, account for if numpy arrays or torch tensors""" 174 if a is b: 175 return True 176 177 if type(a) is not type(b): # pyright: ignore[reportAny] 178 return False 179 180 if ( 181 str(type(a)) == "<class 'numpy.ndarray'>" # pyright: ignore[reportAny, reportUnknownArgumentType] 182 and str(type(b)) == "<class 'numpy.ndarray'>" # pyright: ignore[reportAny, reportUnknownArgumentType] 183 ) or ( 184 str(type(a)) == "<class 'torch.Tensor'>" # pyright: ignore[reportAny, reportUnknownArgumentType] 185 and str(type(b)) == "<class 'torch.Tensor'>" # pyright: ignore[reportAny, reportUnknownArgumentType] 186 ): 187 return (a == b).all() # pyright: ignore[reportAny] 188 189 if ( 190 str(type(a)) == "<class 'pandas.core.frame.DataFrame'>" # pyright: ignore[reportUnknownArgumentType, reportAny] 191 and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>" # pyright: ignore[reportUnknownArgumentType, reportAny] 192 ): 193 return a.equals(b) # pyright: ignore[reportAny] 194 195 if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence): 196 if len(a) == 0 and len(b) == 0: 197 return True 198 return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b)) 199 200 if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)): 201 return len(a) == len(b) and all( # pyright: ignore[reportUnknownArgumentType] 202 array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2]) 203 for k1, k2 in zip(a.keys(), b.keys()) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] 204 ) 205 206 try: 207 return bool(a == b) # pyright: ignore[reportAny] 208 except (TypeError, ValueError) as e: 209 warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}") 210 return NotImplemented # type: ignore[return-value]
check if two objects are equal, account for if numpy arrays or torch tensors
214def dc_eq( 215 dc1: Any, # pyright: ignore[reportAny] 216 dc2: Any, # pyright: ignore[reportAny] 217 except_when_class_mismatch: bool = False, 218 false_when_class_mismatch: bool = True, 219 except_when_field_mismatch: bool = False, 220) -> bool: 221 """ 222 checks if two dataclasses which (might) hold numpy arrays are equal 223 224 # Parameters: 225 226 - `dc1`: the first dataclass 227 - `dc2`: the second dataclass 228 - `except_when_class_mismatch: bool` 229 if `True`, will throw `TypeError` if the classes are different. 230 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` 231 (default: `False`) 232 - `false_when_class_mismatch: bool` 233 only relevant if `except_when_class_mismatch` is `False`. 234 if `True`, will return `False` if the classes are different. 235 if `False`, will attempt to compare the fields. 236 - `except_when_field_mismatch: bool` 237 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. 238 if `True`, will throw `AttributeError` if the fields are different. 239 (default: `False`) 240 241 # Returns: 242 - `bool`: True if the dataclasses are equal, False otherwise 243 244 # Raises: 245 - `TypeError`: if the dataclasses are of different classes 246 - `AttributeError`: if the dataclasses have different fields 247 248 ``` 249 [START] 250 ▼ 251 ┌─────────────┐ 252 │ dc1 is dc2? │───Yes───► (True) 253 └──────┬──────┘ 254 │No 255 ▼ 256 ┌───────────────┐ 257 │ classes match?│───Yes───► [compare field values] ───► (True/False) 258 └──────┬────────┘ 259 │No 260 ▼ 261 ┌────────────────────────────┐ 262 │ except_when_class_mismatch?│───Yes───► { raise TypeError } 263 └─────────────┬──────────────┘ 264 │No 265 ▼ 266 ┌────────────────────────────┐ 267 │ false_when_class_mismatch? │───Yes───► (False) 268 └─────────────┬──────────────┘ 269 │No 270 ▼ 271 ┌────────────────────────────┐ 272 │ except_when_field_mismatch?│───No────► [compare field values] 273 └─────────────┬──────────────┘ 274 │Yes 275 ▼ 276 ┌───────────────┐ 277 │ fields match? │───Yes───► [compare field values] 278 └──────┬────────┘ 279 │No 280 ▼ 281 { raise AttributeError } 282 ``` 283 284 """ 285 if dc1 is dc2: 286 return True 287 288 if dc1.__class__ is not dc2.__class__: # pyright: ignore[reportAny] 289 if except_when_class_mismatch: 290 # if the classes don't match, raise an error 291 raise TypeError( 292 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" # pyright: ignore[reportAny] 293 ) 294 if false_when_class_mismatch: 295 # return False immediately without attempting field comparison 296 return False 297 # classes don't match but we'll try to compare fields anyway 298 if except_when_field_mismatch: 299 dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)]) # pyright: ignore[reportAny] 300 dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)]) # pyright: ignore[reportAny] 301 fields_match: bool = set(dc1_fields) == set(dc2_fields) 302 if not fields_match: 303 # if the fields don't match, raise an error 304 raise AttributeError( 305 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" 306 ) 307 308 return all( 309 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) # pyright: ignore[reportAny] 310 for fld in dataclasses.fields(dc1) # pyright: ignore[reportAny] 311 if fld.compare 312 )
checks if two dataclasses which (might) hold numpy arrays are equal
Parameters:
dc1: the first dataclassdc2: the second dataclassexcept_when_class_mismatch: boolifTrue, will throwTypeErrorif the classes are different. if not, will return false by default or attempt to compare the fields iffalse_when_class_mismatchisFalse(default:False)false_when_class_mismatch: boolonly relevant ifexcept_when_class_mismatchisFalse. ifTrue, will returnFalseif the classes are different. ifFalse, will attempt to compare the fields.except_when_field_mismatch: boolonly relevant ifexcept_when_class_mismatchisFalseandfalse_when_class_mismatchisFalse. ifTrue, will throwAttributeErrorif the fields are different. (default:False)
Returns:
bool: True if the dataclasses are equal, False otherwise
Raises:
TypeError: if the dataclasses are of different classesAttributeError: if the dataclasses have different fields
[START]
▼
┌─────────────┐
│ dc1 is dc2? │───Yes───► (True)
└──────┬──────┘
│No
▼
┌───────────────┐
│ classes match?│───Yes───► [compare field values] ───► (True/False)
└──────┬────────┘
│No
▼
┌────────────────────────────┐
│ except_when_class_mismatch?│───Yes───► { raise TypeError }
└─────────────┬──────────────┘
│No
▼
┌────────────────────────────┐
│ false_when_class_mismatch? │───Yes───► (False)
└─────────────┬──────────────┘
│No
▼
┌────────────────────────────┐
│ except_when_field_mismatch?│───No────► [compare field values]
└─────────────┬──────────────┘
│Yes
▼
┌───────────────┐
│ fields match? │───Yes───► [compare field values]
└──────┬────────┘
│No
▼
{ raise AttributeError }
61 class MonoTuple: # pyright: ignore[reportUnreachable] 62 """tuple type hint, but for a tuple of any length with all the same type""" 63 64 __slots__ = () 65 66 def __new__(cls, *args, **kwargs): 67 raise TypeError("Type MonoTuple cannot be instantiated.") 68 69 def __init_subclass__(cls, *args, **kwargs): 70 raise TypeError(f"Cannot subclass {cls.__module__}") 71 72 # idk why mypy thinks there is no such function in typing 73 @typing._tp_cache # type: ignore 74 def __class_getitem__(cls, params): 75 if getattr(params, "__origin__", None) == typing.Union: 76 return typing.GenericAlias(tuple, (params, Ellipsis)) 77 elif isinstance(params, type): 78 typing.GenericAlias(tuple, (params, Ellipsis)) 79 # test if has len and is iterable 80 elif isinstance(params, Iterable): 81 if len(params) == 0: 82 return tuple 83 elif len(params) == 1: 84 return typing.GenericAlias(tuple, (params[0], Ellipsis)) 85 else: 86 raise TypeError(f"MonoTuple expects 1 type argument, got {params = }")
tuple type hint, but for a tuple of any length with all the same type