Coverage for src/inline_snapshot_pandas/__init__.py: 100%
47 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-21 22:19 +0200
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-21 22:19 +0200
1from functools import wraps
2from typing import Optional
4import pandas.testing
5from inline_snapshot import customize_repr
6from inline_snapshot import snapshot
7from inline_snapshot._inline_snapshot import GenericValue
9__all__ = (
10 "setup",
11 "assert_frame_equal",
12 "assert_series_equal",
13 "assert_index_equal",
14 "snapshot",
15)
18def make_assert_equal(data_type, assert_equal, repr_function):
20 class Wrapper:
21 def __init__(self, df, cmp):
22 self.df = df
23 self.cmp = cmp
25 def __repr__(self):
26 return f"{data_type.__name__}({repr_function(self.df)!r})"
28 def __eq__(self, other):
29 if isinstance(other, data_type):
30 return self.cmp(self.df, other)
31 if isinstance(other, Wrapper) and isinstance(other.df, data_type):
32 return self.cmp(self.df, other.df)
33 return NotImplemented
35 original = data_type.__eq__
37 def new_eq(a, b):
38 if isinstance(b, (GenericValue, Wrapper)):
39 return NotImplemented
40 return original(a, b)
42 data_type.__eq__ = new_eq
44 @wraps(assert_equal)
45 def result(df, df_snapshot, *args, **kargs):
46 error: Optional[AssertionError] = None
48 def cmp(a, b):
49 nonlocal error
50 try:
51 assert_equal(a, b, *args, **kargs)
52 except AssertionError as e:
53 error = e
54 return False
55 return True
57 if not Wrapper(df, cmp) == df_snapshot:
58 assert error is not None
59 raise error
61 return result
64assert_frame_equal = make_assert_equal(
65 pandas.DataFrame,
66 pandas.testing.assert_frame_equal,
67 lambda df: df.to_dict("records"),
68)
69assert_series_equal = make_assert_equal(
70 pandas.Series, pandas.testing.assert_series_equal, lambda df: df.to_dict()
71)
72assert_index_equal = make_assert_equal(
73 pandas.Index, pandas.testing.assert_index_equal, lambda df: df.to_list()
74)
77def setup():
78 pandas.testing.assert_frame_equal = assert_frame_equal
79 pandas.testing.assert_series_equal = assert_series_equal
80 pandas.testing.assert_index_equal = assert_index_equal