Coverage for src/witchery/__init__.py: 100%

207 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-11-17 14:16 +0100

1""" 

2This library has methods to guess which variables are unknown and to potentially (monkey-patch) fix this. 

3""" 

4 

5# SPDX-FileCopyrightText: 2023-present Robin van der Noord <robinvandernoord@gmail.com> 

6# 

7# SPDX-License-Identifier: MIT 

8 

9 

10import ast 

11import builtins 

12import contextlib 

13import importlib 

14import inspect 

15import textwrap 

16import typing 

17import warnings 

18from _ast import NamedExpr 

19from typing import Any 

20 

21from typing_extensions import Self 

22 

23BUILTINS = set(builtins.__dict__.keys()) 

24 

25 

26def traverse_ast(node: ast.AST, variable_collector: typing.Callable[[ast.AST], None]) -> None: 

27 """ 

28 Recursively traverses the given AST node and applies the variable collector function on each node. 

29 

30 Args: 

31 node (ast.AST): The AST node to traverse. 

32 variable_collector (Callable): The function to apply on each node. 

33 """ 

34 variable_collector(node) 

35 for child in ast.iter_child_nodes(node): 

36 traverse_ast(child, variable_collector) 

37 

38 

39def find_defined_variables(code_str: str) -> set[str]: 

40 """ 

41 Parses the given Python code and finds all variables that are defined within. 

42 

43 A defined variable refers to any variable that is assigned a value in the code through direct assignment 

44 (e.g. `x = 5`). Other assignments such as through for-loops are ignored. 

45 Please use `find_variables` if more variable info is needed. 

46 

47 This function does not account for scope - it will find variables defined anywhere in the provided code string. 

48 

49 Args: 

50 code_str (str): A string of Python code. 

51 

52 Returns: 

53 set[str]: A set of variable names that are defined within the provided Python code. 

54 """ 

55 tree: ast.Module = ast.parse(code_str) 

56 

57 variables: set[str] = set() 

58 

59 def collect_definitions(node: ast.AST) -> None: 

60 if not isinstance(node, (ast.Assign, ast.AnnAssign)): 

61 # only look for variable definitions here! 

62 return 

63 

64 # define function that can be recursed: 

65 def handle_elts(elts: typing.Iterable[ast.expr]) -> None: 

66 for node in elts: 

67 # with contextlib.suppress(Exception): 

68 try: 

69 if isinstance(node, ast.Subscript): 

70 node = node.value 

71 

72 if isinstance(node, ast.Tuple): 

73 # recurse 

74 handle_elts(node.elts) 

75 continue 

76 

77 variables.add(getattr(node, "id")) 

78 

79 except Exception as e: # pragma: no cover 

80 warnings.warn("Something went wrong trying to find variables.", source=e) 

81 # raise 

82 

83 handle_elts(node.targets if hasattr(node, "targets") else [node.target]) 

84 

85 traverse_ast(tree, collect_definitions) 

86 return variables 

87 

88 

89def remove_specific_variables(code: str, to_remove: typing.Iterable[str] = ("db", "database")) -> str: 

90 """ 

91 Removes specific variables from the given code. 

92 

93 Args: 

94 code (str): The code from which to remove variables. 

95 to_remove (Iterable): An iterable of variable names to be removed. 

96 

97 Returns: 

98 str: The code after removing the specified variables. 

99 """ 

100 # Parse the code into an Abstract Syntax Tree (AST) 

101 tree = ast.parse(code) 

102 

103 # Function to check if a variable name is 'db' or 'database' 

104 def should_remove(var_name: str) -> bool: 

105 return var_name in to_remove 

106 

107 # Function to recursively traverse the AST and remove lines with 'db' or 'database' definitions 

108 def remove_desired_variable_refs(node: ast.AST) -> typing.Optional[ast.AST]: 

109 if isinstance(node, ast.Assign): 

110 # Check if any of the assignment targets contain 'db' or 'database' 

111 if any(isinstance(target, ast.Name) and should_remove(target.id) for target in node.targets): 

112 return None 

113 

114 elif isinstance(node, (ast.FunctionDef, ast.ClassDef)) and should_remove(node.name): 

115 return None 

116 

117 # doesn't work well without list() !!! 

