Coverage for tests/test_needle.py: 48%
107 statements
« prev ^ index » next coverage.py v6.5.0, created at 2024-02-21 22:00 +0100
« prev ^ index » next coverage.py v6.5.0, created at 2024-02-21 22:00 +0100
1import ast
2import hashlib
3import itertools
4import random
5import sys
6from pathlib import Path
7from typing import Any
9import pytest
10from pysource_codegen import generate
12import pysource_minimize._minimize
13from .utils import testing_enabled
15try:
16 import pysource_minimize_testing # type: ignore
17except ImportError:
18 import pysource_minimize as pysource_minimize_testing
21from pysource_minimize import minimize
22from pysource_minimize._minimize import unparse
25sample_dir = Path(__file__).parent / "needle_samples"
27sample_dir.mkdir(exist_ok=True)
29needle_name = "needle_17597"
32def contains_one_needle(source):
33 try:
34 compile(source, "<string>", "exec")
35 except:
36 return False
37 return needle_count(source) == 1
40def needle_count(source):
41 tree = ast.parse(source)
43 return sum(
44 isinstance(node, ast.Name) and node.id == needle_name for node in ast.walk(tree)
45 )
48def try_find_needle(source):
49 assert contains_one_needle(source)
51 with testing_enabled():
52 new_source = pysource_minimize_testing.minimize(
53 source, contains_one_needle, retries=0
54 )
56 assert new_source.strip() == needle_name
59@pytest.mark.parametrize(
60 "file", [pytest.param(f, id=f.stem) for f in sample_dir.glob("*.py")]
61)
62def test_needle(file):
63 source = file.read_text()
65 try:
66 compile(source, file, "exec")
67 except:
68 pytest.skip()
70 print(f"the following code can not be minimized to needle:")
71 print(source)
73 if sys.version_info >= (3, 9):
74 print()
75 print("ast:")
76 print(ast.dump(ast.parse(source), indent=2))
78 try_find_needle(source)
81class HideNeedle(ast.NodeTransformer):
82 def __init__(self, num):
83 self.num = num
84 self.index = 0
85 self.needle_hidden = False
87 def generic_visit(self, node: ast.AST) -> ast.AST:
88 if isinstance(node, ast.expr):
89 if self.num == self.index and not self.needle_hidden:
90 self.index += 1
91 self.needle_hidden = True
92 print("replace", node, "with needle")
93 return ast.Name(id=needle_name)
94 self.index += 1
96 return super().generic_visit(node)
98 if sys.version_info >= (3, 10):
100 def visit_Match(self, node: ast.Match) -> Any:
101 node.subject = self.visit(node.subject)
102 for case_ in node.cases:
103 case_.body = [self.visit(b) for b in case_.body]
105 return node
108import sys
111def generate_needle():
112 seed = random.randrange(0, 100000000)
113 print("seed:", seed)
115 source = generate(seed, node_limit=10000, depth_limit=6)
117 for i in itertools.count():
118 original_tree = ast.parse(source)
120 hide_needle = HideNeedle(i)
121 needle_tree = hide_needle.visit(original_tree)
123 if not hide_needle.needle_hidden:
124 break
126 try:
127 needle_source = unparse(needle_tree)
128 compile(needle_source, "<string>", "exec")
129 except:
130 print("skip this needle")
131 continue
133 if not contains_one_needle(needle_source):
134 # match 0:
135 # case needle:
136 # could be generated which can not be reduced to needle
137 continue
139 try:
140 try_find_needle(needle_source)
141 except:
143 print("minimize")
145 def checker(source):
146 try:
147 compile(source, "<string>", "exec")
148 except:
149 return False
151 if needle_count(source) != 1:
152 return False
153 try:
154 try_find_needle(source)
155 except:
156 return True
158 return False
160 new_source = minimize(needle_source, checker)
161 print(new_source)
162 (
163 sample_dir
164 / f"{hashlib.sha256(new_source.encode('utf-8')).hexdigest()}.py"
165 ).write_text(new_source)
167 raise ValueError("new sample found")