Coverage for tests/test_remove_one.py: 90%

131 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-03-25 08:14 +0100

1import ast 

2import hashlib 

3import random 

4import sys 

5from pathlib import Path 

6 

7import pytest 

8from pysource_codegen import generate 

9 

10import pysource_minimize._minimize 

11from . import session_config 

12from .dump_tree import dump_tree 

13from pysource_minimize import minimize 

14from tests.utils import testing_enabled 

15 

16try: 

17 import pysource_minimize_testing # type: ignore 

18except ImportError: 

19 import pysource_minimize as pysource_minimize_testing 

20 

21sample_dir = Path(__file__).parent / "remove_one_samples" 

22 

23sample_dir.mkdir(exist_ok=True) 

24 

25 

26def node_weights(source): 

27 tree = ast.parse(source) 

28 

29 def weight(node): 

30 

31 result = 1 

32 if isinstance( 

33 node, 

34 ( 

35 ast.Pass, 

36 ast.expr_context, 

37 ast.Expr, 

38 ast.boolop, 

39 ast.unaryop, 

40 ast.keyword, 

41 ast.withitem, 

42 ast.For, 

43 ast.AsyncFor, 

44 ast.BoolOp, 

45 ast.AnnAssign, 

46 ast.AugAssign, 

47 ast.Compare, 

48 ast.cmpop, 

49 ast.BinOp, 

50 ast.operator, 

51 ast.Assign, 

52 ast.Import, 

53 ast.Delete, 

54 ast.ImportFrom, 

55 ast.arguments, 

56 ), 

57 ): 

58 result = 0 

59 

60 if isinstance(node, (ast.GeneratorExp, ast.ListComp, ast.SetComp)): 

61 result = 0 

62 if isinstance(node, (ast.DictComp)): 

63 result = -2 

64 if isinstance(node, ast.comprehension): 

65 # removing comrehension removes variable and iterable 

66 result = -1 

67 

68 if isinstance(node, (ast.Dict)): 

69 result = -len(node.keys) + 1 

70 if sys.version_info >= (3, 8) and isinstance(node, ast.NamedExpr): 

71 result = 0 

72 

73 if isinstance(node, ast.Constant): 

74 if isinstance(node.value, bool): 

75 result = int(node.value) + 1 

76 elif isinstance(node.value, int): 

77 result = bin(node.value).count("1") 

78 elif isinstance(node.value, float): 

79 result = abs(int(node.value * 10)) + 1 

80 elif isinstance(node.value, (bytes, str)): 

81 result = len(node.value) + 1 

82 

83 if isinstance(node, ast.FormattedValue): 

84 result = 0 

85 if isinstance(node, ast.JoinedStr): 

86 # work around for https://github.com/python/cpython/issues/110309 

87 result = -(sum(isinstance(n, ast.Constant) for n in node.values)) 

88 

89 if isinstance(node, ast.IfExp): 

90 result = -1 

91 if isinstance(node, ast.Subscript): 

92 result = 0 

93 if isinstance(node, ast.Index): 

94 result = 0 

95 

96 if isinstance(node, (ast.Nonlocal, ast.Global)): 

97 result = len(node.names) 

98 

99 # match 

100 if sys.version_info >= (3, 10): 

101 if isinstance(node, ast.MatchValue): 

102 result = -1 

103 if isinstance(node, (ast.MatchOr, ast.match_case, ast.MatchClass)): 

104 result = 0 

105 if isinstance(node, ast.Match): 

106 result = -1 # for the subject 

107 if isinstance(node, ast.MatchMapping): 107 ↛ 109line 107 didn't jump to line 109, because the condition on line 107 was never true

108 # key-value pairs can only be removed together 

109 result = -len(node.patterns) + 1 

110 

111 # try 

112 if sys.version_info >= (3, 11) and isinstance(node, ast.TryStar): 

113 # execpt*: is invalid syntax 

114 result = -len(node.handlers) + 1 

115 

116 if isinstance(node, ast.excepthandler): 

117 result = 0 

118 if node.name: 

119 result += 1 

120 

121 if sys.version_info >= (3, 12): 

122 if isinstance(node, ast.TypeAlias): 

123 result = 0 

124 

125 if hasattr(node, "type_comment") and node.type_comment is not None: 125 ↛ 126line 125 didn't jump to line 126, because the condition on line 125 was never true

126 result += 1 

127 

128 return result 

129 

130 return [(n, weight(n)) for n in ast.walk(tree)] 

131 

132 

133def count_nodes(source): 

134 return sum(v for n, v in node_weights(source)) 

135 

136 

137def try_remove_one(source): 

138 node_count = count_nodes(source) 

139 

140 def checker(source): 

141 try: 

142 compile(source, "<string>", "exec") 

143 except: 

144 return False 

145 

146 count = count_nodes(source) 

147 

148 if count == node_count - 1: 

149 raise pysource_minimize_testing.StopMinimization 

150 

151 return count_nodes(source) >= node_count - 1 

152 

153 while node_count > 1: 

154 # remove only one "node" from the ast at a time 

155 print("node_count:", node_count) 

156 

157 with testing_enabled(): 

158 new_source = pysource_minimize_testing.minimize(source, checker, retries=0) 

159 

160 if session_config.verbose and False: 

161 print("\nnew_source:") 

162 print(new_source) 

163 tree = ast.parse(new_source) 

164 weights = dict(node_weights(tree)) 

165 

166 dump_tree(tree, lambda node: f"w={weights[node]}") 

167 

168 assert count_nodes(new_source) == node_count - 1 

169 

170 source = new_source 

171 

172 node_count -= 1 

173 

174 

175@pytest.mark.parametrize( 

176 "file", [pytest.param(f, id=f.stem) for f in sample_dir.glob("*.py")] 

177) 

178def test_samples(file): 

179 

180 source = file.read_text() 

181 

182 try: 

183 compile(source, file, "exec") 

184 except: 

185 pytest.skip("the sample does not compile for the current python version") 

186 

187 print("source") 

188 print(source) 

189 

190 print("weights:") 

191 for n, v in node_weights(source): 

192 if v: 

193 print(f" {n}: {v}") 

194 print("ast") 

195 if sys.version_info >= (3, 9): 

196 print(ast.dump(ast.parse(source), indent=2)) 

197 else: 

198 print(ast.dump(ast.parse(source))) 

199 

200 try_remove_one(source) 

201 

202 

203def generate_remove_one(): 

204 seed = random.randrange(0, 100000000) 

205 

206 source = generate(seed, node_limit=1000, depth_limit=6) 

207 

208 try: 

209 try_remove_one(source) 

210 except: 

211 

212 # find minimal source where it is not possible to remove one "node" 

213 

214 def checker(source): 

215 try: 

216 try_remove_one(source) 

217 except Exception as e: 

218 return True 

219 

220 return False 

221 

222 min_source = minimize(source, checker) 

223 

224 ( 

225 sample_dir / f"{hashlib.sha256(min_source.encode('utf-8')).hexdigest()}.py" 

226 ).write_text(min_source) 

227 

228 raise ValueError("new sample found")