118 for child_node in list(ast.iter_child_nodes(node)): 

119 new_child_node = remove_desired_variable_refs(child_node) 

120 if new_child_node is None and hasattr(node, "body"): 

121 node.body.remove(child_node) 

122 

123 return node 

124 

125 # Traverse the AST to remove 'db' and 'database' definitions 

126 new_tree = remove_desired_variable_refs(tree) 

127 

128 if not new_tree: # pragma: no cover 

129 return "" 

130 

131 # Generate the modified code from the new AST 

132 return ast.unparse(new_tree) 

133 

134 

135def has_local_imports(code: str) -> bool: 

136 """ 

137 Checks if the given code has local imports. 

138 

139 Args: 

140 code (str): The code to check for local imports. 

141 

142 Returns: 

143 bool: True if local imports are found, False otherwise. 

144 """ 

145 

146 class FindLocalImports(ast.NodeVisitor): 

147 def visit_ImportFrom(self, node: ast.ImportFrom) -> bool: 

148 if node.level > 0: # This means it's a relative import 

149 return True 

150 return False 

151 

152 tree = ast.parse(code) 

153 visitor = FindLocalImports() 

154 return any(visitor.visit(node) for node in ast.walk(tree)) 

155 

156 

157class ImportRemover(ast.NodeTransformer): 

158 """ 

159 Node visitor to remove imports (even in blocks). 

160 """ 

161 

162 def __init__(self, module_name: str) -> None: 

163 """ 

164 Set the module name to remove. 

165 """ 

166 self.module_name = module_name 

167 

168 def visit_Import(self, node: ast.Import) -> ast.Import | ast.Pass: 

169 """ 

170 Removes `import module_name`. 

171 """ 

172 node.names = [alias for alias in node.names if alias.name != self.module_name] 

173 return node if node.names else ast.Pass() 

174 

175 def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom | ast.Pass: 

176 """ 

177 Removes `from module_name import xyz`. 

178 """ 

179 if node.module == self.module_name: 

180 return ast.Pass() 

181 return node 

182 

183 

184def remove_import(code: str, module_name: str) -> str: 

185 """ 

186 Removes the import of a specific module from the given code, including inner scopes. 

187 

188 Args: 

189 code (str): The code from which to remove the import. 

190 module_name (str): The name of the module to remove. 

191 

192 Returns: 

193 str: The code after removing the import of the specified module. 

194 """ 

195 tree = ast.parse(code) 

196 transformer = ImportRemover(module_name) 

197 tree = transformer.visit(tree) 

198 return ast.unparse(tree) 

199 

200 

201def remove_local_imports(code: str) -> str: 

202 """ 

203 Removes all local imports from the given code. 

204 

205 Args: 

206 code (str): The code from which to remove local imports. 

207 

208 Returns: 

209 str: The code after removing all local imports. 

210 """ 

211 

212 class RemoveLocalImports(ast.NodeTransformer): 

213 def visit_ImportFrom(self, node: ast.ImportFrom) -> typing.Optional[ast.ImportFrom]: 

214 if node.level > 0: # This means it's a relative import 

215 return None # Remove the node 

216 return node # Keep the node 

217 

218 tree = ast.parse(code) 

219 tree = RemoveLocalImports().visit(tree) 

220 return ast.unparse(tree) 

221 

222 

223def find_function_to_call(code: str, function_call_hint: str) -> typing.Optional[str]: 

224 """ 

225 Finds the function to call in the given code based on the function call hint. 

226 

227 Args: 

228 code (str): The code in which to find the function. 

229 function_call_hint (str): The hint for the function call. 

230 

231 Returns: 

232 str, optional: The name of the function to call if found, None otherwise. 

233 """ 

234 function_name = function_call_hint.split("(")[0] # Extract function name from hint 

235 tree = ast.parse(code) 

236 return next( 

237 (function_name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef) and node.name == function_name), 

238 None, 

239 ) 

240 

241 

242DEFAULT_ARGS = ("db",) 

243 

244 

245def extract_function_details( 

246 function_call: str, default_args: typing.Iterable[str] = DEFAULT_ARGS 

247) -> tuple[str | None, list[str]]: 

