Coverage for src/witchery/__init__.py: 100%

223 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-11-17 16:01 +0100

1""" 

2This library has methods to guess which variables are unknown and to potentially (monkey-patch) fix this. 

3""" 

4 

5# SPDX-FileCopyrightText: 2023-present Robin van der Noord <robinvandernoord@gmail.com> 

6# 

7# SPDX-License-Identifier: MIT 

8 

9 

10import ast 

11import builtins 

12import contextlib 

13import importlib 

14import inspect 

15import textwrap 

16import typing 

17import warnings 

18from _ast import NamedExpr 

19from typing import Any 

20 

21from typing_extensions import Self 

22 

23BUILTINS = set(builtins.__dict__.keys()) 

24 

25 

26def traverse_ast(node: ast.AST, variable_collector: typing.Callable[[ast.AST], None]) -> None: 

27 """ 

28 Recursively traverses the given AST node and applies the variable collector function on each node. 

29 

30 Args: 

31 node (ast.AST): The AST node to traverse. 

32 variable_collector (Callable): The function to apply on each node. 

33 """ 

34 variable_collector(node) 

35 for child in ast.iter_child_nodes(node): 

36 traverse_ast(child, variable_collector) 

37 

38 

39def find_defined_variables(code_str: str) -> set[str]: 

40 """ 

41 Parses the given Python code and finds all variables that are defined within. 

42 

43 A defined variable refers to any variable that is assigned a value in the code through direct assignment 

44 (e.g. `x = 5`). Other assignments such as through for-loops are ignored. 

45 Please use `find_variables` if more variable info is needed. 

46 

47 This function does not account for scope - it will find variables defined anywhere in the provided code string. 

48 

49 Args: 

50 code_str (str): A string of Python code. 

51 

52 Returns: 

53 set[str]: A set of variable names that are defined within the provided Python code. 

54 """ 

55 tree: ast.Module = ast.parse(code_str) 

56 

57 variables: set[str] = set() 

58 

59 def collect_definitions(node: ast.AST) -> None: 

60 if not isinstance(node, (ast.Assign, ast.AnnAssign)): 

61 # only look for variable definitions here! 

62 return 

63 

64 # define function that can be recursed: 

65 def handle_elts(elts: typing.Iterable[ast.expr]) -> None: 

66 for node in elts: 

67 # with contextlib.suppress(Exception): 

68 try: 

69 if isinstance(node, ast.Subscript): 

70 node = node.value 

71 

72 if isinstance(node, ast.Tuple): 

73 # recurse 

74 handle_elts(node.elts) 

75 continue 

76 

77 if var := getattr(node, "id", None): 

78 variables.add(var) 

79 

80 except Exception as e: # pragma: no cover 

81 warnings.warn("Something went wrong trying to find variables.", source=e) 

82 # raise 

83 

84 handle_elts(node.targets if hasattr(node, "targets") else [node.target]) 

85 

86 traverse_ast(tree, collect_definitions) 

87 return variables 

88 

89 

90class IfBlockRemover(ast.NodeTransformer): 

91 """ 

92 Remove if False or if typing.TYPE_CHECKING. 

93 """ 

94 

95 def visit_If(self, node: ast.If) -> ast.AST | None: 

96 """ 

97 Modify if statements. 

98 """ 

99 if isinstance(node.test, ast.Name) and node.test.id == "TYPE_CHECKING": 

100 new_node = ast.copy_location(ast.Pass(), node) 

101 return ast.copy_location(ast.If(test=node.test, body=[new_node], orelse=node.orelse), node) 

102 

103 if (isinstance(node.test, ast.Constant) and node.test.value is False) or ( 

104 isinstance(node.test, ast.Attribute) 

105 and isinstance(node.test.value, ast.Name) 

106 and node.test.value.id == "typing" 

107 and node.test.attr == "TYPE_CHECKING" 

108 ): 

109 return None 

110 

111 return self.generic_visit(node) 

112 

113 

