Coverage for tests/test_needle.py: 48%

107 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-02-21 22:00 +0100

1import ast 

2import hashlib 

3import itertools 

4import random 

5import sys 

6from pathlib import Path 

7from typing import Any 

8 

9import pytest 

10from pysource_codegen import generate 

11 

12import pysource_minimize._minimize 

13from .utils import testing_enabled 

14 

15try: 

16 import pysource_minimize_testing # type: ignore 

17except ImportError: 

18 import pysource_minimize as pysource_minimize_testing 

19 

20 

21from pysource_minimize import minimize 

22from pysource_minimize._minimize import unparse 

23 

24 

25sample_dir = Path(__file__).parent / "needle_samples" 

26 

27sample_dir.mkdir(exist_ok=True) 

28 

29needle_name = "needle_17597" 

30 

31 

32def contains_one_needle(source): 

33 try: 

34 compile(source, "<string>", "exec") 

35 except: 

36 return False 

37 return needle_count(source) == 1 

38 

39 

40def needle_count(source): 

41 tree = ast.parse(source) 

42 

43 return sum( 

44 isinstance(node, ast.Name) and node.id == needle_name for node in ast.walk(tree) 

45 ) 

46 

47 

48def try_find_needle(source): 

49 assert contains_one_needle(source) 

50 

51 with testing_enabled(): 

52 new_source = pysource_minimize_testing.minimize( 

53 source, contains_one_needle, retries=0 

54 ) 

55 

56 assert new_source.strip() == needle_name 

57 

58 

59@pytest.mark.parametrize( 

60 "file", [pytest.param(f, id=f.stem) for f in sample_dir.glob("*.py")] 

61) 

62def test_needle(file): 

63 source = file.read_text() 

64 

65 try: 

66 compile(source, file, "exec") 

67 except: 

68 pytest.skip() 

69 

70 print(f"the following code can not be minimized to needle:") 

71 print(source) 

72 

73 if sys.version_info >= (3, 9): 

74 print() 

75 print("ast:") 

76 print(ast.dump(ast.parse(source), indent=2)) 

77 

78 try_find_needle(source) 

79 

80 

81class HideNeedle(ast.NodeTransformer): 

82 def __init__(self, num): 

83 self.num = num 

84 self.index = 0 

85 self.needle_hidden = False 

86 

87 def generic_visit(self, node: ast.AST) -> ast.AST: 

88 if isinstance(node, ast.expr): 

89 if self.num == self.index and not self.needle_hidden: 

90 self.index += 1 

91 self.needle_hidden = True 

92 print("replace", node, "with needle") 

93 return ast.Name(id=needle_name) 

94 self.index += 1 

95 

96 return super().generic_visit(node) 

97 

98 if sys.version_info >= (3, 10): 

99 

100 def visit_Match(self, node: ast.Match) -> Any: 

101 node.subject = self.visit(node.subject) 

102 for case_ in node.cases: 

103 case_.body = [self.visit(b) for b in case_.body] 

104 

105 return node 

106 

107 

108import sys 

109 

110 

111def generate_needle(): 

112 seed = random.randrange(0, 100000000) 

113 print("seed:", seed) 

114 

115 source = generate(seed, node_limit=10000, depth_limit=6) 

116 

117 for i in itertools.count(): 

118 original_tree = ast.parse(source) 

119 

120 hide_needle = HideNeedle(i) 

121 needle_tree = hide_needle.visit(original_tree) 

122 

123 if not hide_needle.needle_hidden: 

124 break 

125 

126 try: 

127 needle_source = unparse(needle_tree) 

128 compile(needle_source, "<string>", "exec") 

129 except: 

130 print("skip this needle") 

131 continue 

132 

133 if not contains_one_needle(needle_source): 

134 # match 0: 

135 # case needle: 

136 # could be generated which can not be reduced to needle 

137 continue 

138 

139 try: 

140 try_find_needle(needle_source) 

141 except: 

142 

143 print("minimize") 

144 

145 def checker(source): 

146 try: 

147 compile(source, "<string>", "exec") 

148 except: 

149 return False 

150 

151 if needle_count(source) != 1: 

152 return False 

153 try: 

154 try_find_needle(source) 

155 except: 

156 return True 

157 

158 return False 

159 

160 new_source = minimize(needle_source, checker) 

161 print(new_source) 

162 ( 

163 sample_dir 

164 / f"{hashlib.sha256(new_source.encode('utf-8')).hexdigest()}.py" 

165 ).write_text(new_source) 

166 

167 raise ValueError("new sample found")