Coverage for src/witchery/__init__.py: 100%
207 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-11-17 14:16 +0100
« prev ^ index » next coverage.py v7.2.7, created at 2023-11-17 14:16 +0100
1"""
2This library has methods to guess which variables are unknown and to potentially (monkey-patch) fix this.
3"""
5# SPDX-FileCopyrightText: 2023-present Robin van der Noord <robinvandernoord@gmail.com>
6#
7# SPDX-License-Identifier: MIT
10import ast
11import builtins
12import contextlib
13import importlib
14import inspect
15import textwrap
16import typing
17import warnings
18from _ast import NamedExpr
19from typing import Any
21from typing_extensions import Self
23BUILTINS = set(builtins.__dict__.keys())
26def traverse_ast(node: ast.AST, variable_collector: typing.Callable[[ast.AST], None]) -> None:
27 """
28 Recursively traverses the given AST node and applies the variable collector function on each node.
30 Args:
31 node (ast.AST): The AST node to traverse.
32 variable_collector (Callable): The function to apply on each node.
33 """
34 variable_collector(node)
35 for child in ast.iter_child_nodes(node):
36 traverse_ast(child, variable_collector)
39def find_defined_variables(code_str: str) -> set[str]:
40 """
41 Parses the given Python code and finds all variables that are defined within.
43 A defined variable refers to any variable that is assigned a value in the code through direct assignment
44 (e.g. `x = 5`). Other assignments such as through for-loops are ignored.
45 Please use `find_variables` if more variable info is needed.
47 This function does not account for scope - it will find variables defined anywhere in the provided code string.
49 Args:
50 code_str (str): A string of Python code.
52 Returns:
53 set[str]: A set of variable names that are defined within the provided Python code.
54 """
55 tree: ast.Module = ast.parse(code_str)
57 variables: set[str] = set()
59 def collect_definitions(node: ast.AST) -> None:
60 if not isinstance(node, (ast.Assign, ast.AnnAssign)):
61 # only look for variable definitions here!
62 return
64 # define function that can be recursed:
65 def handle_elts(elts: typing.Iterable[ast.expr]) -> None:
66 for node in elts:
67 # with contextlib.suppress(Exception):
68 try:
69 if isinstance(node, ast.Subscript):
70 node = node.value
72 if isinstance(node, ast.Tuple):
73 # recurse
74 handle_elts(node.elts)
75 continue
77 variables.add(getattr(node, "id"))
79 except Exception as e: # pragma: no cover
80 warnings.warn("Something went wrong trying to find variables.", source=e)
81 # raise
83 handle_elts(node.targets if hasattr(node, "targets") else [node.target])
85 traverse_ast(tree, collect_definitions)
86 return variables
89def remove_specific_variables(code: str, to_remove: typing.Iterable[str] = ("db", "database")) -> str:
90 """
91 Removes specific variables from the given code.
93 Args:
94 code (str): The code from which to remove variables.
95 to_remove (Iterable): An iterable of variable names to be removed.
97 Returns:
98 str: The code after removing the specified variables.
99 """
100 # Parse the code into an Abstract Syntax Tree (AST)
101 tree = ast.parse(code)
103 # Function to check if a variable name is 'db' or 'database'
104 def should_remove(var_name: str) -> bool:
105 return var_name in to_remove
107 # Function to recursively traverse the AST and remove lines with 'db' or 'database' definitions
108 def remove_desired_variable_refs(node: ast.AST) -> typing.Optional[ast.AST]:
109 if isinstance(node, ast.Assign):
110 # Check if any of the assignment targets contain 'db' or 'database'
111 if any(isinstance(target, ast.Name) and should_remove(target.id) for target in node.targets):
112 return None
114 elif isinstance(node, (ast.FunctionDef, ast.ClassDef)) and should_remove(node.name):
115 return None
117 # doesn't work well without list() !!!
118 for child_node in list(ast.iter_child_nodes(node)):
119 new_child_node = remove_desired_variable_refs(child_node)
120 if new_child_node is None and hasattr(node, "body"):
121 node.body.remove(child_node)
123 return node
125 # Traverse the AST to remove 'db' and 'database' definitions
126 new_tree = remove_desired_variable_refs(tree)
128 if not new_tree: # pragma: no cover
129 return ""
131 # Generate the modified code from the new AST
132 return ast.unparse(new_tree)
135def has_local_imports(code: str) -> bool:
136 """
137 Checks if the given code has local imports.
139 Args:
140 code (str): The code to check for local imports.
142 Returns:
143 bool: True if local imports are found, False otherwise.
144 """
146 class FindLocalImports(ast.NodeVisitor):
147 def visit_ImportFrom(self, node: ast.ImportFrom) -> bool:
148 if node.level > 0: # This means it's a relative import
149 return True
150 return False
152 tree = ast.parse(code)
153 visitor = FindLocalImports()
154 return any(visitor.visit(node) for node in ast.walk(tree))
157class ImportRemover(ast.NodeTransformer):
158 """
159 Node visitor to remove imports (even in blocks).
160 """
162 def __init__(self, module_name: str) -> None:
163 """
164 Set the module name to remove.
165 """
166 self.module_name = module_name
168 def visit_Import(self, node: ast.Import) -> ast.Import | ast.Pass:
169 """
170 Removes `import module_name`.
171 """
172 node.names = [alias for alias in node.names if alias.name != self.module_name]
173 return node if node.names else ast.Pass()
175 def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom | ast.Pass:
176 """
177 Removes `from module_name import xyz`.
178 """
179 if node.module == self.module_name:
180 return ast.Pass()
181 return node
184def remove_import(code: str, module_name: str) -> str:
185 """
186 Removes the import of a specific module from the given code, including inner scopes.
188 Args:
189 code (str): The code from which to remove the import.
190 module_name (str): The name of the module to remove.
192 Returns:
193 str: The code after removing the import of the specified module.
194 """
195 tree = ast.parse(code)
196 transformer = ImportRemover(module_name)
197 tree = transformer.visit(tree)
198 return ast.unparse(tree)
201def remove_local_imports(code: str) -> str:
202 """
203 Removes all local imports from the given code.
205 Args:
206 code (str): The code from which to remove local imports.
208 Returns:
209 str: The code after removing all local imports.
210 """
212 class RemoveLocalImports(ast.NodeTransformer):
213 def visit_ImportFrom(self, node: ast.ImportFrom) -> typing.Optional[ast.ImportFrom]:
214 if node.level > 0: # This means it's a relative import
215 return None # Remove the node
216 return node # Keep the node
218 tree = ast.parse(code)
219 tree = RemoveLocalImports().visit(tree)
220 return ast.unparse(tree)
223def find_function_to_call(code: str, function_call_hint: str) -> typing.Optional[str]:
224 """
225 Finds the function to call in the given code based on the function call hint.
227 Args:
228 code (str): The code in which to find the function.
229 function_call_hint (str): The hint for the function call.
231 Returns:
232 str, optional: The name of the function to call if found, None otherwise.
233 """
234 function_name = function_call_hint.split("(")[0] # Extract function name from hint
235 tree = ast.parse(code)
236 return next(
237 (function_name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef) and node.name == function_name),
238 None,
239 )
242DEFAULT_ARGS = ("db",)
245def extract_function_details(
246 function_call: str, default_args: typing.Iterable[str] = DEFAULT_ARGS
247) -> tuple[str | None, list[str]]:
248 """
249 Extracts the function name and arguments from the function call string.
251 Args:
252 function_call (str): The function call string.
253 default_args (Iterable, optional): The default arguments for the function.
255 Returns:
256 tuple: A tuple containing the function name and a list of arguments.
257 """
258 function_name = function_call.split("(")[0] # Extract function name from hint
259 if "(" not in function_call:
260 return function_name, list(default_args)
262 with contextlib.suppress(SyntaxError):
263 tree = ast.parse(function_call)
264 for node in ast.walk(tree):
265 if isinstance(node, ast.Call):
266 if len(node.args) == 0:
267 # If no arguments are given, add 'db' automatically
268 return function_name, list(default_args)
270 func = typing.cast(ast.Name, node.func)
271 return func.id, [ast.unparse(arg) for arg in node.args]
273 return None, []
276def add_function_call(
277 code: str, function_call: str, args: typing.Iterable[str] = DEFAULT_ARGS, multiple: bool = False
278) -> str:
279 """
280 Adds a function call to the given code.
282 Args:
283 code (str): The code to which to add the function call.
284 function_call (str): The function call string.
285 args (Iterable, optional): The arguments for the function call.
286 multiple (bool, optional): If True, add a call after every function with the specified name.
288 Returns:
289 str: The code after adding the function call.
290 """
291 function_name, args = extract_function_details(function_call, default_args=args)
293 def arg_value(arg: str) -> ast.Name:
294 # make mypy happy
295 body = typing.cast(NamedExpr, ast.parse(arg).body[0])
296 return typing.cast(ast.Name, body.value)
298 tree = ast.parse(code)
299 # Create a function call node
300 new_call = ast.Call(
301 func=ast.Name(id=function_name, ctx=ast.Load()),
302 args=[arg_value(arg) for arg in args] if args else [],
303 keywords=[],
304 )
305 func_call = ast.Expr(value=new_call)
307 # Insert the function call right after the function definition
308 for i, node in enumerate(tree.body):
309 if isinstance(node, ast.FunctionDef) and node.name == function_name:
310 tree.body.insert(i + 1, func_call)
311 if not multiple:
312 break
314 return ast.unparse(tree)
317def find_variables(code_str: str, with_builtins: bool = True) -> tuple[set[str], set[str]]:
318 """
319 Finds all used and defined variables in the given code string.
321 Args:
322 code_str (str): The code string to parse for variables.
323 with_builtins (bool): include Python builtins?
325 Returns:
326 tuple: A tuple containing sets of used and defined variables.
327 """
328 # Partly made by ChatGPT
329 code_str = textwrap.dedent(code_str)
331 # could raise SyntaxError
332 tree: ast.Module = ast.parse(code_str)
334 used_variables: set[str] = set()
335 defined_variables: set[str] = set()
336 imported_modules: set[str] = set()
337 imported_names: set[str] = set()
338 loop_variables: set[str] = set()
340 def collect_variables(node: ast.AST) -> None:
341 if isinstance(node, ast.Name):
342 if isinstance(node.ctx, ast.Load):
343 used_variables.add(node.id)
344 elif isinstance(node.ctx, ast.Store):
345 defined_variables.add(node.id)
346 elif isinstance(node.ctx, ast.Del):
347 defined_variables.discard(node.id)
349 def collect_definitions(node: ast.AST) -> None:
350 if not isinstance(node, (ast.Assign, ast.AnnAssign)):
351 return
353 def handle_elts(elts: list[ast.expr]) -> None:
354 for node in elts:
355 # with contextlib.suppress(Exception):
356 try:
357 if isinstance(node, ast.Subscript):
358 node = node.value
360 if isinstance(node, ast.Tuple):
361 # recurse
362 handle_elts(node.elts)
363 continue
365 if var := getattr(node, "id", None):
366 defined_variables.add(var)
367 except Exception as e: # pragma: no cover
368 warnings.warn("Something went wrong trying to find variables.", source=e)
369 # raise
371 handle_elts(node.targets if hasattr(node, "targets") else [node.target])
373 def collect_imports(node: ast.AST) -> None:
374 if isinstance(node, ast.Import):
375 for alias in node.names:
376 imported_names.add(alias.name)
377 elif isinstance(node, ast.ImportFrom) and node.module:
378 module_name = node.module
379 with contextlib.suppress(ImportError):
380 imported_module = importlib.import_module(module_name)
382 if node.names[0].name == "*":
383 imported_names.update(name for name in dir(imported_module) if not name.startswith("_"))
384 else:
385 imported_names.update(alias.asname or alias.name for alias in node.names)
387 def collect_imported_names(node: ast.AST) -> None:
388 if isinstance(node, ast.ImportFrom) and node.module:
389 for alias in node.names:
390 imported_names.add(alias.asname or alias.name)
392 def collect_loop_variables(node: ast.AST) -> None:
393 if isinstance(node, ast.For) and isinstance(node.target, ast.Name):
394 loop_variables.add(node.target.id)
396 def collect_everything(node: ast.AST) -> None:
397 collect_variables(node)
398 collect_definitions(node)
399 collect_imported_names(node)
400 collect_imports(node)
401 collect_loop_variables(node)
403 # manually rewritten (2.19s for 10k):
404 traverse_ast(tree, collect_everything)
406 all_variables = (
407 defined_variables | imported_modules | loop_variables | imported_names | (BUILTINS if with_builtins else set())
408 )
410 return used_variables, all_variables
413def find_missing_variables(code: str) -> set[str]:
414 """
415 Finds and returns all missing variables in the given code.
417 Args:
418 code (str): The code to check for missing variables.
420 Returns:
421 set: A set of names of missing variables.
422 """
423 used_variables, defined_variables = find_variables(code)
424 return {var for var in used_variables if var not in defined_variables}
427T = typing.TypeVar("T", bound=Any)
430class Empty:
431 # todo: overload more methods
432 # class that does absolutely nothing
433 # but can be accessed like an object (obj.something.whatever)
434 # or a dict[with][some][keys]
435 def __init__(self, *_: Any, **__: Any) -> None:
436 ...
438 def __bool__(self) -> bool:
439 return False
441 def __getattribute__(self, _: str) -> Self:
442 return self
444 def __getitem__(self, _: str) -> Self:
445 return self
447 def __iter__(self) -> typing.Generator[Self, Any, None]:
448 # fix set(empty)
449 yield self # once
451 def __get__(self, *_: Any) -> Self:
452 return self
454 def __call__(self, *_: Any, **__: Any) -> Self:
455 return self
457 def __str__(self) -> str:
458 return ""
460 def __repr__(self) -> str:
461 return ""
463 def __add__(self, other: T) -> T:
464 return other
467def generate_magic_code(missing_vars: set[str]) -> str:
468 """
469 Generates code to define missing variables with a do-nothing object.
471 After finding missing vars, fill them in with an object that does nothing except return itself or an empty string.
472 This way, it's least likely to crash (when used as default or validator in pydal, don't use this for running code!).
474 Args:
475 missing_vars (set): The set of missing variable names.
477 Returns:
478 str: The generated code.
479 """
480 extra_code = (
481 "import typing; from typing import Any; "
482 "from typing_extensions import Self; "
483 "T = typing.TypeVar('T', bound=Any); "
484 "\n"
485 )
487 extra_code += inspect.getsource(Empty)
489 extra_code += "\n\n"
490 extra_code += "empty = Empty()"
491 extra_code += "\n"
493 for variable in missing_vars:
494 extra_code += f"{variable} = empty; "
496 return textwrap.dedent(extra_code)