docs for muutils v0.9.1
View Source on GitHub

muutils.json_serialize.util

utilities for json_serialize


  1"""utilities for json_serialize"""
  2
  3from __future__ import annotations
  4
  5import dataclasses
  6import functools
  7import inspect
  8import sys
  9import typing
 10import warnings
 11from typing import Any, Callable, Iterable, TypeVar, Union
 12
 13from muutils.json_serialize.types import BaseType, Hashableitem
 14
 15if typing.TYPE_CHECKING:
 16    pass
 17
 18_NUMPY_WORKING: bool
 19try:
 20    _NUMPY_WORKING = True
 21except ImportError:
 22    warnings.warn("numpy not found, cannot serialize numpy arrays!")
 23    _NUMPY_WORKING = False
 24
 25
 26# pyright: reportExplicitAny=false
 27
 28# At type-checking time, include array serialization types to avoid nominal type errors
 29# This avoids superfluous imports at runtime
 30# if TYPE_CHECKING:
 31#     from muutils.json_serialize.array import NumericList, SerializedArrayWithMeta
 32
 33#     JSONitem = Union[
 34#         BaseType,
 35#         typing.Sequence["JSONitem"],
 36#         typing.Dict[str, "JSONitem"],
 37#         SerializedArrayWithMeta,
 38#         NumericList,
 39#     ]
 40# else:
 41
 42JSONitem = Union[
 43    BaseType,
 44    typing.Sequence["JSONitem"],
 45    typing.Dict[str, "JSONitem"],
 46    # TODO: figure this out
 47    # "_SerializedSet",
 48    # "_SerializedFrozenset",
 49]
 50
 51JSONdict = typing.Dict[str, JSONitem]
 52
 53
 54# TODO: this bit is very broken
 55# or if python version <3.9
 56if typing.TYPE_CHECKING or sys.version_info < (3, 9):
 57    MonoTuple = typing.Sequence
 58else:
 59
 60    class MonoTuple:  # pyright: ignore[reportUnreachable]
 61        """tuple type hint, but for a tuple of any length with all the same type"""
 62
 63        __slots__ = ()
 64
 65        def __new__(cls, *args, **kwargs):
 66            raise TypeError("Type MonoTuple cannot be instantiated.")
 67
 68        def __init_subclass__(cls, *args, **kwargs):
 69            raise TypeError(f"Cannot subclass {cls.__module__}")
 70
 71        # idk why mypy thinks there is no such function in typing
 72        @typing._tp_cache  # type: ignore
 73        def __class_getitem__(cls, params):
 74            if getattr(params, "__origin__", None) == typing.Union:
 75                return typing.GenericAlias(tuple, (params, Ellipsis))
 76            elif isinstance(params, type):
 77                typing.GenericAlias(tuple, (params, Ellipsis))
 78            # test if has len and is iterable
 79            elif isinstance(params, Iterable):
 80                if len(params) == 0:
 81                    return tuple
 82                elif len(params) == 1:
 83                    return typing.GenericAlias(tuple, (params[0], Ellipsis))
 84            else:
 85                raise TypeError(f"MonoTuple expects 1 type argument, got {params = }")
 86
 87
 88# TYPING: we allow `Any` here because the container is... universal
 89class UniversalContainer:
 90    """contains everything -- `x in UniversalContainer()` is always True"""
 91
 92    def __contains__(self, x: Any) -> bool:  # pyright: ignore[reportAny]
 93        return True
 94
 95
 96def isinstance_namedtuple(x: Any) -> bool:  # pyright: ignore[reportAny]
 97    """checks if `x` is a `namedtuple`
 98
 99    credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
100    """
101    t: type = type(x)  # pyright: ignore[reportUnknownVariableType, reportAny]
102    b: tuple[type, ...] = t.__bases__
103    if len(b) != 1 or (b[0] is not tuple):
104        return False
105    f: Any = getattr(t, "_fields", None)
106    if not isinstance(f, tuple):
107        return False
108    # fine that the type is unknown -- that's what we want to check
109    return all(isinstance(n, str) for n in f)  # pyright: ignore[reportUnknownVariableType]
110
111
112T_FuncTryCatchReturn = TypeVar("T_FuncTryCatchReturn")
113
114
115def try_catch(
116    func: Callable[..., T_FuncTryCatchReturn],
117) -> Callable[..., Union[T_FuncTryCatchReturn, str]]:
118    """wraps the function to catch exceptions, returns serialized error message on exception
119
120    returned func will return normal result on success, or error message on exception
121    """
122
123    @functools.wraps(func)
124    def newfunc(*args: Any, **kwargs: Any) -> Union[T_FuncTryCatchReturn, str]:  # pyright: ignore[reportAny]
125        try:
126            return func(*args, **kwargs)
127        except Exception as e:
128            return f"{e.__class__.__name__}: {e}"
129
130    return newfunc
131
132
133# TYPING: can we get rid of any of these?
134def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem:  # pyright: ignore[reportAny]
135    if isinstance(obj, typing.Mapping):
136        return tuple((k, _recursive_hashify(v)) for k, v in obj.items())  # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType]
137    elif isinstance(obj, (bool, int, float, str)):
138        return obj
139    elif isinstance(obj, (tuple, list, Iterable)):
140        return tuple(_recursive_hashify(v) for v in obj)  # pyright: ignore[reportUnknownVariableType]
141    else:
142        if force:
143            return str(obj)  # pyright: ignore[reportAny]
144        else:
145            raise ValueError(f"cannot hashify:\n{obj}")
146
147
148class SerializationException(Exception):
149    pass
150
151
152def string_as_lines(s: str | None) -> list[str]:
153    """for easier reading of long strings in json, split up by newlines
154
155    sort of like how jupyter notebooks do it
156    """
157    if s is None:
158        return list()
159    else:
160        return s.splitlines(keepends=False)
161
162
163def safe_getsource(func: Callable[..., Any]) -> list[str]:
164    try:
165        return string_as_lines(inspect.getsource(func))
166    except Exception as e:
167        return string_as_lines(f"Error: Unable to retrieve source code:\n{e}")
168
169
170# credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises
171def array_safe_eq(a: Any, b: Any) -> bool:  # pyright: ignore[reportAny]
172    """check if two objects are equal, account for if numpy arrays or torch tensors"""
173    if a is b:
174        return True
175
176    if type(a) is not type(b):  # pyright: ignore[reportAny]
177        return False
178
179    if (
180        str(type(a)) == "<class 'numpy.ndarray'>"  # pyright: ignore[reportAny, reportUnknownArgumentType]
181        and str(type(b)) == "<class 'numpy.ndarray'>"  # pyright: ignore[reportAny, reportUnknownArgumentType]
182    ) or (
183        str(type(a)) == "<class 'torch.Tensor'>"  # pyright: ignore[reportAny, reportUnknownArgumentType]
184        and str(type(b)) == "<class 'torch.Tensor'>"  # pyright: ignore[reportAny, reportUnknownArgumentType]
185    ):
186        return (a == b).all()  # pyright: ignore[reportAny]
187
188    if (
189        str(type(a)) == "<class 'pandas.core.frame.DataFrame'>"  # pyright: ignore[reportUnknownArgumentType, reportAny]
190        and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>"  # pyright: ignore[reportUnknownArgumentType, reportAny]
191    ):
192        return a.equals(b)  # pyright: ignore[reportAny]
193
194    if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence):
195        if len(a) == 0 and len(b) == 0:
196            return True
197        return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b))
198
199    if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)):
200        return len(a) == len(b) and all(  # pyright: ignore[reportUnknownArgumentType]
201            array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2])
202            for k1, k2 in zip(a.keys(), b.keys())  # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
203        )
204
205    try:
206        return bool(a == b)  # pyright: ignore[reportAny]
207    except (TypeError, ValueError) as e:
208        warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}")
209        return NotImplemented  # type: ignore[return-value]
210
211
212# TYPING: see what can be done about so many `Any`s here
213def dc_eq(
214    dc1: Any,  # pyright: ignore[reportAny]
215    dc2: Any,  # pyright: ignore[reportAny]
216    except_when_class_mismatch: bool = False,
217    false_when_class_mismatch: bool = True,
218    except_when_field_mismatch: bool = False,
219) -> bool:
220    """
221    checks if two dataclasses which (might) hold numpy arrays are equal
222
223    # Parameters:
224
225    - `dc1`: the first dataclass
226    - `dc2`: the second dataclass
227    - `except_when_class_mismatch: bool`
228        if `True`, will throw `TypeError` if the classes are different.
229        if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False`
230        (default: `False`)
231    - `false_when_class_mismatch: bool`
232        only relevant if `except_when_class_mismatch` is `False`.
233        if `True`, will return `False` if the classes are different.
234        if `False`, will attempt to compare the fields.
235    - `except_when_field_mismatch: bool`
236        only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`.
237        if `True`, will throw `AttributeError` if the fields are different.
238        (default: `False`)
239
240    # Returns:
241    - `bool`: True if the dataclasses are equal, False otherwise
242
243    # Raises:
244    - `TypeError`: if the dataclasses are of different classes
245    - `AttributeError`: if the dataclasses have different fields
246
247    ```
248    [START]
249
250    ┌─────────────┐
251    │ dc1 is dc2? │───Yes───► (True)
252    └──────┬──────┘
253           │No
254
255    ┌───────────────┐
256    │ classes match?│───Yes───► [compare field values] ───► (True/False)
257    └──────┬────────┘
258           │No
259
260    ┌────────────────────────────┐
261    │ except_when_class_mismatch?│───Yes───► { raise TypeError }
262    └─────────────┬──────────────┘
263                  │No
264
265    ┌────────────────────────────┐
266    │ false_when_class_mismatch? │───Yes───► (False)
267    └─────────────┬──────────────┘
268                  │No
269
270    ┌────────────────────────────┐
271    │ except_when_field_mismatch?│───No────► [compare field values]
272    └─────────────┬──────────────┘
273                  │Yes
274
275    ┌───────────────┐
276    │ fields match? │───Yes───► [compare field values]
277    └──────┬────────┘
278           │No
279
280    { raise AttributeError }
281    ```
282
283    """
284    if dc1 is dc2:
285        return True
286
287    if dc1.__class__ is not dc2.__class__:  # pyright: ignore[reportAny]
288        if except_when_class_mismatch:
289            # if the classes don't match, raise an error
290            raise TypeError(
291                f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`"  # pyright: ignore[reportAny]
292            )
293        if false_when_class_mismatch:
294            # return False immediately without attempting field comparison
295            return False
296        # classes don't match but we'll try to compare fields anyway
297        if except_when_field_mismatch:
298            dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)])  # pyright: ignore[reportAny]
299            dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)])  # pyright: ignore[reportAny]
300            fields_match: bool = set(dc1_fields) == set(dc2_fields)
301            if not fields_match:
302                # if the fields don't match, raise an error
303                raise AttributeError(
304                    f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`"
305                )
306
307    return all(
308        array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name))  # pyright: ignore[reportAny]
309        for fld in dataclasses.fields(dc1)  # pyright: ignore[reportAny]
310        if fld.compare
311    )