114def remove_if_falsey_blocks(code: str) -> str: 

115 """ 

116 Remove if False or if typing.TYPE_CHECKING. 

117 """ 

118 tree = ast.parse(code) 

119 remover = IfBlockRemover() 

120 new_tree = remover.visit(tree) 

121 return ast.unparse(new_tree) 

122 

123 

124def remove_specific_variables(code: str, to_remove: typing.Iterable[str] = ("db", "database")) -> str: 

125 """ 

126 Removes specific variables from the given code. 

127 

128 Args: 

129 code (str): The code from which to remove variables. 

130 to_remove (Iterable): An iterable of variable names to be removed. 

131 

132 Returns: 

133 str: The code after removing the specified variables. 

134 """ 

135 # Parse the code into an Abstract Syntax Tree (AST) 

136 tree = ast.parse(code) 

137 

138 # Function to check if a variable name is 'db' or 'database' 

139 def should_remove(var_name: str) -> bool: 

140 return var_name in to_remove 

141 

142 # Function to recursively traverse the AST and remove lines with 'db' or 'database' definitions 

143 def remove_desired_variable_refs(node: ast.AST) -> typing.Optional[ast.AST]: 

144 if isinstance(node, ast.Assign): 

145 # Check if any of the assignment targets contain 'db' or 'database' 

146 if any(isinstance(target, ast.Name) and should_remove(target.id) for target in node.targets): 

147 return None 

148 

149 elif isinstance(node, (ast.FunctionDef, ast.ClassDef)) and should_remove(node.name): 

150 return None 

151 

152 # doesn't work well without list() !!! 

153 for child_node in list(ast.iter_child_nodes(node)): 

154 new_child_node = remove_desired_variable_refs(child_node) 

155 if new_child_node is None and hasattr(node, "body"): 

156 node.body.remove(child_node) 

157 

158 return node 

159 

160 # Traverse the AST to remove 'db' and 'database' definitions 

161 new_tree = remove_desired_variable_refs(tree) 

162 

163 if not new_tree: # pragma: no cover 

164 return "" 

165 

166 # Generate the modified code from the new AST 

167 return ast.unparse(new_tree) 

168 

169 

170def has_local_imports(code: str) -> bool: 

171 """ 

172 Checks if the given code has local imports. 

173 

174 Args: 

175 code (str): The code to check for local imports. 

176 

177 Returns: 

178 bool: True if local imports are found, False otherwise. 

179 """ 

180 

181 class FindLocalImports(ast.NodeVisitor): 

182 def visit_ImportFrom(self, node: ast.ImportFrom) -> bool: 

183 if node.level > 0: # This means it's a relative import 

184 return True 

185 return False 

186 

187 tree = ast.parse(code) 

188 visitor = FindLocalImports() 

189 return any(visitor.visit(node) for node in ast.walk(tree)) 

190 

191 

192class ImportRemover(ast.NodeTransformer): 

193 """ 

194 Node visitor to remove imports (even in blocks). 

195 """ 

196 

197 def __init__(self, module_name: str) -> None: 

198 """ 

199 Set the module name to remove. 

200 """ 

201 self.module_name = module_name 

202 

203 def visit_Import(self, node: ast.Import) -> ast.Import | ast.Pass: 

204 """ 

205 Removes `import module_name`. 

206 """ 

207 node.names = [alias for alias in node.names if alias.name != self.module_name] 

208 return node if node.names else ast.Pass() 

209 

210 def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom | ast.Pass: 

211 """ 

212 Removes `from module_name import xyz`. 

213 """ 

214 if node.module == self.module_name: 

215 return ast.Pass() 

216 return node 

217 

218 

219def remove_import(code: str, module_name: str) -> str: 

220 """ 

221 Removes the import of a specific module from the given code, including inner scopes. 

222 

223 Args: 

224 code (str): The code from which to remove the import. 

225 module_name (str): The name of the module to remove. 

226 

227 Returns: 

228 str: The code after removing the import of the specified module. 

229 """ 

