Coverage for src/witchery/__init__.py: 100%
223 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-11-17 16:01 +0100
« prev ^ index » next coverage.py v7.2.7, created at 2023-11-17 16:01 +0100
1"""
2This library has methods to guess which variables are unknown and to potentially (monkey-patch) fix this.
3"""
5# SPDX-FileCopyrightText: 2023-present Robin van der Noord <robinvandernoord@gmail.com>
6#
7# SPDX-License-Identifier: MIT
10import ast
11import builtins
12import contextlib
13import importlib
14import inspect
15import textwrap
16import typing
17import warnings
18from _ast import NamedExpr
19from typing import Any
21from typing_extensions import Self
23BUILTINS = set(builtins.__dict__.keys())
26def traverse_ast(node: ast.AST, variable_collector: typing.Callable[[ast.AST], None]) -> None:
27 """
28 Recursively traverses the given AST node and applies the variable collector function on each node.
30 Args:
31 node (ast.AST): The AST node to traverse.
32 variable_collector (Callable): The function to apply on each node.
33 """
34 variable_collector(node)
35 for child in ast.iter_child_nodes(node):
36 traverse_ast(child, variable_collector)
39def find_defined_variables(code_str: str) -> set[str]:
40 """
41 Parses the given Python code and finds all variables that are defined within.
43 A defined variable refers to any variable that is assigned a value in the code through direct assignment
44 (e.g. `x = 5`). Other assignments such as through for-loops are ignored.
45 Please use `find_variables` if more variable info is needed.
47 This function does not account for scope - it will find variables defined anywhere in the provided code string.
49 Args:
50 code_str (str): A string of Python code.
52 Returns:
53 set[str]: A set of variable names that are defined within the provided Python code.
54 """
55 tree: ast.Module = ast.parse(code_str)
57 variables: set[str] = set()
59 def collect_definitions(node: ast.AST) -> None:
60 if not isinstance(node, (ast.Assign, ast.AnnAssign)):
61 # only look for variable definitions here!
62 return
64 # define function that can be recursed:
65 def handle_elts(elts: typing.Iterable[ast.expr]) -> None:
66 for node in elts:
67 # with contextlib.suppress(Exception):
68 try:
69 if isinstance(node, ast.Subscript):
70 node = node.value
72 if isinstance(node, ast.Tuple):
73 # recurse
74 handle_elts(node.elts)
75 continue
77 if var := getattr(node, "id", None):
78 variables.add(var)
80 except Exception as e: # pragma: no cover
81 warnings.warn("Something went wrong trying to find variables.", source=e)
82 # raise
84 handle_elts(node.targets if hasattr(node, "targets") else [node.target])
86 traverse_ast(tree, collect_definitions)
87 return variables
90class IfBlockRemover(ast.NodeTransformer):
91 """
92 Remove if False or if typing.TYPE_CHECKING.
93 """
95 def visit_If(self, node: ast.If) -> ast.AST | None:
96 """
97 Modify if statements.
98 """
99 if isinstance(node.test, ast.Name) and node.test.id == "TYPE_CHECKING":
100 new_node = ast.copy_location(ast.Pass(), node)
101 return ast.copy_location(ast.If(test=node.test, body=[new_node], orelse=node.orelse), node)
103 if (isinstance(node.test, ast.Constant) and node.test.value is False) or (
104 isinstance(node.test, ast.Attribute)
105 and isinstance(node.test.value, ast.Name)
106 and node.test.value.id == "typing"
107 and node.test.attr == "TYPE_CHECKING"
108 ):
109 return None
111 return self.generic_visit(node)
114def remove_if_falsey_blocks(code: str) -> str:
115 """
116 Remove if False or if typing.TYPE_CHECKING.
117 """
118 tree = ast.parse(code)
119 remover = IfBlockRemover()
120 new_tree = remover.visit(tree)
121 return ast.unparse(new_tree)
124def remove_specific_variables(code: str, to_remove: typing.Iterable[str] = ("db", "database")) -> str:
125 """
126 Removes specific variables from the given code.
128 Args:
129 code (str): The code from which to remove variables.
130 to_remove (Iterable): An iterable of variable names to be removed.
132 Returns:
133 str: The code after removing the specified variables.
134 """
135 # Parse the code into an Abstract Syntax Tree (AST)
136 tree = ast.parse(code)
138 # Function to check if a variable name is 'db' or 'database'
139 def should_remove(var_name: str) -> bool:
140 return var_name in to_remove
142 # Function to recursively traverse the AST and remove lines with 'db' or 'database' definitions
143 def remove_desired_variable_refs(node: ast.AST) -> typing.Optional[ast.AST]:
144 if isinstance(node, ast.Assign):
145 # Check if any of the assignment targets contain 'db' or 'database'
146 if any(isinstance(target, ast.Name) and should_remove(target.id) for target in node.targets):
147 return None
149 elif isinstance(node, (ast.FunctionDef, ast.ClassDef)) and should_remove(node.name):
150 return None
152 # doesn't work well without list() !!!
153 for child_node in list(ast.iter_child_nodes(node)):
154 new_child_node = remove_desired_variable_refs(child_node)
155 if new_child_node is None and hasattr(node, "body"):
156 node.body.remove(child_node)
158 return node
160 # Traverse the AST to remove 'db' and 'database' definitions
161 new_tree = remove_desired_variable_refs(tree)
163 if not new_tree: # pragma: no cover
164 return ""
166 # Generate the modified code from the new AST
167 return ast.unparse(new_tree)
170def has_local_imports(code: str) -> bool:
171 """
172 Checks if the given code has local imports.
174 Args:
175 code (str): The code to check for local imports.
177 Returns:
178 bool: True if local imports are found, False otherwise.
179 """
181 class FindLocalImports(ast.NodeVisitor):
182 def visit_ImportFrom(self, node: ast.ImportFrom) -> bool:
183 if node.level > 0: # This means it's a relative import
184 return True
185 return False
187 tree = ast.parse(code)
188 visitor = FindLocalImports()
189 return any(visitor.visit(node) for node in ast.walk(tree))
192class ImportRemover(ast.NodeTransformer):
193 """
194 Node visitor to remove imports (even in blocks).
195 """
197 def __init__(self, module_name: str) -> None:
198 """
199 Set the module name to remove.
200 """
201 self.module_name = module_name
203 def visit_Import(self, node: ast.Import) -> ast.Import | ast.Pass:
204 """
205 Removes `import module_name`.
206 """
207 node.names = [alias for alias in node.names if alias.name != self.module_name]
208 return node if node.names else ast.Pass()
210 def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom | ast.Pass:
211 """
212 Removes `from module_name import xyz`.
213 """
214 if node.module == self.module_name:
215 return ast.Pass()
216 return node
219def remove_import(code: str, module_name: str) -> str:
220 """
221 Removes the import of a specific module from the given code, including inner scopes.
223 Args:
224 code (str): The code from which to remove the import.
225 module_name (str): The name of the module to remove.
227 Returns:
228 str: The code after removing the import of the specified module.
229 """
230 if not module_name:
231 # nothing to remove
232 warnings.warn("`remove_import` called without module name!")
233 return code
235 tree = ast.parse(code)
236 transformer = ImportRemover(module_name)
237 tree = transformer.visit(tree)
238 return ast.unparse(tree)
241def remove_local_imports(code: str) -> str:
242 """
243 Removes all local imports from the given code.
245 Args:
246 code (str): The code from which to remove local imports.
248 Returns:
249 str: The code after removing all local imports.
250 """
252 class RemoveLocalImports(ast.NodeTransformer):
253 def visit_ImportFrom(self, node: ast.ImportFrom) -> typing.Optional[ast.ImportFrom]:
254 if node.level > 0: # This means it's a relative import
255 return None # Remove the node
256 return node # Keep the node
258 tree = ast.parse(code)
259 tree = RemoveLocalImports().visit(tree)
260 return ast.unparse(tree)
263def find_function_to_call(code: str, function_call_hint: str) -> typing.Optional[str]:
264 """
265 Finds the function to call in the given code based on the function call hint.
267 Args:
268 code (str): The code in which to find the function.
269 function_call_hint (str): The hint for the function call.
271 Returns:
272 str, optional: The name of the function to call if found, None otherwise.
273 """
274 function_name = function_call_hint.split("(")[0] # Extract function name from hint
275 tree = ast.parse(code)
276 return next(
277 (function_name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef) and node.name == function_name),
278 None,
279 )
282DEFAULT_ARGS = ("db",)
285def extract_function_details(
286 function_call: str, default_args: typing.Iterable[str] = DEFAULT_ARGS
287) -> tuple[str | None, list[str]]:
288 """
289 Extracts the function name and arguments from the function call string.
291 Args:
292 function_call (str): The function call string.
293 default_args (Iterable, optional): The default arguments for the function.
295 Returns:
296 tuple: A tuple containing the function name and a list of arguments.
297 """
298 function_name = function_call.split("(")[0] # Extract function name from hint
299 if "(" not in function_call:
300 return function_name, list(default_args)
302 with contextlib.suppress(SyntaxError):
303 tree = ast.parse(function_call)
304 for node in ast.walk(tree):
305 if isinstance(node, ast.Call):
306 if len(node.args) == 0:
307 # If no arguments are given, add 'db' automatically
308 return function_name, list(default_args)
310 func = typing.cast(ast.Name, node.func)
311 return func.id, [ast.unparse(arg) for arg in node.args]
313 return None, []
316def add_function_call(
317 code: str, function_call: str, args: typing.Iterable[str] = DEFAULT_ARGS, multiple: bool = False
318) -> str:
319 """
320 Adds a function call to the given code.
322 Args:
323 code (str): The code to which to add the function call.
324 function_call (str): The function call string.
325 args (Iterable, optional): The arguments for the function call.
326 multiple (bool, optional): If True, add a call after every function with the specified name.
328 Returns:
329 str: The code after adding the function call.
330 """
331 function_name, args = extract_function_details(function_call, default_args=args)
333 def arg_value(arg: str) -> ast.Name:
334 # make mypy happy
335 body = typing.cast(NamedExpr, ast.parse(arg).body[0])
336 return typing.cast(ast.Name, body.value)
338 tree = ast.parse(code)
339 # Create a function call node
340 new_call = ast.Call(
341 func=ast.Name(id=function_name, ctx=ast.Load()),
342 args=[arg_value(arg) for arg in args] if args else [],
343 keywords=[],
344 )
345 func_call = ast.Expr(value=new_call)
347 # Insert the function call right after the function definition
348 for i, node in enumerate(tree.body):
349 if isinstance(node, ast.FunctionDef) and node.name == function_name:
350 tree.body.insert(i + 1, func_call)
351 if not multiple:
352 break
354 return ast.unparse(tree)
357def find_variables(code_str: str, with_builtins: bool = True) -> tuple[set[str], set[str]]:
358 """
359 Finds all used and defined variables in the given code string.
361 Args:
362 code_str (str): The code string to parse for variables.
363 with_builtins (bool): include Python builtins?
365 Returns:
366 tuple: A tuple containing sets of used and defined variables.
367 """
368 # Partly made by ChatGPT
369 code_str = textwrap.dedent(code_str)
371 # could raise SyntaxError
372 tree: ast.Module = ast.parse(code_str)
374 used_variables: set[str] = set()
375 defined_variables: set[str] = set()
376 imported_modules: set[str] = set()
377 imported_names: set[str] = set()
378 loop_variables: set[str] = set()
380 def collect_variables(node: ast.AST) -> None:
381 """
382 Collect or remove variables based on load/store and delete statements.
383 """
384 if isinstance(node, ast.Name):
385 if isinstance(node.ctx, ast.Load):
386 used_variables.add(node.id)
387 elif isinstance(node.ctx, ast.Store):
388 defined_variables.add(node.id)
389 elif isinstance(node.ctx, ast.Del):
390 defined_variables.discard(node.id)
392 def collect_definitions(node: ast.AST) -> None:
393 """
394 Collect variable definitions via other ways.
395 """
396 if not isinstance(node, (ast.Assign, ast.AnnAssign)):
397 return
399 def handle_elts(elts: list[ast.expr]) -> None:
400 """
401 Handle recursive definitions such as tuples.
402 """
403 for node in elts:
404 # with contextlib.suppress(Exception):
405 try:
406 if isinstance(node, ast.Subscript):
407 node = node.value
409 if isinstance(node, ast.Tuple):
410 # recurse
411 handle_elts(node.elts)
412 continue
414 if var := getattr(node, "id", None):
415 defined_variables.add(var)
416 except Exception as e: # pragma: no cover
417 warnings.warn("Something went wrong trying to find variables.", source=e)
418 # raise
420 handle_elts(node.targets if hasattr(node, "targets") else [node.target])
422 def collect_imports(node: ast.AST) -> None:
423 """
424 Get defined variables via imports.
425 """
426 if isinstance(node, ast.Import):
427 for alias in node.names:
428 imported_names.add(alias.name)
429 elif isinstance(node, ast.ImportFrom) and node.module:
430 module_name = node.module
431 with contextlib.suppress(ImportError):
432 imported_module = importlib.import_module(module_name)
434 if node.names[0].name == "*":
435 imported_names.update(name for name in dir(imported_module) if not name.startswith("_"))
436 else:
437 imported_names.update(alias.asname or alias.name for alias in node.names)
439 def collect_imported_names(node: ast.AST) -> None:
440 """
441 Get defined variables via import from.
442 """
443 if isinstance(node, ast.ImportFrom) and node.module:
444 for alias in node.names:
445 imported_names.add(alias.asname or alias.name)
447 def collect_loop_variables(node: ast.AST) -> None:
448 """
449 Get variables defined in a loop (for var in ...).
450 """
451 if isinstance(node, ast.For) and isinstance(node.target, ast.Name):
452 loop_variables.add(node.target.id)
454 def collect_everything(node: ast.AST) -> None:
455 """
456 Run the functions above to get all variables from the code.
457 """
458 collect_variables(node)
459 collect_definitions(node)
460 collect_imported_names(node)
461 collect_imports(node)
462 collect_loop_variables(node)
464 # manually rewritten (2.19s for 10k):
465 traverse_ast(tree, collect_everything)
467 all_variables = (
468 defined_variables | imported_modules | loop_variables | imported_names | (BUILTINS if with_builtins else set())
469 )
471 return used_variables, all_variables
474def find_missing_variables(code: str) -> set[str]:
475 """
476 Finds and returns all missing variables in the given code.
478 Args:
479 code (str): The code to check for missing variables.
481 Returns:
482 set: A set of names of missing variables.
483 """
484 used_variables, defined_variables = find_variables(code)
485 return {var for var in used_variables if var not in defined_variables}
488T = typing.TypeVar("T", bound=Any)
491class Empty:
492 """
493 Class that does absolutely nothing.
495 but can be accessed like an object (obj.something.whatever)
496 or a dict[with][some][keys]
497 """
499 # todo: overload more methods
501 def __init__(self, *_: Any, **__: Any) -> None:
502 """
503 Can be passed any vars.
504 """
506 def __bool__(self) -> bool:
507 """
508 An `empty` object is False so it can be `or`-ed.
509 """
510 return False
512 def __getattribute__(self, _: str) -> Self:
513 """
514 Accessing .something.
515 """
516 return self
518 def __getitem__(self, _: str) -> Self:
519 """
520 Accessing ['something'].
521 """
522 return self
524 def __iter__(self) -> typing.Generator[Self, Any, None]:
525 """
526 Allows `for _ in Empty():`.
528 Only yields one item, itself.
529 """
530 # fix set(empty)
531 yield self # once
533 def __get__(self, *_: Any) -> Self:
534 """
535 Called when empty is set as a property on another class.
536 """
537 return self
539 def __call__(self, *_: Any, **__: Any) -> Self:
540 """
541 When an instance gets called.
543 empty = Empty()
544 empty()
545 """
546 return self
548 def __str__(self) -> str:
549 """
550 Empty string represent.
551 """
552 return ""
554 def __repr__(self) -> str:
555 """
556 Empty string represent.
557 """
558 return ""
560 def __add__(self, other: T) -> T:
561 """
562 Overlaods +.
564 empty + [] = []
565 """
566 return other
569def generate_magic_code(missing_vars: set[str]) -> str:
570 """
571 Generates code to define missing variables with a do-nothing object.
573 After finding missing vars, fill them in with an object that does nothing except return itself or an empty string.
574 This way, it's least likely to crash (when used as default or validator in pydal, don't use this for running code!).
576 Args:
577 missing_vars (set): The set of missing variable names.
579 Returns:
580 str: The generated code.
581 """
582 extra_code = (
583 "import typing; from typing import Any; "
584 "from typing_extensions import Self; "
585 "T = typing.TypeVar('T', bound=Any); "
586 "\n"
587 )
589 extra_code += inspect.getsource(Empty)
591 extra_code += "\n\n"
592 extra_code += "empty = Empty()"
593 extra_code += "\n"
595 for variable in missing_vars:
596 extra_code += f"{variable} = empty; "
598 return textwrap.dedent(extra_code)