JSONitem = typing.Union[bool, int, float, str, NoneType, typing.Sequence[ForwardRef('JSONitem')], typing.Dict[str, ForwardRef('JSONitem')]]
JSONdict = typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.Sequence[ForwardRef('JSONitem')], typing.Dict[str, ForwardRef('JSONitem')]]]
class UniversalContainer:
90class UniversalContainer:
91    """contains everything -- `x in UniversalContainer()` is always True"""
92
93    def __contains__(self, x: Any) -> bool:  # pyright: ignore[reportAny]
94        return True

contains everything -- x in UniversalContainer() is always True

def isinstance_namedtuple(x: Any) -> bool:
 97def isinstance_namedtuple(x: Any) -> bool:  # pyright: ignore[reportAny]
 98    """checks if `x` is a `namedtuple`
 99
100    credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
101    """
102    t: type = type(x)  # pyright: ignore[reportUnknownVariableType, reportAny]
103    b: tuple[type, ...] = t.__bases__
104    if len(b) != 1 or (b[0] is not tuple):
105        return False
106    f: Any = getattr(t, "_fields", None)
107    if not isinstance(f, tuple):
108        return False
109    # fine that the type is unknown -- that's what we want to check
110    return all(isinstance(n, str) for n in f)  # pyright: ignore[reportUnknownVariableType]
def try_catch( func: Callable[..., ~T_FuncTryCatchReturn]) -> Callable[..., Union[~T_FuncTryCatchReturn, str]]:
116def try_catch(
117    func: Callable[..., T_FuncTryCatchReturn],
118) -> Callable[..., Union[T_FuncTryCatchReturn, str]]:
119    """wraps the function to catch exceptions, returns serialized error message on exception
120
121    returned func will return normal result on success, or error message on exception
122    """
123
124    @functools.wraps(func)
125    def newfunc(*args: Any, **kwargs: Any) -> Union[T_FuncTryCatchReturn, str]:  # pyright: ignore[reportAny]
126        try:
127            return func(*args, **kwargs)
128        except Exception as e:
129            return f"{e.__class__.__name__}: {e}"
130
131    return newfunc