230 if not module_name: 

231 # nothing to remove 

232 warnings.warn("`remove_import` called without module name!") 

233 return code 

234 

235 tree = ast.parse(code) 

236 transformer = ImportRemover(module_name) 

237 tree = transformer.visit(tree) 

238 return ast.unparse(tree) 

239 

240 

241def remove_local_imports(code: str) -> str: 

242 """ 

243 Removes all local imports from the given code. 

244 

245 Args: 

246 code (str): The code from which to remove local imports. 

247 

248 Returns: 

249 str: The code after removing all local imports. 

250 """ 

251 

252 class RemoveLocalImports(ast.NodeTransformer): 

253 def visit_ImportFrom(self, node: ast.ImportFrom) -> typing.Optional[ast.ImportFrom]: 

254 if node.level > 0: # This means it's a relative import 

255 return None # Remove the node 

256 return node # Keep the node 

257 

258 tree = ast.parse(code) 

259 tree = RemoveLocalImports().visit(tree) 

260 return ast.unparse(tree) 

261 

262 

263def find_function_to_call(code: str, function_call_hint: str) -> typing.Optional[str]: 

264 """ 

265 Finds the function to call in the given code based on the function call hint. 

266 

267 Args: 

268 code (str): The code in which to find the function. 

269 function_call_hint (str): The hint for the function call. 

270 

271 Returns: 

272 str, optional: The name of the function to call if found, None otherwise. 

273 """ 

274 function_name = function_call_hint.split("(")[0] # Extract function name from hint 

275 tree = ast.parse(code) 

276 return next( 

277 (function_name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef) and node.name == function_name), 

278 None, 

279 ) 

280 

281 

282DEFAULT_ARGS = ("db",) 

283 

284 

285def extract_function_details( 

286 function_call: str, default_args: typing.Iterable[str] = DEFAULT_ARGS 

287) -> tuple[str | None, list[str]]: 

288 """ 

289 Extracts the function name and arguments from the function call string. 

290 

291 Args: 

292 function_call (str): The function call string. 

293 default_args (Iterable, optional): The default arguments for the function. 

294 

295 Returns: 

296 tuple: A tuple containing the function name and a list of arguments. 

297 """ 

298 function_name = function_call.split("(")[0] # Extract function name from hint 

299 if "(" not in function_call: 

300 return function_name, list(default_args) 

301 

302 with contextlib.suppress(SyntaxError): 

303 tree = ast.parse(function_call) 

304 for node in ast.walk(tree): 

305 if isinstance(node, ast.Call): 

306 if len(node.args) == 0: 

307 # If no arguments are given, add 'db' automatically 

308 return function_name, list(default_args) 

309 

310 func = typing.cast(ast.Name, node.func) 

311 return func.id, [ast.unparse(arg) for arg in node.args] 

312 

313 return None, [] 

314 

315 

316def add_function_call( 

317 code: str, function_call: str, args: typing.Iterable[str] = DEFAULT_ARGS, multiple: bool = False 

318) -> str: 

319 """ 

320 Adds a function call to the given code. 

321 

322 Args: 

323 code (str): The code to which to add the function call. 

324 function_call (str): The function call string. 

325 args (Iterable, optional): The arguments for the function call. 

326 multiple (bool, optional): If True, add a call after every function with the specified name. 

327 

328 Returns: 

329 str: The code after adding the function call. 

330 """ 

331 function_name, args = extract_function_details(function_call, default_args=args) 

332 

333 def arg_value(arg: str) -> ast.Name: 

334 # make mypy happy 

335 body = typing.cast(NamedExpr, ast.parse(arg).body[0]) 

336 return typing.cast(ast.Name, body.value) 

337 

338 tree = ast.parse(code) 

339 # Create a function call node 

340 new_call = ast.Call( 

341 func=ast.Name(id=function_name, ctx=ast.Load()), 

342 args=[arg_value(arg) for arg in args] if args else [], 

343 keywords=[], 

344 ) 

