Coverage for pysource_minimize/_minimize_base.py: 94%

211 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-03-25 08:14 +0100

1import ast 

2import copy 

3import sys 

4from typing import List 

5from typing import Union 

6 

7 

8TESTING = False 

9 

10 

11def is_block(nodes): 

12 return ( 

13 isinstance(nodes, list) 

14 and nodes 

15 and all(isinstance(n, ast.stmt) for n in nodes) 

16 ) 

17 

18 

19class StopMinimization(Exception): 

20 pass 

21 

22 

23class CoverageRequired(Exception): 

24 pass 

25 

26 

27def coverage_required(): 

28 if TESTING: 

29 raise CoverageRequired() 

30 

31 

32def equal_ast(lhs, rhs): 

33 if type(lhs) != type(rhs): 

34 return False 

35 

36 elif isinstance(lhs, list): 

37 if len(lhs) != len(rhs): 

38 return False 

39 

40 return all(equal_ast(l, r) for l, r in zip(lhs, rhs)) 

41 

42 elif isinstance(lhs, ast.AST): 

43 return all( 

44 equal_ast(getattr(lhs, field), getattr(rhs, field)) 

45 for field in lhs._fields 

46 if field not in ("ctx",) 

47 ) 

48 else: 

49 return lhs == rhs 

50 assert False, f"unexpected type {type(lhs)}" 

51 

52 

53class ValueWrapper(ast.AST): 

54 def __init__(self, value=None): 

55 self.value = value 

56 

57 def __repr__(self): 

58 return f"ValueWrapper({self.value!r})" 

59 

60 def __eq__(self, other): 

61 return self.value == other 

62 

63 

64def arguments( 

65 node: Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda] 

66) -> List[ast.arg]: 

67 args = node.args 

68 l = [*args.args, args.vararg, *args.kwonlyargs, args.kwarg] 

69 

70 if sys.version_info >= (3, 8): 70 ↛ 73line 70 didn't jump to line 73, because the condition on line 70 was never false

71 l += args.posonlyargs 

72 

73 return [arg for arg in l if arg is not None] 

74 

75 

76class MinimizeBase: 

77 allow_multiple_mappings = False 

78 

79 def __init__(self, original_ast, checker, progress_callback): 

80 self.checker = checker 

81 self.progress_callback = progress_callback 

82 self.stop = False 

83 

84 # duplicate nodes like ast.Load() 

85 class UniqueObj(ast.NodeTransformer): 

86 def visit(self, node): 

87 if not node._fields: 

88 return type(node)() 

89 return super().visit(node) 

90 

91 self.original_ast = UniqueObj().visit(copy.deepcopy(original_ast)) 

92 

93 self.original_nodes_number = self.nodes_of(self.original_ast) 

94 

95 def wrap(value): 

96 nonlocal i 

97 if isinstance(value, ast.AST): 

98 return value 

99 elif isinstance(value, list): 

100 return [wrap(e) for e in value] 

101 elif isinstance(value, (type(None), int, str, bytes)): 

102 return ValueWrapper(value) 

103 else: 

104 assert False 

105 

106 for node in ast.walk(self.original_ast): 

107 for name, value in ast.iter_fields(node): 

108 if (type(node).__name__, name) in [ 

109 ("arguments", "kw_defaults"), 

110 ("Nonlocal", "names"), 

111 ("Global", "names"), 

112 ("MatchClass", "kwd_attrs"), 

113 ]: 

114 setattr(node, name, wrap(value)) 

115 

116 for i, node in enumerate(ast.walk(self.original_ast)): 

117 node.__index = i 

118 

119 self.replaced = {} 

120 

121 try: 

122 if not self.checker(self.get_ast(self.original_ast)): 122 ↛ 123line 122 didn't jump to line 123, because the condition on line 122 was never true

123 raise ValueError("checker return False: nothing to minimize here") 

124 

125 self.minimize_stmt(self.original_ast) 

126 except StopMinimization: 

127 self.stop = True 

128 

129 def index_of(self, node): 

130 return node.__index 

131 

132 def get_ast(self, node, replaced={}): 

133 replaced = {**self.replaced, **replaced} 

134 

135 tmp_ast = copy.deepcopy(node) 

136 node_map = {n.__index: n for n in ast.walk(tmp_ast)} 

137 

138 if TESTING: 

139 for a, b in zip(ast.walk(tmp_ast), ast.walk(node)): 

140 assert a.__index == b.__index 

141 

142 unique__index = {} 

143 for n in ast.walk(tmp_ast): 

144 assert n.__index not in unique__index, (n, unique__index[n.__index]) 

145 unique__index[n.__index] = n 

146 

147 for node in ast.walk(tmp_ast): 

148 for field in node._fields: 

149 assert hasattr( 

150 node, field 

151 ), f"{node.__class__.__name__}.{field} is not defined" 

152 

153 def replaced_node(node): 

154 if not isinstance(node, ast.AST): 

155 return node 

156 if not hasattr(node, "_MinimizeBase__index"): 

157 return node 

158 i = node.__index 

159 while i in replaced: 

160 i = replaced[i] 

161 assert isinstance(i, (int, type(None), ast.AST)), (node, i) 

162 if i is None: 

163 return None 

164 if isinstance(i, ValueWrapper): 164 ↛ 165line 164 didn't jump to line 165, because the condition on line 164 was never true

165 return i.value 

166 if isinstance(i, ast.AST): 

167 return i 

168 result = node_map[i] 

169 

170 if isinstance(result, ValueWrapper): 170 ↛ 171line 170 didn't jump to line 171, because the condition on line 170 was never true

171 result = result.value 

172 

173 return result 

174 

175 def replaced_nodes(nodes, name): 