wraps the function to catch exceptions, returns serialized error message on exception

returned func will return normal result on success, or error message on exception

class SerializationException(builtins.Exception):
149class SerializationException(Exception):
150    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
add_note
args
def string_as_lines(s: str | None) -> list[str]:
153def string_as_lines(s: str | None) -> list[str]:
154    """for easier reading of long strings in json, split up by newlines
155
156    sort of like how jupyter notebooks do it
157    """
158    if s is None:
159        return list()
160    else:
161        return s.splitlines(keepends=False)

for easier reading of long strings in json, split up by newlines

sort of like how jupyter notebooks do it

def safe_getsource(func: Callable[..., Any]) -> list[str]:
164def safe_getsource(func: Callable[..., Any]) -> list[str]:
165    try:
166        return string_as_lines(inspect.getsource(func))
167    except Exception as e:
168        return string_as_lines(f"Error: Unable to retrieve source code:\n{e}")
def array_safe_eq(a: Any, b: Any) -> bool:
172def array_safe_eq(a: Any, b: Any) -> bool:  # pyright: ignore[reportAny]
173    """check if two objects are equal, account for if numpy arrays or torch tensors"""
174    if a is b:
175        return True
176
177    if type(a) is not type(b):  # pyright: ignore[reportAny]
178        return False
179
180    if (
181        str(type(a)) == "<class 'numpy.ndarray'>"  # pyright: ignore[reportAny, reportUnknownArgumentType]
182        and str(type(b)) == "<class 'numpy.ndarray'>"  # pyright: ignore[reportAny, reportUnknownArgumentType]
183    ) or (
184        str(type(a)) == "<class 'torch.Tensor'>"  # pyright: ignore[reportAny, reportUnknownArgumentType]
185        and str(type(b)) == "<class 'torch.Tensor'>"  # pyright: ignore[reportAny, reportUnknownArgumentType]
186    ):
187        return (a == b).all()  # pyright: ignore[reportAny]
188
189    if (
190        str(type(a)) == "<class 'pandas.core.frame.DataFrame'>"  # pyright: ignore[reportUnknownArgumentType, reportAny]
191        and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>"  # pyright: ignore[reportUnknownArgumentType, reportAny]
192    ):
193        return a.equals(b)  # pyright: ignore[reportAny]
194
195    if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence):
196        if len(a) == 0 and len(b) == 0:
197            return True
198        return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b))
199
200    if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)):
201        return len(a) == len(b) and all(  # pyright: ignore[reportUnknownArgumentType]
202            array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2])
203            for k1, k2 in zip(a.keys(), b.keys())  # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
204        )
205
206    try:
207        return bool(a == b)  # pyright: ignore[reportAny]
208    except (TypeError, ValueError) as e:
209        warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}")
210        return NotImplemented  # type: ignore[return-value]

