Coverage for src/pydal2sql/magic.py: 97%

143 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-31 16:22 +0200

1""" 

2This file has methods to guess which variables are unknown and to potentially (monkey-patch) fix this. 

3""" 

4 

5 

6import ast 

7import builtins 

8import contextlib 

9import importlib 

10import textwrap 

11import typing 

12 

13BUILTINS = set(builtins.__dict__.keys()) 

14 

15 

16def traverse_ast(node: ast.AST, variable_collector: typing.Callable[[ast.AST], None]) -> None: 

17 """ 

18 Calls variable_collector on each node recursively. 

19 """ 

20 variable_collector(node) 

21 for child in ast.iter_child_nodes(node): 

22 traverse_ast(child, variable_collector) 

23 

24 

25def find_defined_variables(code_str: str) -> set[str]: 

26 tree: ast.Module = ast.parse(code_str) 

27 

28 variables: set[str] = set() 

29 

30 def collect_definitions(node: ast.AST) -> None: 

31 if isinstance(node, ast.Assign): 

32 node_targets = typing.cast(list[ast.Name], node.targets) 

33 

34 variables.update(target.id for target in node_targets) 

35 

36 traverse_ast(tree, collect_definitions) 

37 return variables 

38 

39 

40def remove_specific_variables(code: str, to_remove: typing.Iterable[str] = ("db", "database")) -> str: 

41 # Parse the code into an Abstract Syntax Tree (AST) - by ChatGPT 

42 tree = ast.parse(code) 

43 

44 # Function to check if a variable name is 'db' or 'database' 

45 def should_remove(var_name: str) -> bool: 

46 return var_name in to_remove 

47 

48 # Function to recursively traverse the AST and remove definitions of 'db' or 'database' 

49 def remove_db_and_database_defs_rec(node: ast.AST) -> typing.Optional[ast.AST]: 

50 if isinstance(node, ast.Assign): 

51 # Check if the assignment targets contain 'db' or 'database' 

52 new_targets = [ 

53 target for target in node.targets if not (isinstance(target, ast.Name) and should_remove(target.id)) 

54 ] 

55 node.targets = new_targets 

56 

57 elif isinstance(node, (ast.FunctionDef, ast.ClassDef)) and should_remove(node.name): 

58 # Check if function or class name is 'db' or 'database' 

59 return None 

60 

61 for child_node in ast.iter_child_nodes(node): 

62 # Recursively process child nodes 

63 new_child_node = remove_db_and_database_defs_rec(child_node) 

64 if new_child_node is None and hasattr(node, "body"): 

65 # If the child node was removed, remove it from the parent's body 

66 node.body.remove(child_node) 

67 

68 return node 

69 

70 # Traverse the AST to remove 'db' and 'database' definitions 

71 new_tree = remove_db_and_database_defs_rec(tree) 

72 

73 if not new_tree: # pragma: no cover 

74 return "" 

75 

76 # Generate the modified code from the new AST 

77 return ast.unparse(new_tree) 

78 

79 

80def find_local_imports(code: str) -> bool: 

81 class FindLocalImports(ast.NodeVisitor): 

82 def visit_ImportFrom(self, node: ast.ImportFrom) -> bool: 

83 if node.level > 0: # This means it's a relative import 

84 return True 

85 return False 

86 

87 tree = ast.parse(code) 

88 visitor = FindLocalImports() 

89 return any(visitor.visit(node) for node in ast.walk(tree)) 

90 

91 

92def remove_import(code: str, module_name: typing.Optional[str]) -> str: 

93 tree = ast.parse(code) 

94 new_body = [ 

95 node 

96 for node in tree.body 

97 if not isinstance(node, (ast.Import, ast.ImportFrom)) 

98 or (not isinstance(node, ast.Import) or all(alias.name != module_name for alias in node.names)) 

99 and (not isinstance(node, ast.ImportFrom) or node.module != module_name) 

100 ] 

101 tree.body = new_body 

102 return ast.unparse(tree) 

103 

104 

105def remove_local_imports(code: str) -> str: 

106 class RemoveLocalImports(ast.NodeTransformer): 

107 def visit_ImportFrom(self, node: ast.ImportFrom) -> typing.Optional[ast.ImportFrom]: 

108 if node.level > 0: # This means it's a relative import 

109 return None # Remove the node 

110 return node # Keep the node 

111 

112 tree = ast.parse(code) 

113 tree = RemoveLocalImports().visit(tree) 

114 return ast.unparse(tree) 

115 

116 

117# def find_function_to_call(code: str, target: str) -> typing.Optional[ast.FunctionDef]: 

118# tree = ast.parse(code) 

119# for node in ast.walk(tree): 

120# if isinstance(node, ast.FunctionDef) and node.name == target: 

121# return node 

122 

123def find_function_to_call(code, function_call_hint): 

124 function_name = function_call_hint.split('(')[0] # Extract function name from hint 

125 tree = ast.parse(code) 

126 for node in ast.walk(tree): 

127 if isinstance(node, ast.FunctionDef) and node.name == function_name: 

128 return function_name 

129 return None 

130 

131# def add_function_call(code: str, function_name: str) -> str: 

132# tree = ast.parse(code) 

133# # Create a function call node 

134# func_call = ast.Expr( 

135# value=ast.Call( 

136# func=ast.Name(id=function_name, ctx=ast.Load()), 

137# args=[ast.Name(id='db', ctx=ast.Load())], 

138# keywords=[] 

139# ) 

140# ) 

141# 

142# # Insert the function call right after the function definition 

143# for i, node in enumerate(tree.body): 

144# if isinstance(node, ast.FunctionDef) and node.name == function_name: 

145# tree.body.insert(i + 1, func_call) 

146# break 

147# 

148# return ast.unparse(tree) 

