Coverage for src/pydal2sql_core/magic.py: 62%
143 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-31 19:47 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-31 19:47 +0200
1"""
2This file has methods to guess which variables are unknown and to potentially (monkey-patch) fix this.
3"""
6import ast
7import builtins
8import contextlib
9import importlib
10import textwrap
11import typing
13BUILTINS = set(builtins.__dict__.keys())
16def traverse_ast(node: ast.AST, variable_collector: typing.Callable[[ast.AST], None]) -> None:
17 """
18 Calls variable_collector on each node recursively.
19 """
20 variable_collector(node)
21 for child in ast.iter_child_nodes(node):
22 traverse_ast(child, variable_collector)
25def find_defined_variables(code_str: str) -> set[str]:
26 tree: ast.Module = ast.parse(code_str)
28 variables: set[str] = set()
30 def collect_definitions(node: ast.AST) -> None:
31 if isinstance(node, ast.Assign):
32 node_targets = typing.cast(list[ast.Name], node.targets)
34 variables.update(target.id for target in node_targets)
36 traverse_ast(tree, collect_definitions)
37 return variables
40def remove_specific_variables(code: str, to_remove: typing.Iterable[str] = ("db", "database")) -> str:
41 # Parse the code into an Abstract Syntax Tree (AST) - by ChatGPT
42 tree = ast.parse(code)
44 # Function to check if a variable name is 'db' or 'database'
45 def should_remove(var_name: str) -> bool:
46 return var_name in to_remove
48 # Function to recursively traverse the AST and remove definitions of 'db' or 'database'
49 def remove_db_and_database_defs_rec(node: ast.AST) -> typing.Optional[ast.AST]:
50 if isinstance(node, ast.Assign):
51 # Check if the assignment targets contain 'db' or 'database'
52 new_targets = [
53 target for target in node.targets if not (isinstance(target, ast.Name) and should_remove(target.id))
54 ]
55 node.targets = new_targets
57 elif isinstance(node, (ast.FunctionDef, ast.ClassDef)) and should_remove(node.name):
58 # Check if function or class name is 'db' or 'database'
59 return None
61 for child_node in ast.iter_child_nodes(node):
62 # Recursively process child nodes
63 new_child_node = remove_db_and_database_defs_rec(child_node)
64 if new_child_node is None and hasattr(node, "body"):
65 # If the child node was removed, remove it from the parent's body
66 node.body.remove(child_node)
68 return node
70 # Traverse the AST to remove 'db' and 'database' definitions
71 new_tree = remove_db_and_database_defs_rec(tree)
73 if not new_tree: # pragma: no cover
74 return ""
76 # Generate the modified code from the new AST
77 return ast.unparse(new_tree)
80def find_local_imports(code: str) -> bool:
81 class FindLocalImports(ast.NodeVisitor):
82 def visit_ImportFrom(self, node: ast.ImportFrom) -> bool:
83 if node.level > 0: # This means it's a relative import
84 return True
85 return False
87 tree = ast.parse(code)
88 visitor = FindLocalImports()
89 return any(visitor.visit(node) for node in ast.walk(tree))
92def remove_import(code: str, module_name: typing.Optional[str]) -> str:
93 tree = ast.parse(code)
94 new_body = [
95 node
96 for node in tree.body
97 if not isinstance(node, (ast.Import, ast.ImportFrom))
98 or (not isinstance(node, ast.Import) or all(alias.name != module_name for alias in node.names))
99 and (not isinstance(node, ast.ImportFrom) or node.module != module_name)
100 ]
101 tree.body = new_body
102 return ast.unparse(tree)
105def remove_local_imports(code: str) -> str:
106 class RemoveLocalImports(ast.NodeTransformer):
107 def visit_ImportFrom(self, node: ast.ImportFrom) -> typing.Optional[ast.ImportFrom]:
108 if node.level > 0: # This means it's a relative import
109 return None # Remove the node
110 return node # Keep the node
112 tree = ast.parse(code)
113 tree = RemoveLocalImports().visit(tree)
114 return ast.unparse(tree)
117# def find_function_to_call(code: str, target: str) -> typing.Optional[ast.FunctionDef]:
118# tree = ast.parse(code)
119# for node in ast.walk(tree):
120# if isinstance(node, ast.FunctionDef) and node.name == target:
121# return node
123def find_function_to_call(code, function_call_hint):
124 function_name = function_call_hint.split('(')[0] # Extract function name from hint
125 tree = ast.parse(code)
126 for node in ast.walk(tree):
127 if isinstance(node, ast.FunctionDef) and node.name == function_name:
128 return function_name
129 return None
131# def add_function_call(code: str, function_name: str) -> str:
132# tree = ast.parse(code)
133# # Create a function call node
134# func_call = ast.Expr(
135# value=ast.Call(
136# func=ast.Name(id=function_name, ctx=ast.Load()),
137# args=[ast.Name(id='db', ctx=ast.Load())],
138# keywords=[]
139# )
140# )
141#
142# # Insert the function call right after the function definition
143# for i, node in enumerate(tree.body):
144# if isinstance(node, ast.FunctionDef) and node.name == function_name:
145# tree.body.insert(i + 1, func_call)
146# break
147#
148# return ast.unparse(tree)
150# def extract_function_details(function_call):
151# function_name = function_call.split('(')[0] # Extract function name from hint
152# if '(' in function_call:
153# try:
154# tree = ast.parse(function_call)
155# for node in ast.walk(tree):
156# if isinstance(node, ast.Call):
157# return node.func.id, [ast.unparse(arg) for arg in node.args]
158# except SyntaxError:
159# pass
160# return function_name, []
162def extract_function_details(function_call):
163 function_name = function_call.split('(')[0] # Extract function name from hint
164 if '(' not in function_call:
165 return function_name, ['db']
167 with contextlib.suppress(SyntaxError):
168 tree = ast.parse(function_call)
169 for node in ast.walk(tree):
170 if isinstance(node, ast.Call):
171 # If no arguments are given, add 'db' automatically
172 if len(node.args) == 0:
173 return function_name, ['db']
174 else:
175 return node.func.id, [ast.unparse(arg) for arg in node.args]
176 return None, []
178def add_function_call(code, function_call):
179 function_name, args = extract_function_details(function_call)
181 tree = ast.parse(code)
182 # Create a function call node
183 new_call = ast.Call(
184 func=ast.Name(id=function_name, ctx=ast.Load()),
185 args=[ast.parse(arg).body[0].value for arg in args] if args else [],
186 keywords=[]
187 )
188 func_call = ast.Expr(value=new_call)
190 # Insert the function call right after the function definition
191 for i, node in enumerate(tree.body):
192 if isinstance(node, ast.FunctionDef) and node.name == function_name:
193 tree.body.insert(i+1, func_call)
194 break
196 return ast.unparse(tree)
198def find_variables(code_str: str) -> tuple[set[str], set[str]]:
199 """
200 Look through the source code in code_str and try to detect using ast parsing which variables are undefined.
201 """
202 # Partly made by ChatGPT
203 code_str = textwrap.dedent(code_str)
205 # could raise SyntaxError
206 tree: ast.Module = ast.parse(code_str)
208 used_variables: set[str] = set()
209 defined_variables: set[str] = set()
210 imported_modules: set[str] = set()
211 imported_names: set[str] = set()
212 loop_variables: set[str] = set()
214 def collect_variables(node: ast.AST) -> None:
215 if isinstance(node, ast.Name):
216 if isinstance(node.ctx, ast.Load):
217 used_variables.add(node.id)
218 elif isinstance(node.ctx, ast.Store):
219 defined_variables.add(node.id)
220 elif isinstance(node.ctx, ast.Del):
221 defined_variables.discard(node.id)
223 def collect_definitions(node: ast.AST) -> None:
224 if isinstance(node, ast.Assign):
225 node_targets = typing.cast(list[ast.Name], node.targets)
227 defined_variables.update(target.id for target in node_targets)
229 def collect_imports(node: ast.AST) -> None:
230 if isinstance(node, ast.Import):
231 for alias in node.names:
232 imported_names.add(alias.name)
233 elif isinstance(node, ast.ImportFrom) and node.module:
234 module_name = node.module
235 imported_module = importlib.import_module(module_name)
236 if node.names[0].name == "*":
237 imported_names.update(name for name in dir(imported_module) if not name.startswith("_"))
238 else:
239 imported_names.update(alias.asname or alias.name for alias in node.names)
241 def collect_imported_names(node: ast.AST) -> None:
242 if isinstance(node, ast.ImportFrom) and node.module:
243 for alias in node.names:
244 imported_names.add(alias.asname or alias.name)
246 def collect_loop_variables(node: ast.AST) -> None:
247 if isinstance(node, ast.For) and isinstance(node.target, ast.Name):
248 loop_variables.add(node.target.id)
250 def collect_everything(node: ast.AST) -> None:
251 collect_variables(node)
252 collect_definitions(node)
253 collect_imported_names(node)
254 collect_imports(node)
255 collect_loop_variables(node)
257 # manually rewritten (2.19s for 10k):
258 traverse_ast(tree, collect_everything)
260 all_variables = defined_variables | imported_modules | loop_variables | imported_names | BUILTINS
262 return used_variables, all_variables
265def find_missing_variables(code: str) -> set[str]:
266 used_variables, defined_variables = find_variables(code)
267 return {var for var in used_variables if var not in defined_variables}
270def generate_magic_code(missing_vars: set[str]) -> str:
271 """
272 After finding missing vars, fill them in with an object that does nothing except return itself or an empty string.
274 This way, it's least likely to crash (when used as default or validator in pydal, don't use this for running code!).
275 """
276 extra_code = """
277 class Empty:
278 # class that does absolutely nothing
279 # but can be accessed like an object (obj.something.whatever)
280 # or a dict[with][some][keys]
281 def __getattribute__(self, _):
282 return self
284 def __getitem__(self, _):
285 return self
287 def __get__(self):
288 return self
290 def __call__(self, *_):
291 return self
293 def __str__(self):
294 return ''
296 def __repr__(self):
297 return ''
299 # todo: overload more methods
300 empty = Empty()
301 \n
302 """
303 for variable in missing_vars:
304 extra_code += f"{variable} = empty; "
306 return textwrap.dedent(extra_code)