check if two objects are equal, account for if numpy arrays or torch tensors

def dc_eq( dc1: Any, dc2: Any, except_when_class_mismatch: bool = False, false_when_class_mismatch: bool = True, except_when_field_mismatch: bool = False) -> bool:
214def dc_eq(
215    dc1: Any,  # pyright: ignore[reportAny]
216    dc2: Any,  # pyright: ignore[reportAny]
217    except_when_class_mismatch: bool = False,
218    false_when_class_mismatch: bool = True,
219    except_when_field_mismatch: bool = False,
220) -> bool:
221    """
222    checks if two dataclasses which (might) hold numpy arrays are equal
223
224    # Parameters:
225
226    - `dc1`: the first dataclass
227    - `dc2`: the second dataclass
228    - `except_when_class_mismatch: bool`
229        if `True`, will throw `TypeError` if the classes are different.
230        if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False`
231        (default: `False`)
232    - `false_when_class_mismatch: bool`
233        only relevant if `except_when_class_mismatch` is `False`.
234        if `True`, will return `False` if the classes are different.
235        if `False`, will attempt to compare the fields.
236    - `except_when_field_mismatch: bool`
237        only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`.
238        if `True`, will throw `AttributeError` if the fields are different.
239        (default: `False`)
240
241    # Returns:
242    - `bool`: True if the dataclasses are equal, False otherwise
243
244    # Raises:
245    - `TypeError`: if the dataclasses are of different classes
246    - `AttributeError`: if the dataclasses have different fields
247
248    ```
249    [START]
250
251    ┌─────────────┐
252    │ dc1 is dc2? │───Yes───► (True)
253    └──────┬──────┘
254           │No
255
256    ┌───────────────┐
257    │ classes match?│───Yes───► [compare field values] ───► (True/False)
258    └──────┬────────┘
259           │No
260
261    ┌────────────────────────────┐
262    │ except_when_class_mismatch?│───Yes───► { raise TypeError }
263    └─────────────┬──────────────┘
264                  │No
265
266    ┌────────────────────────────┐
267    │ false_when_class_mismatch? │───Yes───► (False)
268    └─────────────┬──────────────┘
269                  │No
270
271    ┌────────────────────────────┐
272    │ except_when_field_mismatch?│───No────► [compare field values]
273    └─────────────┬──────────────┘
274                  │Yes
275
276    ┌───────────────┐
277    │ fields match? │───Yes───► [compare field values]
278    └──────┬────────┘
279           │No
280
281    { raise AttributeError }
282    ```
283
284    """
285    if dc1 is dc2:
286        return True
287
288    if dc1.__class__ is not dc2.__class__:  # pyright: ignore[reportAny]
289        if except_when_class_mismatch:
290            # if the classes don't match, raise an error
291            raise TypeError(
292                f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`"  # pyright: ignore[reportAny]
293            )
294        if false_when_class_mismatch:
295            # return False immediately without attempting field comparison
296            return False
297        # classes don't match but we'll try to compare fields anyway
298        if except_when_field_mismatch:
299            dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)])  # pyright: ignore[reportAny]
300            dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)])  # pyright: ignore[reportAny]
301            fields_match: bool = set(dc1_fields) == set(dc2_fields)
302            if not fields_match:
303                # if the fields don't match, raise an error
304                raise AttributeError(
305                    f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`"
306                )
307
308    return all(
309        array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name))  # pyright: ignore[reportAny]
310        for fld in dataclasses.fields(dc1)  # pyright: ignore[reportAny]
311        if fld.compare
312    )