248 """ 

249 Extracts the function name and arguments from the function call string. 

250 

251 Args: 

252 function_call (str): The function call string. 

253 default_args (Iterable, optional): The default arguments for the function. 

254 

255 Returns: 

256 tuple: A tuple containing the function name and a list of arguments. 

257 """ 

258 function_name = function_call.split("(")[0] # Extract function name from hint 

259 if "(" not in function_call: 

260 return function_name, list(default_args) 

261 

262 with contextlib.suppress(SyntaxError): 

263 tree = ast.parse(function_call) 

264 for node in ast.walk(tree): 

265 if isinstance(node, ast.Call): 

266 if len(node.args) == 0: 

267 # If no arguments are given, add 'db' automatically 

268 return function_name, list(default_args) 

269 

270 func = typing.cast(ast.Name, node.func) 

271 return func.id, [ast.unparse(arg) for arg in node.args] 

272 

273 return None, [] 

274 

275 

276def add_function_call( 

277 code: str, function_call: str, args: typing.Iterable[str] = DEFAULT_ARGS, multiple: bool = False 

278) -> str: 

279 """ 

280 Adds a function call to the given code. 

281 

282 Args: 

283 code (str): The code to which to add the function call. 

284 function_call (str): The function call string. 

285 args (Iterable, optional): The arguments for the function call. 

286 multiple (bool, optional): If True, add a call after every function with the specified name. 

287 

288 Returns: 

289 str: The code after adding the function call. 

290 """ 

291 function_name, args = extract_function_details(function_call, default_args=args) 

292 

293 def arg_value(arg: str) -> ast.Name: 

294 # make mypy happy 

295 body = typing.cast(NamedExpr, ast.parse(arg).body[0]) 

296 return typing.cast(ast.Name, body.value) 

297 

298 tree = ast.parse(code) 

299 # Create a function call node 

300 new_call = ast.Call( 

301 func=ast.Name(id=function_name, ctx=ast.Load()), 

302 args=[arg_value(arg) for arg in args] if args else [], 

303 keywords=[], 

304 ) 

305 func_call = ast.Expr(value=new_call) 

306 

307 # Insert the function call right after the function definition 

308 for i, node in enumerate(tree.body): 

309 if isinstance(node, ast.FunctionDef) and node.name == function_name: 

310 tree.body.insert(i + 1, func_call) 

311 if not multiple: 

312 break 

313 

314 return ast.unparse(tree) 

315 

316 

317def find_variables(code_str: str, with_builtins: bool = True) -> tuple[set[str], set[str]]: 

318 """ 

319 Finds all used and defined variables in the given code string. 

320 

321 Args: 

322 code_str (str): The code string to parse for variables. 

323 with_builtins (bool): include Python builtins? 

324 

325 Returns: 

326 tuple: A tuple containing sets of used and defined variables. 

327 """ 

328 # Partly made by ChatGPT 

329 code_str = textwrap.dedent(code_str) 

330 

331 # could raise SyntaxError 

332 tree: ast.Module = ast.parse(code_str) 

333 

334 used_variables: set[str] = set() 

335 defined_variables: set[str] = set() 

336 imported_modules: set[str] = set() 

337 imported_names: set[str] = set() 

338 loop_variables: set[str] = set() 

339 

340 def collect_variables(node: ast.AST) -> None: 

341 if isinstance(node, ast.Name): 

342 if isinstance(node.ctx, ast.Load): 

343 used_variables.add(node.id) 

344 elif isinstance(node.ctx, ast.Store): 

345 defined_variables.add(node.id) 

346 elif isinstance(node.ctx, ast.Del): 

347 defined_variables.discard(node.id) 

348 

349 def collect_definitions(node: ast.AST) -> None: 

350 if not isinstance(node, (ast.Assign, ast.AnnAssign)): 

351 return 

352 

353 def handle_elts(elts: list[ast.expr]) -> None: 

354 for node in elts: 

355 # with contextlib.suppress(Exception): 

356 try: 

357 if isinstance(node, ast.Subscript): 

358 node = node.value 

359 

360 if isinstance(node, ast.Tuple): 

361 # recurse 

362 handle_elts(node.elts) 

363 continue 

364 

365 if var := getattr(node, "id", None): 

366 defined_variables.add(var) 

367 except Exception as e: # pragma: no cover 

368 warnings.warn("Something went wrong trying to find variables.", source=e) 