345 func_call = ast.Expr(value=new_call) 

346 

347 # Insert the function call right after the function definition 

348 for i, node in enumerate(tree.body): 

349 if isinstance(node, ast.FunctionDef) and node.name == function_name: 

350 tree.body.insert(i + 1, func_call) 

351 if not multiple: 

352 break 

353 

354 return ast.unparse(tree) 

355 

356 

357def find_variables(code_str: str, with_builtins: bool = True) -> tuple[set[str], set[str]]: 

358 """ 

359 Finds all used and defined variables in the given code string. 

360 

361 Args: 

362 code_str (str): The code string to parse for variables. 

363 with_builtins (bool): include Python builtins? 

364 

365 Returns: 

366 tuple: A tuple containing sets of used and defined variables. 

367 """ 

368 # Partly made by ChatGPT 

369 code_str = textwrap.dedent(code_str) 

370 

371 # could raise SyntaxError 

372 tree: ast.Module = ast.parse(code_str) 

373 

374 used_variables: set[str] = set() 

375 defined_variables: set[str] = set() 

376 imported_modules: set[str] = set() 

377 imported_names: set[str] = set() 

378 loop_variables: set[str] = set() 

379 

380 def collect_variables(node: ast.AST) -> None: 

381 """ 

382 Collect or remove variables based on load/store and delete statements. 

383 """ 

384 if isinstance(node, ast.Name): 

385 if isinstance(node.ctx, ast.Load): 

386 used_variables.add(node.id) 

387 elif isinstance(node.ctx, ast.Store): 

388 defined_variables.add(node.id) 

389 elif isinstance(node.ctx, ast.Del): 

390 defined_variables.discard(node.id) 

391 

392 def collect_definitions(node: ast.AST) -> None: 

393 """ 

394 Collect variable definitions via other ways. 

395 """ 

396 if not isinstance(node, (ast.Assign, ast.AnnAssign)): 

397 return 

398 

399 def handle_elts(elts: list[ast.expr]) -> None: 

400 """ 

401 Handle recursive definitions such as tuples. 

402 """ 

403 for node in elts: 

404 # with contextlib.suppress(Exception): 

405 try: 

406 if isinstance(node, ast.Subscript): 

407 node = node.value 

408 

409 if isinstance(node, ast.Tuple): 

410 # recurse 

411 handle_elts(node.elts) 

412 continue 

413 

414 if var := getattr(node, "id", None): 

415 defined_variables.add(var) 

416 except Exception as e: # pragma: no cover 

417 warnings.warn("Something went wrong trying to find variables.", source=e) 

418 # raise 

419 

420 handle_elts(node.targets if hasattr(node, "targets") else [node.target]) 

421 

422 def collect_imports(node: ast.AST) -> None: 

423 """ 

424 Get defined variables via imports. 

425 """ 

426 if isinstance(node, ast.Import): 

427 for alias in node.names: 

428 imported_names.add(alias.name) 

429 elif isinstance(node, ast.ImportFrom) and node.module: 

430 module_name = node.module 

431 with contextlib.suppress(ImportError): 

432 imported_module = importlib.import_module(module_name) 

433 

434 if node.names[0].name == "*": 

435 imported_names.update(name for name in dir(imported_module) if not name.startswith("_")) 

436 else: 

437 imported_names.update(alias.asname or alias.name for alias in node.names) 

438 

439 def collect_imported_names(node: ast.AST) -> None: 

440 """ 

441 Get defined variables via import from. 

442 """ 

443 if isinstance(node, ast.ImportFrom) and node.module: 

444 for alias in node.names: 

445 imported_names.add(alias.asname or alias.name) 

446 

447 def collect_loop_variables(node: ast.AST) -> None: 

448 """ 

449 Get variables defined in a loop (for var in ...). 

450 """ 

451 if isinstance(node, ast.For) and isinstance(node.target, ast.Name): 

452 loop_variables.add(node.target.id) 

453 