checks if two dataclasses which (might) hold numpy arrays are equal

Parameters:

  • dc1: the first dataclass
  • dc2: the second dataclass
  • except_when_class_mismatch: bool if True, will throw TypeError if the classes are different. if not, will return false by default or attempt to compare the fields if false_when_class_mismatch is False (default: False)
  • false_when_class_mismatch: bool only relevant if except_when_class_mismatch is False. if True, will return False if the classes are different. if False, will attempt to compare the fields.
  • except_when_field_mismatch: bool only relevant if except_when_class_mismatch is False and false_when_class_mismatch is False. if True, will throw AttributeError if the fields are different. (default: False)

Returns:

  • bool: True if the dataclasses are equal, False otherwise

Raises:

  • TypeError: if the dataclasses are of different classes
  • AttributeError: if the dataclasses have different fields
[START]
   ▼
┌─────────────┐
│ dc1 is dc2? │───Yes───► (True)
└──────┬──────┘
       │No
       ▼
┌───────────────┐
│ classes match?│───Yes───► [compare field values] ───► (True/False)
└──────┬────────┘
       │No
       ▼
┌────────────────────────────┐
│ except_when_class_mismatch?│───Yes───► { raise TypeError }
└─────────────┬──────────────┘
              │No
              ▼
┌────────────────────────────┐
│ false_when_class_mismatch? │───Yes───► (False)
└─────────────┬──────────────┘
              │No
              ▼
┌────────────────────────────┐
│ except_when_field_mismatch?│───No────► [compare field values]
└─────────────┬──────────────┘
              │Yes
              ▼
┌───────────────┐
│ fields match? │───Yes───► [compare field values]
└──────┬────────┘
       │No
       ▼
{ raise AttributeError }
class MonoTuple:
61    class MonoTuple:  # pyright: ignore[reportUnreachable]
62        """tuple type hint, but for a tuple of any length with all the same type"""
63
64        __slots__ = ()
65
66        def __new__(cls, *args, **kwargs):
67            raise TypeError("Type MonoTuple cannot be instantiated.")
68
69        def __init_subclass__(cls, *args, **kwargs):
70            raise TypeError(f"Cannot subclass {cls.__module__}")
71
72        # idk why mypy thinks there is no such function in typing
73        @typing._tp_cache  # type: ignore
74        def __class_getitem__(cls, params):
75            if getattr(params, "__origin__", None) == typing.Union:
76                return typing.GenericAlias(tuple, (params, Ellipsis))
77            elif isinstance(params, type):
78                typing.GenericAlias(tuple, (params, Ellipsis))
79            # test if has len and is iterable
80            elif isinstance(params, Iterable):
81                if len(params) == 0:
82                    return tuple
83                elif len(params) == 1:
84                    return typing.GenericAlias(tuple, (params[0], Ellipsis))
85            else:
86                raise TypeError(f"MonoTuple expects 1 type argument, got {params = }")

tuple type hint, but for a tuple of any length with all the same type