Coverage for polypandas/testing.py: 88%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-02-24 14:21 -0500

1"""Testing utilities for pandas DataFrames.""" 

2 

3from typing import Any, Dict, List, Optional 

4 

5from polypandas.exceptions import PolypandasError 

6from polypandas.protocols import is_pandas_available 

7 

8 

9class DataFrameComparisonError(PolypandasError): 

10 """Raised when DataFrame comparison fails.""" 

11 

12 pass 

13 

14 

15def assert_dataframe_equal( 

16 df1: Any, 

17 df2: Any, 

18 check_order: bool = False, 

19 rtol: float = 1e-5, 

20 atol: float = 1e-8, 

21 check_dtypes: bool = True, 

22 check_column_order: bool = False, 

23) -> None: 

24 """Assert that two pandas DataFrames are equal. 

25 

26 Args: 

27 df1: First DataFrame to compare. 

28 df2: Second DataFrame to compare. 

29 check_order: If True, row order must match. If False, DataFrames are sorted. 

30 rtol: Relative tolerance for floating point comparisons. 

31 atol: Absolute tolerance for floating point comparisons. 

32 check_dtypes: If True, check that dtypes match. 

33 check_column_order: If True, column order must match. 

34 

35 Raises: 

36 DataFrameComparisonError: If DataFrames are not equal. 

37 """ 

38 if not is_pandas_available(): 

39 raise DataFrameComparisonError( 

40 "pandas is required for DataFrame comparison. Install it with: pip install pandas" 

41 ) 

42 

43 import pandas as pd 

44 

45 if not isinstance(df1, pd.DataFrame) or not isinstance(df2, pd.DataFrame): 

46 raise DataFrameComparisonError("Both arguments must be pandas DataFrames") 

47 

48 if len(df1) != len(df2): 

49 raise DataFrameComparisonError( 

50 f"DataFrame row counts don't match: {len(df1)} != {len(df2)}" 

51 ) 

52 

53 if set(df1.columns) != set(df2.columns): 

54 raise DataFrameComparisonError( 

55 f"DataFrame columns don't match: {set(df1.columns)} != {set(df2.columns)}" 

56 ) 

57 

58 if check_column_order and list(df1.columns) != list(df2.columns): 

59 raise DataFrameComparisonError("Column order does not match") 

60 

61 if not check_column_order: 

62 df2 = df2[df1.columns] 

63 

64 if not check_order and len(df1) > 0: 

65 df1 = df1.sort_values(by=list(df1.columns)).reset_index(drop=True) 

66 df2 = df2.sort_values(by=list(df2.columns)).reset_index(drop=True) 

67 

68 try: 

69 pd.testing.assert_frame_equal( 

70 df1, 

71 df2, 

72 check_dtype=check_dtypes, 

73 rtol=rtol, 

74 atol=atol, 

75 ) 

76 except AssertionError as e: 

77 raise DataFrameComparisonError(str(e)) from e 

78 

79 

80def assert_schema_equal( 

81 df1: Any, 

82 df2: Any, 

83 check_order: bool = False, 

84) -> None: 

85 """Assert that two DataFrames have the same dtypes (schema). Alias for assert_dtypes_equal.""" 

86 assert_dtypes_equal(df1, df2, check_order=check_order) 

87 

88 

89def assert_dtypes_equal( 

90 df1: Any, 

91 df2: Any, 

92 check_order: bool = False, 

93) -> None: 

94 """Assert that two DataFrames have the same dtypes. 

95 

96 Args: 

97 df1: First DataFrame. 

98 df2: Second DataFrame. 

99 check_order: If True, column order must match. 

100 

101 Raises: 

102 DataFrameComparisonError: If dtypes are not equal. 

103 """ 

104 if not is_pandas_available(): 

105 raise DataFrameComparisonError("pandas is required") 

106 

107 dtypes1 = df1.dtypes 

108 dtypes2 = df2.dtypes 

109 

110 if set(dtypes1.index) != set(dtypes2.index): 

111 raise DataFrameComparisonError("DataFrames have different columns") 

112 

113 if not check_order: 

114 dtypes2 = dtypes2[dtypes1.index] 

115 

116 for col in dtypes1.index: 

117 if dtypes1[col] != dtypes2[col]: 

118 raise DataFrameComparisonError( 

119 f"Column '{col}' has different dtypes: {dtypes1[col]} != {dtypes2[col]}" 

120 ) 

121 

122 

123def assert_approx_count(df: Any, expected_count: int, tolerance: float = 0.1) -> None: 

124 """Assert that DataFrame row count is approximately equal to expected.""" 

125 if not is_pandas_available(): 

126 raise DataFrameComparisonError("pandas is required") 

127 

128 actual_count = len(df) 

129 min_count = int(expected_count * (1 - tolerance)) 

130 max_count = int(expected_count * (1 + tolerance)) 

131 

132 if not (min_count <= actual_count <= max_count): 

133 raise DataFrameComparisonError( 

134 f"DataFrame count {actual_count} is not within {tolerance * 100:.1f}% " 

135 f"of expected {expected_count}. Expected range: [{min_count}, {max_count}]" 

136 ) 

137 

138 

139def get_column_stats(df: Any, column: str) -> Dict[str, Any]: 

140 """Get basic statistics for a column.""" 

141 if not is_pandas_available(): 

142 raise DataFrameComparisonError("pandas is required") 

143 

144 import pandas as pd 

145 

146 s = df[column] 

147 stats: Dict[str, Any] = { 

148 "count": int(s.count()), 

149 "null_count": int(s.isna().sum()), 

150 "distinct_count": int(s.nunique()), 

151 } 

152 

153 if pd.api.types.is_numeric_dtype(s): 

154 stats["min"] = float(s.min()) if s.count() else None 

155 stats["max"] = float(s.max()) if s.count() else None 

156 stats["mean"] = float(s.mean()) if s.count() else None 

157 stats["std"] = float(s.std()) if s.count() else None 

158 

159 return stats 

160 

161 

162def assert_column_exists(df: Any, *columns: str) -> None: 

163 """Assert that specified columns exist in DataFrame.""" 

164 df_columns = set(df.columns) 

165 missing = [c for c in columns if c not in df_columns] 

166 

167 if missing: 

168 raise DataFrameComparisonError( 

169 f"Columns missing from DataFrame: {missing}. Available columns: {sorted(df_columns)}" 

170 ) 

171 

172 

173def assert_no_duplicates(df: Any, columns: Optional[List[str]] = None) -> None: 

174 """Assert that DataFrame has no duplicate rows.""" 

175 if not is_pandas_available(): 

176 raise DataFrameComparisonError("pandas is required") 

177 

178 total_count = len(df) 

179 if columns is None: 

180 unique_count = len(df.drop_duplicates()) 

181 else: 

182 unique_count = len(df.drop_duplicates(subset=columns)) 

183 

184 if unique_count != total_count: 

185 raise DataFrameComparisonError( 

186 f"DataFrame contains {total_count - unique_count} duplicate row(s). " 

187 f"Total rows: {total_count}, Unique rows: {unique_count}" 

188 )