369 # raise 

370 

371 handle_elts(node.targets if hasattr(node, "targets") else [node.target]) 

372 

373 def collect_imports(node: ast.AST) -> None: 

374 if isinstance(node, ast.Import): 

375 for alias in node.names: 

376 imported_names.add(alias.name) 

377 elif isinstance(node, ast.ImportFrom) and node.module: 

378 module_name = node.module 

379 with contextlib.suppress(ImportError): 

380 imported_module = importlib.import_module(module_name) 

381 

382 if node.names[0].name == "*": 

383 imported_names.update(name for name in dir(imported_module) if not name.startswith("_")) 

384 else: 

385 imported_names.update(alias.asname or alias.name for alias in node.names) 

386 

387 def collect_imported_names(node: ast.AST) -> None: 

388 if isinstance(node, ast.ImportFrom) and node.module: 

389 for alias in node.names: 

390 imported_names.add(alias.asname or alias.name) 

391 

392 def collect_loop_variables(node: ast.AST) -> None: 

393 if isinstance(node, ast.For) and isinstance(node.target, ast.Name): 

394 loop_variables.add(node.target.id) 

395 

396 def collect_everything(node: ast.AST) -> None: 

397 collect_variables(node) 

398 collect_definitions(node) 

399 collect_imported_names(node) 

400 collect_imports(node) 

401 collect_loop_variables(node) 

402 

403 # manually rewritten (2.19s for 10k): 

404 traverse_ast(tree, collect_everything) 

405 

406 all_variables = ( 

407 defined_variables | imported_modules | loop_variables | imported_names | (BUILTINS if with_builtins else set()) 

408 ) 

409 

410 return used_variables, all_variables 

411 

412 

413def find_missing_variables(code: str) -> set[str]: 

414 """ 

415 Finds and returns all missing variables in the given code. 

416 

417 Args: 

418 code (str): The code to check for missing variables. 

419 

420 Returns: 

421 set: A set of names of missing variables. 

422 """ 

423 used_variables, defined_variables = find_variables(code) 

424 return {var for var in used_variables if var not in defined_variables} 

425 

426 

427T = typing.TypeVar("T", bound=Any) 

428 

429 

430class Empty: 

431 # todo: overload more methods 

432 # class that does absolutely nothing 

433 # but can be accessed like an object (obj.something.whatever) 

434 # or a dict[with][some][keys] 

435 def __init__(self, *_: Any, **__: Any) -> None: 

436 ... 

437 

438 def __bool__(self) -> bool: 

439 return False 

440 

441 def __getattribute__(self, _: str) -> Self: 

442 return self 

443 

444 def __getitem__(self, _: str) -> Self: 

445 return self 

446 

447 def __iter__(self) -> typing.Generator[Self, Any, None]: 

448 # fix set(empty) 

449 yield self # once 

450 

451 def __get__(self, *_: Any) -> Self: 

452 return self 

453 

454 def __call__(self, *_: Any, **__: Any) -> Self: 

455 return self 

456 

457 def __str__(self) -> str: 

458 return "" 

459 

460 def __repr__(self) -> str: 

461 return "" 

462 

463 def __add__(self, other: T) -> T: 

464 return other 

465 

466 

467def generate_magic_code(missing_vars: set[str]) -> str: 

468 """ 

469 Generates code to define missing variables with a do-nothing object. 

470 

471 After finding missing vars, fill them in with an object that does nothing except return itself or an empty string. 

472 This way, it's least likely to crash (when used as default or validator in pydal, don't use this for running code!). 

473 

474 Args: 

475 missing_vars (set): The set of missing variable names. 

476 

477 Returns: 

478 str: The generated code. 

479 """ 

480 extra_code = ( 

481 "import typing; from typing import Any; " 

482 "from typing_extensions import Self; " 

483 "T = typing.TypeVar('T', bound=Any); " 

484 "\n" 

485 ) 

486 

487 extra_code += inspect.getsource(Empty) 

488 

489 extra_code += "\n\n" 

490 extra_code += "empty = Empty()" 

491 extra_code += "\n" 

492 

493 for variable in missing_vars: 

494 extra_code += f"{variable} = empty; " 

495 

496 return textwrap.dedent(extra_code)