Coverage for src/inline_snapshot_pandas/__init__.py: 100%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-21 22:19 +0200

1from functools import wraps 

2from typing import Optional 

3 

4import pandas.testing 

5from inline_snapshot import customize_repr 

6from inline_snapshot import snapshot 

7from inline_snapshot._inline_snapshot import GenericValue 

8 

9__all__ = ( 

10 "setup", 

11 "assert_frame_equal", 

12 "assert_series_equal", 

13 "assert_index_equal", 

14 "snapshot", 

15) 

16 

17 

18def make_assert_equal(data_type, assert_equal, repr_function): 

19 

20 class Wrapper: 

21 def __init__(self, df, cmp): 

22 self.df = df 

23 self.cmp = cmp 

24 

25 def __repr__(self): 

26 return f"{data_type.__name__}({repr_function(self.df)!r})" 

27 

28 def __eq__(self, other): 

29 if isinstance(other, data_type): 

30 return self.cmp(self.df, other) 

31 if isinstance(other, Wrapper) and isinstance(other.df, data_type): 

32 return self.cmp(self.df, other.df) 

33 return NotImplemented 

34 

35 original = data_type.__eq__ 

36 

37 def new_eq(a, b): 

38 if isinstance(b, (GenericValue, Wrapper)): 

39 return NotImplemented 

40 return original(a, b) 

41 

42 data_type.__eq__ = new_eq 

43 

44 @wraps(assert_equal) 

45 def result(df, df_snapshot, *args, **kargs): 

46 error: Optional[AssertionError] = None 

47 

48 def cmp(a, b): 

49 nonlocal error 

50 try: 

51 assert_equal(a, b, *args, **kargs) 

52 except AssertionError as e: 

53 error = e 

54 return False 

55 return True 

56 

57 if not Wrapper(df, cmp) == df_snapshot: 

58 assert error is not None 

59 raise error 

60 

61 return result 

62 

63 

64assert_frame_equal = make_assert_equal( 

65 pandas.DataFrame, 

66 pandas.testing.assert_frame_equal, 

67 lambda df: df.to_dict("records"), 

68) 

69assert_series_equal = make_assert_equal( 

70 pandas.Series, pandas.testing.assert_series_equal, lambda df: df.to_dict() 

71) 

72assert_index_equal = make_assert_equal( 

73 pandas.Index, pandas.testing.assert_index_equal, lambda df: df.to_list() 

74) 

75 

76 

77def setup(): 

78 pandas.testing.assert_frame_equal = assert_frame_equal 

79 pandas.testing.assert_series_equal = assert_series_equal 

80 pandas.testing.assert_index_equal = assert_index_equal