Coverage for src/pydal2sql/magic.py: 100%

60 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-21 11:14 +0200

1""" 

2This file has methods to guess which variables are unknown and to potentially (monkey-patch) fix this. 

3""" 

4 

5import ast 

6import builtins 

7import importlib 

8import textwrap 

9import typing 

10 

11BUILTINS = set(builtins.__dict__.keys()) 

12 

13 

14def traverse_ast(node: ast.AST, variable_collector: typing.Callable[[ast.AST], None]) -> None: 

15 """ 

16 Calls variable_collector on each node recursively. 

17 """ 

18 variable_collector(node) 

19 for child in ast.iter_child_nodes(node): 

20 traverse_ast(child, variable_collector) 

21 

22 

23def find_missing_variables(code_str: str) -> set[str]: 

24 """ 

25 Look through the source code in code_str and try to detect using ast parsing which variables are undefined. 

26 """ 

27 # Partly made by ChatGPT 

28 code_str = textwrap.dedent(code_str) 

29 

30 # could raise SyntaxError 

31 tree: ast.Module = ast.parse(code_str) 

32 

33 used_variables: set[str] = set() 

34 defined_variables: set[str] = set() 

35 imported_modules: set[str] = set() 

36 imported_names: set[str] = set() 

37 loop_variables: set[str] = set() 

38 

39 def collect_variables(node: ast.AST) -> None: 

40 if isinstance(node, ast.Name): 

41 if isinstance(node.ctx, ast.Load): 

42 used_variables.add(node.id) 

43 elif isinstance(node.ctx, ast.Store): 

44 defined_variables.add(node.id) 

45 elif isinstance(node.ctx, ast.Del): 

46 defined_variables.discard(node.id) 

47 

48 def collect_definitions(node: ast.AST) -> None: 

49 if isinstance(node, ast.Assign): 

50 node_targets = typing.cast(list[ast.Name], node.targets) 

51 

52 defined_variables.update(target.id for target in node_targets) 

53 

54 def collect_imports(node: ast.AST) -> None: 

55 if isinstance(node, ast.Import): 

56 for alias in node.names: 

57 imported_names.add(alias.name) 

58 elif isinstance(node, ast.ImportFrom) and node.module: 

59 module_name = node.module 

60 imported_module = importlib.import_module(module_name) 

61 if node.names[0].name == "*": 

62 imported_names.update(name for name in dir(imported_module) if not name.startswith("_")) 

63 else: 

64 imported_names.update(alias.asname or alias.name for alias in node.names) 

65 

66 def collect_imported_names(node: ast.AST) -> None: 

67 if isinstance(node, ast.ImportFrom) and node.module: 

68 for alias in node.names: 

69 imported_names.add(alias.asname or alias.name) 

70 

71 def collect_loop_variables(node: ast.AST) -> None: 

72 if isinstance(node, ast.For) and isinstance(node.target, ast.Name): 

73 loop_variables.add(node.target.id) 

74 

75 def collect_everything(node: ast.AST) -> None: 

76 collect_variables(node) 

77 collect_definitions(node) 

78 collect_imported_names(node) 

79 collect_imports(node) 

80 collect_loop_variables(node) 

81 

82 # ChatGPT produced (4.20s for 10k): 

83 # traverse_ast(tree, collect_variables) 

84 # traverse_ast(tree, collect_definitions) 

85 # traverse_ast(tree, collect_imported_names) 

86 # traverse_ast(tree, collect_imports) 

87 # traverse_ast(tree, collect_loop_variables) 

88 

89 # manually rewritten (2.19s for 10k): 

90 traverse_ast(tree, collect_everything) 

91 return { 

92 var 

93 for var in used_variables 

94 if var not in defined_variables 

95 and var not in imported_modules 

96 and var not in loop_variables 

97 and var not in imported_names 

98 and var not in BUILTINS 

99 } 

100 

101 

102# if __name__ == "__main__": 

103# # Example usage: 

104# code_string = """ 

105# from math import floor 

106# import datetime 

107# from pydal import DAL 

108# a = 1 

109# b = 2 

110# print(a, b + c) 

111# d = e + b 

112# xyz 

113# floor(d) 

114# ceil(d) 

115# ceil(e) 

116# 

117# datetime.utcnow() 

118# 

119# db = DAL() 

120# 

121# db.define_table('...') 

122# 

123# for table in []: 

124# print(table) 

125# 

126# if toble := True: 

127# print(toble) 

128# """ 

129# 

130# # import timeit 

131# # 

132# # print( 

133# # timeit.timeit(lambda: find_missing_variables(code_string), number=10000) 

134# # 

135# # 

136# # ) 

137# 

138# missing_variables = find_missing_variables(code_string) 

139# assert missing_variables == {"c", "xyz", "ceil", "e"}, missing_variables 

140 

141 

142def generate_magic_code(missing_vars: set[str]) -> str: 

143 """ 

144 After finding missing vars, fill them in with an object that does nothing except return itself or an empty string. 

145 

146 This way, it's least likely to crash (when used as default or validator in pydal, don't use this for running code!). 

147 """ 

148 extra_code = """ 

149 class Empty: 

150 # class that does absolutely nothing 

151 # but can be accessed like an object (obj.something.whatever) 

152 # or a dict[with][some][keys] 

153 def __getattribute__(self, _): 

154 return self 

155 

156 def __getitem__(self, _): 

157 return self 

158 

159 def __get__(self): 

160 return self 

161 

162 def __call__(self, *_): 

163 return self 

164 

165 def __str__(self): 

166 return '' 

167 

168 def __repr__(self): 

169 return '' 

170 

171 # todo: overload more methods 

172 empty = Empty() 

173 \n 

174 """ 

175 for variable in missing_vars: 

176 extra_code += f"{variable} = empty; " 

177 

178 return textwrap.dedent(extra_code)