176 def replace(l): 

177 for i in l: 

178 if i not in replaced: 

179 yield node_map[i] 

180 else: 

181 next_i = replaced[i] 

182 if isinstance(next_i, int): 

183 yield from replace([next_i]) 

184 elif isinstance(next_i, list): 

185 yield from replace(next_i) 

186 elif isinstance(next_i, ast.AST): 

187 yield next_i 

188 elif next_i is None: 188 ↛ 191line 188 didn't jump to line 191, because the condition on line 188 was never false

189 yield None 

190 else: 

191 raise TypeError(type(next_i)) 

192 

193 if not all(isinstance(n, ast.AST) for n in nodes): 193 ↛ 194line 193 didn't jump to line 194, because the condition on line 193 was never true

194 return nodes 

195 

196 block = is_block(nodes) 

197 

198 result = list(replace([n.__index for n in nodes])) 

199 result = [e.value if isinstance(e, ValueWrapper) else e for e in result] 

200 

201 if not result and block and name not in ("orelse", "finalbody"): 

202 return [ast.Pass()] 

203 

204 if block: 

205 result = [ast.Expr(r) if isinstance(r, ast.expr) else r for r in result] 

206 

207 return result 

208 

209 def map_node(node): 

210 for name, child in ast.iter_fields(node): 

211 if ( 

212 hasattr(node, "_MinimizeBase__index") 

213 and (node.__index, name) in replaced 

214 ): 

215 setattr(node, name, replaced[(node.__index, name)]) 

216 elif isinstance(child, list): 

217 setattr(node, name, replaced_nodes(child, name)) 

218 else: 

219 setattr(node, name, replaced_node(child)) 

220 for child in ast.iter_child_nodes(node): 

221 map_node(child) 

222 

223 # TODO: this could be optimized (copy all, reduce) -> (generate new ast nodes) 

224 map_node(tmp_ast) 

225 

226 if TESTING: 

227 for node in ast.walk(tmp_ast): 

228 for field in node._fields: 

229 assert hasattr( 

230 node, field 

231 ), f"{node.__class__.__name__}.{field} is not defined" 

232 

233 for field, value in ast.iter_fields(node): 

234 if isinstance(value, list): 

235 assert not any(isinstance(e, ValueWrapper) for e in value) 

236 else: 

237 assert not isinstance(value, ValueWrapper) 

238 

239 if isinstance(node, ast.arguments): 

240 assert len(node.kw_defaults) == len(node.kwonlyargs) 

241 if sys.version_info >= (3, 8): 241 ↛ 227line 241 didn't jump to line 227, because the condition on line 241 was never false

242 assert len(node.defaults) <= len(node.posonlyargs) + len( 

243 node.args 

244 ) 

245 

246 return tmp_ast 

247 

248 def get_current_node(self, ast_node): 

249 return self.get_ast(ast_node) 

250 

251 def get_current_tree(self, replaced): 

252 tree = self.get_ast(self.original_ast, replaced) 

253 ast.fix_missing_locations(tree) 

254 return tree 

255 

256 @staticmethod 

257 def nodes_of(tree): 

258 return len(list(ast.walk(tree))) 

259 

260 def try_with(self, replaced={}): 

261 """ 

262 returns True if the minimization was successfull 

263 """ 

264 

265 if TESTING and not self.allow_multiple_mappings: 

266 double_defined = self.replaced.keys() & replaced.keys() 

267 assert ( 

268 not double_defined 

269 ), f"the keys {double_defined} are mapped a second time" 

270 

271 tree = self.get_current_tree(replaced) 

272 

273 for node in ast.walk(tree): 

274 if isinstance(node, ast.Delete) and any( 274 ↛ exit,   274 ↛ 2802 missed branches: 1) line 274 didn't run the generator expression on line 274, 2) line 274 didn't jump to line 280, because the condition on line 274 was never true

275 isinstance(target, (ast.Constant, ast.NameConstant)) 

276 for target in node.targets 

277 ): 

278 # code like: 

279 # delete None 

280 return False 

281 

282 valid_minimization = False 

283 

284 try: 

285 valid_minimization = self.checker(tree) 

286 except StopMinimization: 

287 valid_minimization = True 

288 raise 

289 finally: 

290 if valid_minimization: 290 ↛ exitline 290 didn't except from function 'try_with', because the raise on line 288 wasn't executed

291 self.replaced.update(replaced) 

292 self.progress_callback(self.nodes_of(tree), self.original_nodes_number) 

293 

294 return valid_minimization 

295 

296 def try_attr(self, node, attr_name, new_attr): 

297 return self.try_with({(node.__index, attr_name): new_attr}) 

298 

299 def try_node(self, old_node, new_node): 

300 return self.try_with({old_node.__index: new_node}) 

301 

302 def try_without(self, nodes): 

303 return self.try_with({n.__index: [] for n in nodes}) 

304 

305 def try_none(self, node): 

306 if node is None: 

307 return True 

308 return self.try_with({node.__index: None}) 

309 

310 def try_only(self, node, *childs) -> bool: 

311 for child in childs: 

312 if isinstance(child, list): 

313 if self.try_with({node.__index: [c.__index for c in child]}): 

314 return True 

315 elif child is None: 

316 continue 

317 else: 

318 if self.try_with({node.__index: child.__index}): 

319 return True 

320 return False 

321 

322 def try_only_minimize(self, node, *childs): 

323 childs = [child for child in childs if child is not None] 

324 

325 for child in childs: 

326 if self.try_only(node, child): 

327 self.minimize(child) 

328 return True 

329 

330 for child in childs: 

331 self.minimize(child) 

332 return False 

333 

334 def minimize(self, o): 

335 raise NotImplementedError