454 def collect_everything(node: ast.AST) -> None: 

455 """ 

456 Run the functions above to get all variables from the code. 

457 """ 

458 collect_variables(node) 

459 collect_definitions(node) 

460 collect_imported_names(node) 

461 collect_imports(node) 

462 collect_loop_variables(node) 

463 

464 # manually rewritten (2.19s for 10k): 

465 traverse_ast(tree, collect_everything) 

466 

467 all_variables = ( 

468 defined_variables | imported_modules | loop_variables | imported_names | (BUILTINS if with_builtins else set()) 

469 ) 

470 

471 return used_variables, all_variables 

472 

473 

474def find_missing_variables(code: str) -> set[str]: 

475 """ 

476 Finds and returns all missing variables in the given code. 

477 

478 Args: 

479 code (str): The code to check for missing variables. 

480 

481 Returns: 

482 set: A set of names of missing variables. 

483 """ 

484 used_variables, defined_variables = find_variables(code) 

485 return {var for var in used_variables if var not in defined_variables} 

486 

487 

488T = typing.TypeVar("T", bound=Any) 

489 

490 

491class Empty: 

492 """ 

493 Class that does absolutely nothing. 

494 

495 but can be accessed like an object (obj.something.whatever) 

496 or a dict[with][some][keys] 

497 """ 

498 

499 # todo: overload more methods 

500 

501 def __init__(self, *_: Any, **__: Any) -> None: 

502 """ 

503 Can be passed any vars. 

504 """ 

505 

506 def __bool__(self) -> bool: 

507 """ 

508 An `empty` object is False so it can be `or`-ed. 

509 """ 

510 return False 

511 

512 def __getattribute__(self, _: str) -> Self: 

513 """ 

514 Accessing .something. 

515 """ 

516 return self 

517 

518 def __getitem__(self, _: str) -> Self: 

519 """ 

520 Accessing ['something']. 

521 """ 

522 return self 

523 

524 def __iter__(self) -> typing.Generator[Self, Any, None]: 

525 """ 

526 Allows `for _ in Empty():`. 

527 

528 Only yields one item, itself. 

529 """ 

530 # fix set(empty) 

531 yield self # once 

532 

533 def __get__(self, *_: Any) -> Self: 

534 """ 

535 Called when empty is set as a property on another class. 

536 """ 

537 return self 

538 

539 def __call__(self, *_: Any, **__: Any) -> Self: 

540 """ 

541 When an instance gets called. 

542 

543 empty = Empty() 

544 empty() 

545 """ 

546 return self 

547 

548 def __str__(self) -> str: 

549 """ 

550 Empty string represent. 

551 """ 

552 return "" 

553 

554 def __repr__(self) -> str: 

555 """ 

556 Empty string represent. 

557 """ 

558 return "" 

559 

560 def __add__(self, other: T) -> T: 

561 """ 

562 Overlaods +. 

563 

564 empty + [] = [] 

565 """ 

566 return other 

567 

568 

569def generate_magic_code(missing_vars: set[str]) -> str: 

570 """ 

571 Generates code to define missing variables with a do-nothing object. 

572 

573 After finding missing vars, fill them in with an object that does nothing except return itself or an empty string. 

574 This way, it's least likely to crash (when used as default or validator in pydal, don't use this for running code!). 

575 

576 Args: 

577 missing_vars (set): The set of missing variable names. 

578 

579 Returns: 

580 str: The generated code. 

581 """ 

582 extra_code = ( 

583 "import typing; from typing import Any; " 

584 "from typing_extensions import Self; " 

585 "T = typing.TypeVar('T', bound=Any); " 

586 "\n" 

587 ) 

588 

589 extra_code += inspect.getsource(Empty) 

590 

591 extra_code += "\n\n" 

592 extra_code += "empty = Empty()" 

593 extra_code += "\n" 

594 

595 for variable in missing_vars: 

596 extra_code += f"{variable} = empty; " 

597 

598 return textwrap.dedent(extra_code)