149 

150# def extract_function_details(function_call): 

151# function_name = function_call.split('(')[0] # Extract function name from hint 

152# if '(' in function_call: 

153# try: 

154# tree = ast.parse(function_call) 

155# for node in ast.walk(tree): 

156# if isinstance(node, ast.Call): 

157# return node.func.id, [ast.unparse(arg) for arg in node.args] 

158# except SyntaxError: 

159# pass 

160# return function_name, [] 

161 

162def extract_function_details(function_call): 

163 function_name = function_call.split('(')[0] # Extract function name from hint 

164 if '(' not in function_call: 

165 return function_name, ['db'] 

166 

167 with contextlib.suppress(SyntaxError): 

168 tree = ast.parse(function_call) 

169 for node in ast.walk(tree): 

170 if isinstance(node, ast.Call): 

171 # If no arguments are given, add 'db' automatically 

172 if len(node.args) == 0: 

173 return function_name, ['db'] 

174 else: 

175 return node.func.id, [ast.unparse(arg) for arg in node.args] 

176 return None, [] 

177 

178def add_function_call(code, function_call): 

179 function_name, args = extract_function_details(function_call) 

180 

181 tree = ast.parse(code) 

182 # Create a function call node 

183 new_call = ast.Call( 

184 func=ast.Name(id=function_name, ctx=ast.Load()), 

185 args=[ast.parse(arg).body[0].value for arg in args] if args else [], 

186 keywords=[] 

187 ) 

188 func_call = ast.Expr(value=new_call) 

189 

190 # Insert the function call right after the function definition 

191 for i, node in enumerate(tree.body): 

192 if isinstance(node, ast.FunctionDef) and node.name == function_name: 

193 tree.body.insert(i+1, func_call) 

194 break 

195 

196 return ast.unparse(tree) 

197 

198def find_variables(code_str: str) -> tuple[set[str], set[str]]: 

199 """ 

200 Look through the source code in code_str and try to detect using ast parsing which variables are undefined. 

201 """ 

202 # Partly made by ChatGPT 

203 code_str = textwrap.dedent(code_str) 

204 

205 # could raise SyntaxError 

206 tree: ast.Module = ast.parse(code_str) 

207 

208 used_variables: set[str] = set() 

209 defined_variables: set[str] = set() 

210 imported_modules: set[str] = set() 

211 imported_names: set[str] = set() 

212 loop_variables: set[str] = set() 

213 

214 def collect_variables(node: ast.AST) -> None: 

215 if isinstance(node, ast.Name): 

216 if isinstance(node.ctx, ast.Load): 

217 used_variables.add(node.id) 

218 elif isinstance(node.ctx, ast.Store): 

219 defined_variables.add(node.id) 

220 elif isinstance(node.ctx, ast.Del): 

221 defined_variables.discard(node.id) 

222 

223 def collect_definitions(node: ast.AST) -> None: 

224 if isinstance(node, ast.Assign): 

225 node_targets = typing.cast(list[ast.Name], node.targets) 

226 

227 defined_variables.update(target.id for target in node_targets) 

228 

229 def collect_imports(node: ast.AST) -> None: 

230 if isinstance(node, ast.Import): 

231 for alias in node.names: 

232 imported_names.add(alias.name) 

233 elif isinstance(node, ast.ImportFrom) and node.module: 

234 module_name = node.module 

235 imported_module = importlib.import_module(module_name) 

236 if node.names[0].name == "*": 

237 imported_names.update(name for name in dir(imported_module) if not name.startswith("_")) 

238 else: 

239 imported_names.update(alias.asname or alias.name for alias in node.names) 

240 

241 def collect_imported_names(node: ast.AST) -> None: 

242 if isinstance(node, ast.ImportFrom) and node.module: 

243 for alias in node.names: 

244 imported_names.add(alias.asname or alias.name) 

245 

246 def collect_loop_variables(node: ast.AST) -> None: 

247 if isinstance(node, ast.For) and isinstance(node.target, ast.Name): 

248 loop_variables.add(node.target.id) 

249 

250 def collect_everything(node: ast.AST) -> None: 

251 collect_variables(node) 

252 collect_definitions(node) 

253 collect_imported_names(node) 

254 collect_imports(node) 

255 collect_loop_variables(node) 

256 

257 # manually rewritten (2.19s for 10k): 

258 traverse_ast(tree, collect_everything) 

259 

260 all_variables = defined_variables | imported_modules | loop_variables | imported_names | BUILTINS 

261 

262 return used_variables, all_variables 

263 

264 

265def find_missing_variables(code: str) -> set[str]: 

266 used_variables, defined_variables = find_variables(code) 

267 return {var for var in used_variables if var not in defined_variables} 

268 

269 

270def generate_magic_code(missing_vars: set[str]) -> str: 

271 """ 

272 After finding missing vars, fill them in with an object that does nothing except return itself or an empty string. 

273 

274 This way, it's least likely to crash (when used as default or validator in pydal, don't use this for running code!). 

275 """ 

276 extra_code = """ 

277 class Empty: 

278 # class that does absolutely nothing 

279 # but can be accessed like an object (obj.something.whatever) 

280 # or a dict[with][some][keys] 

281 def __getattribute__(self, _): 

282 return self 

283 

284 def __getitem__(self, _): 

285 return self 

286 

287 def __get__(self): 

288 return self 

289 

290 def __call__(self, *_): 

291 return self 

292 

293 def __str__(self): 

294 return '' 

295 

296 def __repr__(self): 

297 return '' 

298 

299 # todo: overload more methods 

300 empty = Empty() 

301 \n 

302 """ 

303 for variable in missing_vars: 

304 extra_code += f"{variable} = empty; " 

305 

306 return textwrap.dedent(extra_code)