Coverage for pysource_codegen/_codegen.py: 74%

917 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-06-17 07:45 +0200

1from __future__ import annotations 

2 

3import ast 

4import inspect 

5import itertools 

6import re 

7import sys 

8import traceback 

9from copy import deepcopy 

10from typing import Any 

11 

12from ._limits import f_string_expr_limit 

13from ._limits import f_string_format_limit 

14from ._utils import ast_dump 

15from ._utils import unparse 

16from .types import BuiltinNodeType 

17from .types import NodeType 

18from .types import UnionNodeType 

19 

20py38plus = (3, 8) <= sys.version_info 

21py39plus = (3, 9) <= sys.version_info 

22py310plus = (3, 10) <= sys.version_info 

23py311plus = (3, 11) <= sys.version_info 

24py312plus = (3, 12) <= sys.version_info 

25 

26type_infos: dict[str, NodeType | BuiltinNodeType | UnionNodeType] = {} 

27 

28 

29def all_args(args): 

30 if py38plus: 30 ↛ 33line 30 didn't jump to line 33, because the condition on line 30 was always true

31 return (args.posonlyargs, args.args, args.kwonlyargs) 

32 else: 

33 return (args.args, args.kwonlyargs) 

34 

35 

36def walk_until(node, stop): 

37 if isinstance(node, stop): 

38 return 

39 yield node 

40 if isinstance(node, list): 

41 for e in node: 

42 yield from walk_until(e, stop) 

43 return 

44 for child in ast.iter_child_nodes(node): 

45 yield from walk_until(child, stop) 

46 

47 

48def walk_function_nodes(node): 

49 if isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef, ast.Lambda)): 

50 for argument in arguments(node): 50 ↛ 51line 50 didn't jump to line 51, because the loop on line 50 never started

51 if argument.annotation: 

52 yield from walk_function_nodes(argument.annotation) 

53 for default in [*node.args.kw_defaults, *node.args.defaults]: 53 ↛ 54line 53 didn't jump to line 54, because the loop on line 53 never started

54 if default is not None: 

55 yield from walk_function_nodes(default) 

56 

57 if isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef)): 57 ↛ 64line 57 didn't jump to line 64, because the condition on line 57 was always true

58 for decorator in node.decorator_list: 58 ↛ 59line 58 didn't jump to line 59, because the loop on line 58 never started

59 yield from walk_function_nodes(decorator) 

60 

61 if node.returns is not None: 

62 yield from walk_function_nodes(node.returns) 

63 

64 return 

65 yield node 

66 if isinstance(node, list): 

67 for e in node: 

68 yield from walk_function_nodes(e) 

69 return 

70 for child in ast.iter_child_nodes(node): 

71 yield from walk_function_nodes(child) 

72 

73 

74def use(): 

75 """ 

76 this function is mocked in test_valid_source to ignore some decisions 

77 which are usually made by the algo. 

78 The goal is to try to generate some valid source code which would otherwise not be generated, 

79 becaus the algo falsely thinks it is invalid. 

80 """ 

81 return True 

82 

83 

84def equal_ast(lhs, rhs, dump_info=False, t="root"): 

85 if type(lhs) != type(rhs): 

86 if dump_info: 86 ↛ 88line 86 didn't jump to line 88, because the condition on line 86 was always true

87 print(t, lhs, "!=", rhs) 

88 return False 

89 

90 elif isinstance(lhs, list): 

91 if len(lhs) != len(rhs): 

92 if dump_info: 92 ↛ 94line 92 didn't jump to line 94, because the condition on line 92 was always true

93 print(t, lhs, "!=", rhs) 

94 return False 

95 

96 return all( 

97 equal_ast(l, r, dump_info, t + f"[{i}]") 

98 for i, (l, r) in enumerate(zip(lhs, rhs)) 

99 ) 

100 

101 elif isinstance(lhs, ast.AST): 

102 return all( 

103 equal_ast( 

104 getattr(lhs, field), getattr(rhs, field), dump_info, t + f".{field}" 

105 ) 

106 for field in lhs._fields 

107 ) 

108 else: 

109 if dump_info and lhs != rhs: 

110 print(t, lhs, "!=", rhs) 

111 return lhs == rhs 

112 

113 

114def get_info(name): 

115 if name in type_infos: 

116 return type_infos[name] 

117 elif name in ("identifier", "int", "string", "constant"): 

118 type_infos[name] = BuiltinNodeType(name) 

119 

120 else: 

121 doc = inspect.getdoc(getattr(ast, name)) or "" 

122 doc = doc.replace("\n", " ") 

123 

124 if doc: 

125 m = re.fullmatch(r"(\w*)", doc) 

126 if m: 

127 nt = NodeType(fields={}, ast_type=getattr(ast, name)) 

128 name = m.group(1) 

129 type_infos[name] = nt 

130 else: 

131 m = re.fullmatch(r"(\w*)\((.*)\)", doc) 

132 if m: 

133 nt = NodeType(fields={}, ast_type=getattr(ast, name)) 

134 name = m.group(1) 

135 type_infos[name] = nt 

136 for string_field in m.group(2).split(","): 

137 field_type, field_name = string_field.split() 

138 quantity = "" 

139 last = field_type[-1] 

140 if last in "*?": 

141 quantity = last 

142 field_type = field_type[:-1] 

143 

144 nt.fields[field_name] = (field_type, quantity) 

145 get_info(field_type) 

146 elif doc.startswith(f"{name} = "): 

147 doc = doc.split(" = ", 1)[1] 

148 nt = UnionNodeType(options=[]) 

149 type_infos[name] = nt 

150 nt.options = [d.split("(")[0] for d in doc.split(" | ")] 

151 for o in nt.options: 

152 get_info(o) 

153 

154 else: 

155 assert False, "can not parse:" + doc 

156 else: 

157 assert False, "no doc for " + name 

158 

159 return type_infos[name] 

160 

161 

162if sys.version_info < (3, 9): 

163 from .static_type_info import type_infos # type: ignore 

164 

165 

166import random 

167 

168 

169def only_firstone(l, condition): 

170 found = False 

171 for i, e in reversed(list(enumerate(l))): 

172 if condition(e): 

173 if found: 

174 del l[i] 

175 found = True 

176 

177 

178def unique_by(l, key): 

179 return list({key(e): e for e in l}.values()) 179 ↛ exitline 179 didn't run the dictionary comprehension on line 179

180 

181 

182class Invalid(Exception): 

183 pass 

184 

185 

186def propability(parents, child_name): 

187 try: 

188 return propability_try(parents, child_name) 

189 except Invalid: 

190 return 0 

191 

192 

193def propability_try(parents, child_name): 

194 parent_types = [p[0] for p in parents] 

195 

196 def inside(types, not_types=()): 

197 if not isinstance(types, tuple): 

198 types = (types,) 

199 

200 for parent, arg in reversed(parents): 

201 qual_parent = f"{parent}.{arg}" 

202 if any(qual_parent == t if "." in t else parent == t for t in types): 

203 return True 

204 if any(qual_parent == t if "." in t else parent == t for t in not_types): 

205 return False 

206 return False 

207 

208 if child_name in ("Store", "Del", "Load"): 

209 return 1 

210 

211 if child_name == "Slice" and not ( 

212 parents[-1] == ("Subscript", "slice") 

213 or parents[-2:] 

214 == [ 

215 ("Subscript", "slice"), 

216 ("Tuple", "elts"), 

217 ] 

218 ): 

219 raise Invalid 

220 

221 if child_name == "ExtSlice" and parents[-1] == ("ExtSlice", "dims"): 

222 # SystemError('extended slice invalid in nested slice') 

223 raise Invalid 

224 

225 # f-string 

226 if parents[-1] == ("JoinedStr", "values") and child_name not in ( 

227 "Constant", 

228 "FormattedValue", 

229 ): 

230 raise Invalid 

231 

232 if 0: 

233 if ( 

234 not py312plus 

235 and parents[-1] == ("FormattedValue", "value") 

236 and child_name != "Constant" 

237 ): 

238 # TODO: WHY? 

239 raise Invalid 

240 

241 if parents[-1] == ("FormattedValue", "format_spec") and child_name != "JoinedStr": 

242 raise Invalid 

243 

244 if ( 

245 child_name == "JoinedStr" 

246 and parents.count(("FormattedValue", "format_spec")) > f_string_format_limit 

247 ): 

248 raise Invalid 

249 

250 if ( 250 ↛ 254line 250 didn't jump to line 254

251 child_name == "JoinedStr" 

252 and parents.count(("FormattedValue", "value")) > f_string_expr_limit 

253 ): 

254 raise Invalid 

255 

256 if child_name == "FormattedValue" and parents[-1][0] != "JoinedStr": 

257 # TODO: doc says this should be valid, maybe a bug in the python doc 

258 # see https://github.com/python/cpython/issues/111257 

259 raise Invalid 

260 

261 if inside( 

262 ("Delete.targets"), ("Subscript.value", "Subscript.slice", "Attribute.value") 

263 ) and child_name not in ( 

264 "Name", 

265 "Attribute", 

266 "Subscript", 

267 "List", 

268 "Tuple", 

269 ): 

270 raise Invalid 

271 

272 # function statements 

273 if child_name in ( 

274 "Return", 

275 "Yield", 

276 "YieldFrom", 

277 ) and not inside( 

278 ("FunctionDef.body", "AsyncFunctionDef.body", "Lambda.body"), ("ClassDef.body",) 

279 ): 

280 raise Invalid 

281 # function statements 

282 if child_name in ("Nonlocal",) and not inside( 282 ↛ 285line 282 didn't jump to line 285, because the condition on line 282 was never true

283 ("FunctionDef.body", "AsyncFunctionDef.body", "Lambda.body", "ClassDef.body") 

284 ): 

285 raise Invalid 

286 

287 if ( 287 ↛ 295line 287 didn't jump to line 295

288 not py38plus 

289 and child_name == "Continue" 

290 and inside( 

291 ("Try.finalbody", "TryStar.finalbody"), 

292 ("FunctionDef.body", "AsyncFunctionDef.body"), 

293 ) 

294 ): 

295 raise Invalid 

296 

297 if parents[-1] == ("MatchMapping", "keys") and child_name != "Constant": 297 ↛ 299line 297 didn't jump to line 299, because the condition on line 297 was never true

298 # TODO: find all allowed key types 

299 raise Invalid 

300 

301 if child_name == "MatchStar" and parent_types[-1] != "MatchSequence": 301 ↛ 302line 301 didn't jump to line 302, because the condition on line 301 was never true

302 raise Invalid 

303 

304 if child_name == "Starred" and parents[-1] not in ( 

305 ("Tuple", "elts"), 

306 ("Call", "args"), 

307 ("List", "elts"), 

308 ("Set", "elts"), 

309 ("ClassDef", "bases"), 

310 ): 

311 raise Invalid 

312 

313 assign_target = ("Subscript", "Attribute", "Name", "Starred", "List", "Tuple") 

314 

315 assign_context = [p for p in parents if p[0] not in ("Tuple", "List", "Starred")] 

316 

317 if assign_context and assign_context[-1] in [ 

318 ("For", "target"), 

319 ("AsyncFor", "target"), 

320 ("AnnAssign", "target"), 

321 ("AugAssign", "target"), 

322 ("Assign", "targets"), 

323 ("withitem", "optional_vars"), 

324 ("comprehension", "target"), 

325 ]: 

326 if child_name not in assign_target: 

327 raise Invalid 

328 

329 if parents[-1] in [("AugAssign", "target"), ("AnnAssign", "target")]: 

330 if child_name in ("Starred", "List", "Tuple"): 

331 raise Invalid 

332 

333 if inside(("AnnAssign.target",)) and child_name == "Starred": 

334 # TODO this might be a cpython bug 

335 raise Invalid 

336 

337 if parents[-1] in [("AnnAssign", "target")]: 

338 if child_name not in ("Name", "Attribute", "Subscript"): 338 ↛ 339line 338 didn't jump to line 339, because the condition on line 338 was never true

339 raise Invalid 

340 

341 if parents[-1] in [("NamedExpr", "target")] and child_name != "Name": 

342 raise Invalid 

343 

344 in_async_code = inside( 

345 ("AsyncFunctionDef.body", "GeneratorExp.elt"), 

346 ("FunctionDef.body", "Lambda.body", "ClassDef.body"), 

347 ) 

348 

349 if child_name in ("AsyncFor", "Await", "AsyncWith") and not in_async_code: 

350 raise Invalid 

351 

352 if child_name in ("YieldFrom",) and in_async_code: 

353 raise Invalid 

354 

355 in_loop = inside( 

356 ("For.body", "While.body", "AsyncFor.body"), 

357 ("FunctionDef.body", "Lambda.body", "AsyncFunctionDef.body", "ClassDef.body"), 

358 ) 

359 

360 if child_name in ("Break", "Continue") and not in_loop: 

361 raise Invalid 

362 

363 if inside("TryStar.handlers") and child_name in ("Break", "Continue", "Return"): 

364 # SyntaxError: 'break', 'continue' and 'return' cannot appear in an except* block 

365 raise Invalid 

366 

367 if inside(("MatchValue",)) and child_name not in ( 

368 "Attribute", 

369 "Name", 

370 "Constant", 

371 "UnaryOp", 

372 "USub", 

373 ): 

374 raise Invalid 

375 

376 if ( 

377 inside("MatchValue.value") 

378 and inside("Attribute.value") 

379 and child_name not in ("Attribute", "Name") 

380 ): 

381 raise Invalid 

382 

383 if ( 

384 inside(("MatchValue",)) 

385 and inside(("UnaryOp",)) 

386 and child_name in ("Name", "UnaryOp", "Attribute") 

387 ): 

388 raise Invalid 

389 

390 if parents[-1] == ("MatchValue", "value") and child_name == "Name": 

391 raise Invalid 

392 

393 if inside("MatchClass.cls"): 393 ↛ 394line 393 didn't jump to line 394, because the condition on line 393 was never true

394 if child_name not in ("Name", "Attribute"): 

395 raise Invalid 

396 

397 if parents[-1] == ("comprehension", "iter") and child_name == "NamedExpr": 397 ↛ 398line 397 didn't jump to line 398, because the condition on line 397 was never true

398 raise Invalid 

399 

400 if inside( 400 ↛ 404line 400 didn't jump to line 404, because the condition on line 400 was never true

401 ("GeneratorExp", "ListComp", "SetComp", "DictComp", "DictComp") 

402 ) and child_name in ("Yield", "YieldFrom"): 

403 # SyntaxError: 'yield' inside list comprehension 

404 raise Invalid 

405 

406 if ( 406 ↛ 416line 406 didn't jump to line 416

407 inside(("GeneratorExp", "ListComp", "SetComp", "DictComp", "DictComp")) 

408 # TODO restrict to comprehension inside ClassDef 

409 and inside( 

410 "ClassDef.body", 

411 ("FunctionDef.body", "AsyncFunctionDef.body", "Lambda.body"), 

412 ) 

413 and child_name == "NamedExpr" 

414 ): 

415 # SyntaxError: assignment expression within a comprehension cannot be used in a class body 

416 raise Invalid 

417 

418 if not py39plus and any(p[1] == "decorator_list" for p in parents): 418 ↛ 422line 418 didn't jump to line 422, because the condition on line 418 was never true

419 # restricted decorators 

420 # see https://peps.python.org/pep-0614/ 

421 

422 deco_parents = list( 

423 itertools.takewhile(lambda a: a[1] != "decorator_list", reversed(parents)) 

424 )[::-1] 

425 

426 def valid_deco_parents(parents): 

427 # Call?,Attribute* 

428 parents = list(parents) 

429 if parents and parents[0] == ("Call", "func"): 

430 parents.pop() 

431 return all(p == ("Attribute", "value") for p in parents) 

432 

433 if valid_deco_parents(deco_parents) and child_name != "Name": 

434 raise Invalid 

435 

436 # type alias 

437 if py312plus: 

438 if parents[-1] == ("TypeAlias", "name") and child_name != "Name": 

439 raise Invalid 

440 

441 if ( 

442 child_name == "Lambda" 

443 and inside("TypeAlias.value") 

444 and inside("ClassDef.body") 

445 and sys.version_info < (3, 13) 

446 ): 

447 # SyntaxError('Cannot use lambda in annotation scope within class scope') 

448 raise Invalid 

449 

450 if child_name in ( 

451 # "NamedExpr", 

452 "Yield", 

453 "YieldFrom", 

454 "Await", 

455 # "DictComp", 

456 # "ListComp", 

457 # "SetComp", 

458 ) and inside( 

459 ( 

460 "ClassDef.bases", 

461 "ClassDef.keywords", 

462 "FunctionDef.returns", 

463 "AsyncFunctionDef.returns", 

464 "arg.annotation", 

465 "TypeAlias.value", 

466 "TypeVar.bound", 

467 ) 

468 ): 

469 # todo this should only be invalid in type scopes (when the class/def has type parameters) 

470 # and only for async comprehensions 

471 raise Invalid 

472 

473 if child_name in ("NamedExpr",) and inside( 

474 ("TypeAlias.value", "TypeVar.bound") 

475 ): 

476 # todo this should only be invalid in type scopes (when the class/def has type parameters) 

477 # and only for async comprehensions 

478 raise Invalid 

479 

480 if child_name == "Await" and inside("AnnAssign.annotation"): 480 ↛ 481line 480 didn't jump to line 481, because the condition on line 480 was never true

481 raise Invalid 

482 

483 if child_name == "Expr": 

484 return 30 

485 

486 if child_name == "NonLocal" and parents[-1] == ("Module", "body"): 486 ↛ 487line 486 didn't jump to line 487, because the condition on line 486 was never true

487 raise Invalid 

488 

489 return 1 

490 

491 

492def fix(node, parents): 

493 if isinstance(node, ast.ImportFrom): 493 ↛ 494line 493 didn't jump to line 494, because the condition on line 493 was never true

494 if use() and not py310plus and node.level is None: 

495 node.level = 0 

496 

497 if use() and node.module == None and (node.level == None or node.level == 0): 

498 node.level = 1 

499 

500 if isinstance(node, ast.ExceptHandler): 

501 if use() and node.type is None: 501 ↛ 504line 501 didn't jump to line 504, because the condition on line 501 was always true

502 node.name = None 

503 

504 if ( 

505 sys.version_info < (3, 11) 

506 and isinstance(node, ast.Tuple) 

507 and parents[-1] == ("Subscript", "slice") 

508 ): 

509 # a[(a:b,*c)] <- not valid 

510 # TODO check this 

511 found = False 

512 new_elts = [] 

513 # allow only the first Slice or Starred 

514 for e in node.elts: 

515 if isinstance(e, (ast.Starred, ast.Slice)): 

516 if not found: 

517 new_elts.append(e) 

518 found = True 

519 else: 

520 new_elts.append(e) 

521 node.elts = new_elts 

522 

523 if ( 

524 use() 

525 and isinstance(node, ast.AnnAssign) 

526 and not isinstance(node.target, ast.Name) 

527 ): 

528 node.simple = 0 

529 

530 if isinstance(node, ast.Constant): 

531 # TODO: what is Constant.kind 

532 # Constant.kind can be u for unicode strings 

533 allowed_kind: list[str | None] = [None] 

534 if isinstance(node.value, str): 

535 allowed_kind.append("u") 

536 elif node.kind not in allowed_kind: 536 ↛ 537line 536 didn't jump to line 537, because the condition on line 536 was never true

537 node.kind = allowed_kind[hash(node.kind) % len(allowed_kind)] 

538 

539 if ( 

540 use() 

541 and parents 

542 and parents[-1] == ("JoinedStr", "values") 

543 and not isinstance(node.value, str) 

544 ): 

545 # TODO: better format string generation 

546 node.value = str(node.value) 

547 

548 if isinstance(node, ast.FormattedValue): 548 ↛ 549line 548 didn't jump to line 549, because the condition on line 548 was never true

549 valid_conversion = (-1, 115, 114, 97) 

550 if use() and not py310plus and node.conversion is None: 

551 node.conversion = 5 

552 if use() and node.conversion not in valid_conversion: 

553 node.conversion = valid_conversion[node.conversion % 4] 

554 

555 assign_context = [p for p in parents if p[0] not in ("Tuple", "List", "Starred")] 

556 

557 if hasattr(node, "ctx"): 

558 if use() and assign_context and assign_context[-1] == ("Delete", "targets"): 

559 node.ctx = ast.Del() 

560 elif ( 

561 use() 

562 and assign_context 

563 and assign_context[-1] 

564 in ( 

565 ("Assign", "targets"), 

566 ("AnnAssign", "target"), 

567 ("AugAssign", "target"), 

568 ("NamedExpr", "target"), 

569 ("TypeAlias", "name"), 

570 ("For", "target"), 

571 ("AsyncFor", "target"), 

572 ("withitem", "optional_vars"), 

573 ("comprehension", "target"), 

574 ) 

575 ): 

576 node.ctx = ast.Store() 

577 else: 

578 node.ctx = ast.Load() 

579 

580 if ( 

581 use() 

582 and isinstance(node, (ast.List, ast.Tuple)) 

583 and isinstance(node.ctx, ast.Store) 

584 ): 

585 only_firstone(node.elts, lambda e: isinstance(e, ast.Starred)) 

586 

587 if use() and isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef, ast.Lambda)): 

588 # unique argument names 

589 seen = set() 

590 for args in all_args(node.args): 

591 for i, arg in reversed(list(enumerate(args))): 

592 if arg.arg in seen: 592 ↛ 593line 592 didn't jump to line 593, because the condition on line 592 was never true

593 del args[i] 

594 if node.args.defaults: 

595 del node.args.defaults[0] 

596 seen.add(arg.arg) 

597 

598 for arg_name in ("kwarg", "vararg"): 

599 arg = getattr(node.args, arg_name) 

600 if arg: 

601 if arg.arg in seen: 601 ↛ 603line 601 didn't jump to line 603, because the condition on line 601 was always true

602 setattr(node.args, arg_name, None) 

603 seen.add(arg.arg) 

604 arguments = node.args 

605 # kwonlyargs and kw_defaults has to have the same size 

606 min_kw_size = min(len(arguments.kwonlyargs), len(arguments.kw_defaults)) 

607 arguments.kwonlyargs = arguments.kwonlyargs[:min_kw_size] 

608 arguments.kw_defaults = arguments.kw_defaults[:min_kw_size] 

609 

610 if use() and isinstance(node, ast.AsyncFunctionDef): 

611 if any( 

612 isinstance(n, (ast.Yield, ast.YieldFrom)) 

613 for n in walk_function_nodes(node.body) 

614 ): 

615 for n in walk_function_nodes(node.body): 

616 if isinstance(n, ast.Return): 

617 n.value = None 

618 

619 if use() and isinstance(node, (ast.ClassDef, ast.Call)): 

620 # unique argument names 

621 seen = set() 

622 for i, kw in reversed(list(enumerate(node.keywords))): 

623 if kw.arg: 623 ↛ 624line 623 didn't jump to line 624, because the condition on line 623 was never true

624 if kw.arg in seen: 

625 del node.keywords[i] 

626 seen.add(kw.arg) 

627 

628 if use() and isinstance(node, (ast.Try)): 

629 node.handlers[:-1] = [ 

630 handler for handler in node.handlers[:-1] if handler.type is not None 

631 ] 

632 if use() and not node.handlers: 632 ↛ 633line 632 didn't jump to line 633, because the condition on line 632 was never true

633 node.orelse = [] 

634 

635 if use() and isinstance( 

636 node, (ast.GeneratorExp, ast.ListComp, ast.DictComp, ast.SetComp) 

637 ): 

638 # SyntaxError: assignment expression cannot rebind comprehension iteration variable 'name_3' 

639 names = { 

640 n.id 

641 for c in node.generators 

642 for n in ast.walk(c.target) 

643 if isinstance(n, ast.Name) 

644 } | { 

645 n.id 

646 for c in node.generators 

647 for n in ast.walk(c.iter) 

648 if isinstance(n, ast.Name) 

649 } 

650 

651 class Transformer(ast.NodeTransformer): 

652 def visit_NamedExpr(self, node: ast.NamedExpr): 

653 if use() and node.target.id in names: 653 ↛ 654line 653 didn't jump to line 654, because the condition on line 653 was never true

654 return self.visit(node.value) 

655 return self.generic_visit(node) 

656 

657 node = Transformer().visit(node) 

658 

659 # pattern matching 

660 if sys.version_info >= (3, 10): 

661 

662 def match_wildcard(node): 

663 if isinstance(node, ast.MatchAs): 

664 return ( 

665 node.pattern is None 

666 or match_wildcard(node.pattern) 

667 or node.name is None 

668 ) 

669 if isinstance(node, ast.MatchOr): 

670 return any(match_wildcard(p) for p in node.patterns) 670 ↛ exitline 670 didn't finish the generator expression on line 670

671 

672 if isinstance(node, ast.Match): 

673 found = False 

674 new_last = None 

675 for i, case_ in reversed(list(enumerate(node.cases))): 

676 p = case_.pattern 

677 if match_wildcard(p) and case_.guard is None: 

678 if not found: 678 ↛ 681line 678 didn't jump to line 681, because the condition on line 678 was always true

679 new_last = node.cases[i] 

680 found = True 

681 del node.cases[i] 

682 if new_last: 

683 node.cases.append(new_last) 

684 

685 if ( 

686 isinstance(node, ast.MatchValue) 

687 and isinstance(node.value, ast.UnaryOp) 

688 and isinstance(node.value.operand, ast.Constant) 

689 and type(node.value.operand.value) not in (int, float) 

690 ): 

691 node.value = node.value.operand 

692 

693 if ( 

694 isinstance(node, ast.MatchValue) 

695 and isinstance(node.value, ast.Constant) 

696 and any(node.value.value is v for v in (None, True, False)) 

697 ): 

698 return ast.MatchSingleton(value=node.value.value) 

699 

700 if isinstance(node, ast.MatchSingleton) and not any( 

701 node.value is v for v in (None, True, False) 

702 ): 

703 return ast.MatchValue(value=ast.Constant(value=node.value)) 

704 

705 # @lambda f:lambda pattern:set(f(pattern)) 

706 def names(node): 

707 if isinstance(node, ast.MatchAs) and node.name: 707 ↛ 708line 707 didn't jump to line 708, because the condition on line 707 was never true

708 yield node.name 

709 elif isinstance(node, ast.MatchStar) and node.name: 

710 yield node.name 

711 elif isinstance(node, ast.MatchMapping) and node.rest: 711 ↛ 712line 711 didn't jump to line 712, because the condition on line 711 was never true

712 yield node.rest 

713 elif isinstance(node, ast.MatchOr): 713 ↛ 714line 713 didn't jump to line 714, because the condition on line 713 was never true

714 yield from set.intersection( 

715 *[set(names(pattern)) for pattern in node.patterns] 

716 ) 

717 else: 

718 for child in ast.iter_child_nodes(node): 

719 yield from names(child) 

720 

721 class RemoveName(ast.NodeVisitor): 

722 def __init__(self, condition): 

723 self.condition = condition 

724 

725 def visit_MatchAs(self, node): 

726 if self.condition(node.name): 726 ↛ exitline 726 didn't return from function 'visit_MatchAs', because the condition on line 726 was always true

727 node.name = None 

728 

729 def visit_MatchMapping(self, node): 

730 if self.condition(node.rest): 

731 node.rest = None 

732 

733 class RemoveNameCleanup(ast.NodeTransformer): 

734 def visit_MatchAs(self, node): 

735 if node.name is None and node.pattern is not None: 

736 return self.visit(node.pattern) 

737 return self.generic_visit(node) 

738 

739 class FixPatternNames(ast.NodeTransformer): 

740 def __init__(self, used=None, allowed=None): 

741 # variables which are already used 

742 self.used = set() if used is None else used 

743 # variables which are allowed in a MatchOr 

744 self.allowed = allowed 

745 

746 def is_allowed(self, name): 

747 return ( 

748 name is None 

749 or name not in self.used 

750 and (name in self.allowed if self.allowed is not None else True) 

751 ) 

752 

753 def visit_MatchAs(self, node): 

754 if not self.is_allowed(node.name): 754 ↛ 755line 754 didn't jump to line 755, because the condition on line 754 was never true

755 return ast.MatchSingleton(value=None) 

756 elif node.name is not None: 

757 self.used.add(node.name) 

758 return self.generic_visit(node) 

759 

760 def visit_MatchStar(self, node): 

761 if not self.is_allowed(node.name): 761 ↛ 762line 761 didn't jump to line 762, because the condition on line 761 was never true

762 return ast.MatchSingleton(value=None) 

763 elif node.name is not None: 763 ↛ 765line 763 didn't jump to line 765, because the condition on line 763 was always true

764 self.used.add(node.name) 

765 return self.generic_visit(node) 

766 

767 def visit_MatchMapping(self, node): 

768 if not self.is_allowed(node.rest): 

769 return ast.MatchSingleton(value=None) 

770 elif node.rest is not None: 

771 self.used.add(node.rest) 

772 return self.generic_visit(node) 

773 

774 def visit_MatchOr(self, node: ast.MatchOr): 

775 allowed = set.intersection( 

776 *[set(names(pattern)) for pattern in node.patterns] 

777 ) 

778 allowed -= self.used 

779 

780 node.patterns = [ 

781 FixPatternNames(set(self.used), allowed).visit(child) 

782 for child in node.patterns 

783 ] 

784 

785 self.used |= allowed 

786 

787 return node 

788 

789 if isinstance(node, ast.match_case): 

790 node.pattern = FixPatternNames().visit(node.pattern) 

791 

792 if isinstance(node, ast.MatchMapping): 792 ↛ 794line 792 didn't jump to line 794, because the condition on line 792 was never true

793 

794 def can_literal_eval(node): 

795 try: 

796 hash(ast.literal_eval(node)) 

797 except ValueError: 

798 return False 

799 return True 

800 

801 node.keys = [k for k in node.keys if can_literal_eval(k)] 

802 

803 node.keys = unique_by(node.keys, ast.literal_eval) 

804 del node.patterns[len(node.keys) :] 

805 

806 seen = set() 

807 for pattern in node.patterns: 

808 RemoveName(lambda name: name in seen).visit(pattern) 

809 seen |= {*names(pattern)} 

810 

811 if isinstance(node, ast.MatchOr): 

812 var_names = set.intersection( 

813 *[set(names(pattern)) for pattern in node.patterns] 

814 ) 

815 

816 RemoveName(lambda name: name not in var_names).visit(node) 

817 

818 for i, pattern in enumerate(node.patterns): 818 ↛ 823line 818 didn't jump to line 823, because the loop on line 818 didn't complete

819 if match_wildcard(pattern): 

820 node.patterns = node.patterns[: i + 1] 

821 break 

822 

823 if len(node.patterns) == 1: 823 ↛ 824line 823 didn't jump to line 824, because the condition on line 823 was never true

824 return node.patterns[0] 

825 

826 if isinstance(node, ast.Match): 

827 for i, case in enumerate(node.cases): 

828 # default match `case _:` 

829 if ( 

830 isinstance(case.pattern, ast.MatchAs) 

831 and case.pattern.name is None 

832 or isinstance(case.pattern, ast.MatchOr) 

833 and isinstance(case.pattern.patterns[-1], ast.MatchAs) 

834 and case.pattern.patterns[-1].name is None 

835 and case.guard is None 

836 ): 

837 node.cases = node.cases[: i + 1] 

838 break 

839 

840 if isinstance(node, ast.MatchSequence): 

841 only_firstone(node.patterns, lambda e: isinstance(e, ast.MatchStar)) 

842 

843 seen = set() 

844 for pattern in node.patterns: 

845 RemoveName(lambda name: name in seen).visit(pattern) 845 ↛ exitline 845 didn't run the lambda on line 845

846 seen |= {*names(pattern)} 

847 

848 if isinstance(node, ast.MatchClass): 848 ↛ 849line 848 didn't jump to line 849, because the condition on line 848 was never true

849 node.kwd_attrs = unique_by(node.kwd_attrs, lambda e: e) 

850 del node.kwd_patterns[len(node.kwd_attrs) :] 

851 

852 seen = set() 

853 for pattern in [*node.patterns, *node.kwd_patterns]: 

854 RemoveName(lambda name: name in seen).visit(pattern) 

855 seen |= {*names(pattern)} 

856 

857 if isinstance(node, ast.Match): 

858 node = RemoveNameCleanup().visit(node) 

859 

860 # async nodes 

861 

862 in_async_code = False 

863 for parent, attr in reversed(parents): 

864 if parent == "AsyncFunctionDef" and attr == "body": 

865 in_async_code = True 

866 break 

867 if parent in ("FunctionDef", "Lambda", "ClassDef", "TypeAlias"): 

868 break 

869 

870 if not py311plus and parent in ( 

871 "ListComp", 

872 "DictComp", 

873 "SetComp", 

874 "GeneratorExp", 

875 ): 

876 break 

877 

878 if isinstance(node, (ast.ListComp, ast.SetComp, ast.DictComp)): 

879 if use() and not in_async_code: 

880 for comp in node.generators: 

881 comp.is_async = 0 

882 

883 in_excepthandler = False 

884 for parent, _ in reversed(parents): 

885 if parent == "ExceptHandler": 

886 in_excepthandler = True 

887 break 

888 if parent in ("FunctionDef", "Lambda", "AsyncFunctionDef"): 

889 break 

890 

891 if isinstance(node, ast.Raise): 

892 if use() and not node.exc: 

893 node.cause = None 

894 

895 if use() and isinstance(node, ast.Lambda): 

896 # no annotation for lambda arguments 

897 for args in all_args(node.args): 

898 for arg in args: 

899 arg.annotation = None 

900 

901 if use() and node.args.vararg: 901 ↛ 902line 901 didn't jump to line 902, because the condition on line 901 was never true

902 node.args.vararg.annotation = None 

903 

904 if use() and node.args.kwarg: 904 ↛ 905line 904 didn't jump to line 905, because the condition on line 904 was never true

905 node.args.kwarg.annotation = None 

906 

907 if sys.version_info >= (3, 12): 

908 # if use() and isinstance(node, ast.Global): 

909 # node.names = unique_by(node.names, lambda e: e) 

910 

911 # type scopes 

912 if use() and hasattr(node, "type_params"): 

913 node.type_params = unique_by(node.type_params, lambda p: p.name) 

914 

915 def cleanup_annotation(annotation): 

916 class Transformer(ast.NodeTransformer): 

917 def visit_NamedExpr(self, node: ast.NamedExpr): 

918 if not use(): 

919 return self.generic_visit(node) 

920 return self.visit(node.value) 

921 

922 def visit_Yield(self, node: ast.Yield) -> Any: 

923 if not use(): 

924 return self.generic_visit(node) 

925 if node.value is None: 

926 return ast.Constant(value=None) 

927 return self.visit(node.value) 

928 

929 def visit_YieldFrom(self, node: ast.YieldFrom) -> Any: 

930 if not use(): 

931 return self.generic_visit(node) 

932 return self.visit(node.value) 

933 

934 # def visit_Lambda(self, node: ast.Lambda) -> Any: 

935 # if not use(): 

936 # return self.generic_visit(node) 

937 # return self.visit(node.body) 

938 

939 return Transformer().visit(annotation) 

940 

941 if ( 941 ↛ 945line 941 didn't jump to line 945

942 isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) 

943 and node.type_params 

944 ): 

945 for arg in [ 

946 *node.args.posonlyargs, 

947 *node.args.args, 

948 *node.args.kwonlyargs, 

949 node.args.vararg, 

950 node.args.kwarg, 

951 ]: 

952 if use() and arg is not None and arg.annotation: 

953 arg.annotation = cleanup_annotation(arg.annotation) 

954 

955 if use() and node.returns is not None: 

956 node.returns = cleanup_annotation(node.returns) 

957 

958 if isinstance(node, ast.ClassDef) and node.type_params: 958 ↛ 959line 958 didn't jump to line 959, because the condition on line 958 was never true

959 node.bases = [cleanup_annotation(b) for b in node.bases] 

960 for kw in node.keywords: 

961 if use(): 

962 kw.value = cleanup_annotation(kw.value) 

963 

964 for n in ast.walk(node): 

965 if use() and isinstance(n, ast.TypeAlias): 

966 n.value = cleanup_annotation(n.value) 

967 

968 if isinstance(node, ast.ClassDef): 

969 for n in ast.walk(node): 

970 if use() and isinstance(n, ast.TypeVar) and n.bound is not None: 970 ↛ 971line 970 didn't jump to line 971, because the condition on line 970 was never true

971 n.bound = cleanup_annotation(n.bound) 

972 

973 if use() and isinstance(node, ast.AnnAssign): 

974 node.annotation = cleanup_annotation(node.annotation) 

975 

976 if sys.version_info >= (3, 13): 

977 if hasattr(node, "type_params"): 

978 # non-default type parameter 'name_1' follows default type parameter 

979 no_default = False 

980 for child in reversed(node.type_params): 

981 if child.default_value != None: 

982 no_default = True 

983 if use() and no_default: 

984 child.default_value = None 

985 

986 return node 

987 

988 

989def fix_result(node): 

990 return fix_nonlocal(node) 

991 

992 

993def is_valid_ast(tree) -> bool: 

994 def is_valid(node: ast.AST, parents): 

995 type_name = node.__class__.__name__ 

996 if ( 

997 isinstance(node, (ast.AST)) 

998 and parents 

999 and propability( 

1000 parents, 

1001 type_name, 

1002 ) 

1003 == 0 

1004 ): 

1005 print("invalid node with:") 

1006 print("parents:", parents) 

1007 print("node:", node) 

1008 

1009 try: 

1010 propability_try( 

1011 parents, 

1012 node.__class__.__name__, 

1013 ) 

1014 except Invalid: 

1015 frame = traceback.extract_tb(sys.exc_info()[2])[1] 

1016 print("file:", f"{frame.filename}:{frame.lineno}") 

1017 

1018 return False 

1019 

1020 if type_name in same_length: 

1021 attrs = same_length[type_name] 

1022 if len({len(v) for k, v in ast.iter_fields(node) if k in attrs}) != 1: 

1023 return False 

1024 

1025 if isinstance(node, (ast.AST)): 

1026 info = get_info(type_name) 

1027 assert isinstance(info, NodeType) 

1028 

1029 for attr_name, value in ast.iter_fields(node): 

1030 attr_info = info.fields[attr_name] 

1031 if attr_info[1] == "": 

1032 value_info = get_info(attr_info[0]) 

1033 if isinstance(value_info, UnionNodeType): 

1034 if type(value).__name__ not in value_info.options: 

1035 print( 

1036 f"{type(node).__name__}.{attr_name} {value} is not one type of {value_info.options}" 

1037 ) 

1038 print("parents are:", parents) 

1039 return False 

1040 

1041 if isinstance(value, list) and len(value) < min_attr_length( 

1042 type_name, attr_name 

1043 ): 

1044 print("invalid arg length", type_name, attr_name) 

1045 return False 

1046 

1047 if isinstance(value, list) != (info.fields[attr_name][1] == "*"): 1047 ↛ 1048line 1047 didn't jump to line 1048, because the condition on line 1047 was never true

1048 print("no list", value) 

1049 return False 

1050 if value is None: 

1051 if not ( 1051 ↛ 1058line 1051 didn't jump to line 1058, because the condition on line 1051 was never true

1052 ( 

1053 info.fields[attr_name][1] == "?" 

1054 and none_allowed(parents + [(type_name, attr_name)]) 

1055 ) 

1056 or info.fields[attr_name][0] == "constant" 

1057 ): 

1058 print("none not allowed", parents, type_name, attr_name) 

1059 return False 

1060 

1061 for field in node._fields: 

1062 value = getattr(node, field) 

1063 if isinstance(value, list): 

1064 if not all( 

1065 is_valid(e, parents + [(type_name, field)]) for e in value 

1066 ): 

1067 return False 

1068 else: 

1069 if not is_valid(value, parents + [(type_name, field)]): 

1070 return False 

1071 return True 

1072 

1073 if not is_valid(tree, []): 

1074 return False 

1075 

1076 tree_copy = deepcopy(tree) 

1077 

1078 def fix_tree(node: ast.AST, parents): 

1079 for field in node._fields: 

1080 value = getattr(node, field) 

1081 if isinstance(value, ast.AST): 

1082 setattr( 

1083 node, 

1084 field, 

1085 fix_tree(value, parents + [(node.__class__.__name__, field)]), 

1086 ) 

1087 if isinstance(value, list): 

1088 setattr( 

1089 node, 

1090 field, 

1091 [ 

1092 ( 

1093 fix_tree(v, parents + [(node.__class__.__name__, field)]) 

1094 if isinstance(v, ast.AST) 

1095 else v 

1096 ) 

1097 for v in value 

1098 ], 

1099 ) 

1100 

1101 return fix(node, parents) 

1102 

1103 tree_copy = fix_tree(tree_copy, []) 

1104 tree_copy = fix_result(tree_copy) 

1105 

1106 result = equal_ast(tree_copy, tree, dump_info=True) 

1107 

1108 if 1: 

1109 if sys.version_info >= (3, 9) and not result: 

1110 dump_copy = ast_dump(tree_copy).splitlines() 

1111 dump = ast_dump(tree).splitlines() 

1112 import difflib 

1113 

1114 print("ast was changed by during fixing:") 

1115 

1116 print("\n".join(difflib.unified_diff(dump, dump_copy, "original", "fixed"))) 

1117 

1118 return result 

1119 

1120 

1121def arguments( 

1122 node: ast.FunctionDef | ast.AsyncFunctionDef | ast.Lambda, 

1123) -> list[ast.arg]: 

1124 args = node.args 

1125 l = [ 

1126 *args.args, 

1127 args.vararg, 

1128 *args.kwonlyargs, 

1129 args.kwarg, 

1130 ] 

1131 

1132 l += args.posonlyargs 

1133 

1134 return [arg for arg in l if arg is not None] 

1135 

1136 

1137def fix_nonlocal(node): 

1138 class NonLocalFixer(ast.NodeTransformer): 

1139 """ 

1140 removes invalid Nonlocals from the class/function 

1141 """ 

1142 

1143 def __init__(self, locals, nonlocals, globals, type_params, parent_globals): 

1144 self.locals = set(locals) 

1145 self.used_names = set(locals) 

1146 self.type_params = set(type_params) 

1147 

1148 # nonlocals from the parent scope 

1149 self.nonlocals = set(nonlocals) 

1150 self.used_nonlocals = set() 

1151 

1152 # globals from the global scope 

1153 self.globals = set(globals) 

1154 self.used_globals = set() 

1155 self.parent_globals = parent_globals 

1156 

1157 def name_assigned(self, name): 

1158 self.locals.add(name) 

1159 self.used_names.add(name) 

1160 

1161 def visit_Name(self, node: ast.Name) -> Any: 

1162 if isinstance(node.ctx, (ast.Store, ast.Del)): 

1163 self.name_assigned(node.id) 

1164 else: 

1165 self.used_names.add(node.id) 

1166 return node 

1167 

1168 if sys.version_info >= (3, 10): 

1169 

1170 def visit_MatchAs(self, node: ast.MatchAs) -> Any: 

1171 if node.pattern: 

1172 self.visit(node.pattern) 

1173 self.name_assigned(node.name) 

1174 return node 

1175 

1176 def search_walrus(self, node): 

1177 for n in ast.walk(node): 

1178 if isinstance(n, ast.NamedExpr): 

1179 self.visit(n.target) 

1180 

1181 def visit_GeneratorExp(self, node: ast.GeneratorExp) -> Any: 

1182 self.visit(node.generators[0].iter) 

1183 self.search_walrus(node) 

1184 return node 

1185 

1186 def visit_ListComp(self, node: ast.ListComp) -> Any: 

1187 self.visit(node.generators[0].iter) 

1188 self.search_walrus(node) 

1189 return node 

1190 

1191 def visit_DictComp(self, node: ast.DictComp) -> Any: 

1192 self.visit(node.generators[0].iter) 

1193 self.search_walrus(node) 

1194 return node 

1195 

1196 def visit_SetComp(self, node: ast.SetComp) -> Any: 

1197 self.visit(node.generators[0].iter) 

1198 self.search_walrus(node) 

1199 return node 

1200 

1201 def visit_Nonlocal(self, node: ast.Nonlocal) -> Any: 

1202 # TODO: research __class__ seems to be defined in the class scope 

1203 # but it is also not 

1204 # class A: 

1205 # print(locals()) # no __class__ 

1206 # def f(): 

1207 # nonlocal __class__ # is A 

1208 node.names = [ 

1209 name 

1210 for name in node.names 

1211 if name not in self.locals 

1212 and name in self.nonlocals 

1213 and name not in self.used_names 

1214 and name not in self.type_params 

1215 and name not in self.parent_globals 

1216 and name not in self.used_globals 

1217 or name in ("__class__",) 

1218 ] 

1219 self.used_nonlocals |= set(node.names) 

1220 

1221 if not node.names: 

1222 return ast.Pass() 

1223 

1224 return node 

1225 

1226 def visit_Global(self, node: ast.Global) -> Any: 

1227 node.names = [ 

1228 name 

1229 for name in node.names 

1230 if name not in self.locals 

1231 and name not in self.used_names 

1232 and name not in self.used_nonlocals 

1233 ] 

1234 self.used_globals |= set(node.names) 

1235 

1236 if not node.names: 

1237 return ast.Pass() 

1238 

1239 return node 

1240 

1241 def visit_AnnAssign(self, node: ast.AnnAssign) -> Any: 

1242 if isinstance(node.target, ast.Name) and ( 

1243 node.target.id in self.used_globals 

1244 or node.target.id in self.used_nonlocals 

1245 ): 

1246 if node.value: 

1247 return self.generic_visit( 

1248 ast.Assign( 

1249 targets=[node.target], value=node.value, type_comment=None 

1250 ) 

1251 ) 

1252 else: 

1253 return ast.Pass() 

1254 return self.generic_visit(node) 

1255 

1256 def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: 

1257 self.name_assigned(node.name) 

1258 

1259 all_nodes = [ 

1260 *node.args.defaults, 

1261 *node.args.kw_defaults, 

1262 *node.decorator_list, 

1263 node.returns, 

1264 ] 

1265 

1266 all_nodes += [arg.annotation for arg in arguments(node)] 

1267 

1268 for default in all_nodes: 

1269 if default is not None: 

1270 self.visit(default) 

1271 

1272 return node 

1273 

1274 def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: 

1275 self.name_assigned(node.name) 

1276 

1277 all_nodes = [ 

1278 *node.args.defaults, 

1279 *node.args.kw_defaults, 

1280 *node.decorator_list, 

1281 node.returns, 

1282 ] 

1283 

1284 all_nodes += [arg.annotation for arg in arguments(node)] 

1285 

1286 for default in all_nodes: 

1287 if default is not None: 

1288 self.visit(default) 

1289 return node 

1290 

1291 def visit_ClassDef(self, node: ast.ClassDef) -> Any: 

1292 for expr in [ 

1293 *[k.value for k in node.keywords], 

1294 *node.bases, 

1295 *node.decorator_list, 

1296 ]: 

1297 if expr is not None: 1297 ↛ 1292line 1297 didn't jump to line 1292

1298 self.visit(expr) 

1299 

1300 self.name_assigned(node.name) 

1301 

1302 return node 

1303 

1304 # pattern matching 

1305 if sys.version_info >= (3, 10): 

1306 

1307 def visit_MatchMapping(self, node: ast.MatchMapping) -> Any: 

1308 if node.rest is not None: 

1309 self.name_assigned(node.rest) 

1310 return self.generic_visit(node) 

1311 

1312 if sys.version_info >= (3, 13): 

1313 

1314 def visit_MatchStar(self, node: ast.MatchStar) -> Any: 

1315 self.name_assigned(node.name) 

1316 return self.generic_visit(node) 

1317 

1318 def visit_ExceptHandler(self, handler): 

1319 if handler.name: 1319 ↛ 1320line 1319 didn't jump to line 1320, because the condition on line 1319 was never true

1320 self.name_assigned(handler.name) 

1321 return self.generic_visit(handler) 

1322 

1323 def visit_Lambda(self, node: ast.Lambda) -> Any: 

1324 for default in [*node.args.defaults, *node.args.kw_defaults]: 

1325 if default is not None: 1325 ↛ 1324line 1325 didn't jump to line 1324, because the condition on line 1325 was always true

1326 self.visit(default) 

1327 return node 

1328 

1329 if sys.version_info < (3, 13): 

1330 

1331 def visit_Try(self, node: ast.Try) -> Any: 

1332 # work around for https://github.com/python/cpython/issues/111123 

1333 args = {} 

1334 for k in ("body", "orelse", "handlers", "finalbody"): 

1335 args[k] = [self.visit(x) for x in getattr(node, k)] 

1336 

1337 return ast.Try(**args) 

1338 

1339 if sys.version_info >= (3, 11): 

1340 

1341 def visit_TryStar(self, node: ast.TryStar) -> Any: 

1342 # work around for https://github.com/python/cpython/issues/111123 

1343 args = {} 

1344 for k in ("body", "orelse", "handlers", "finalbody"): 

1345 args[k] = [self.visit(x) for x in getattr(node, k)] 

1346 

1347 return ast.TryStar(**args) 

1348 

1349 class FunctionTransformer(ast.NodeTransformer): 

1350 """ 

1351 - transformes a class/function 

1352 """ 

1353 

1354 def __init__(self, nonlocals, globals, type_params, parent_globals): 

1355 self.nonlocals = set(nonlocals) 

1356 self.globals = set(globals) 

1357 self.type_params = type_params 

1358 self.parent_globals = parent_globals 

1359 

1360 def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: 

1361 return self.handle_function(node) 

1362 

1363 def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: 

1364 return self.handle_function(node) 

1365 

1366 def visit_Lambda(self, node: ast.Lambda) -> Any: 

1367 # there are no globals/nonlocals/functiondefs in lambdas 

1368 return node 

1369 

1370 def visit_ClassDef(self, node: ast.ClassDef) -> Any: 

1371 type_params = set(self.type_params) 

1372 if sys.version_info >= (3, 12): 

1373 type_params |= {typ.name for typ in node.type_params} # type: ignore 1373 ↛ exitline 1373 didn't run the set comprehension on line 1373

1374 

1375 fixer = NonLocalFixer( 

1376 [], self.nonlocals, self.globals, type_params, self.parent_globals 

1377 ) 

1378 node.body = [fixer.visit(stmt) for stmt in node.body] 

1379 

1380 ft = FunctionTransformer( 

1381 self.nonlocals, self.globals, type_params, self.parent_globals 

1382 ) 

1383 node.body = [ft.visit(stmt) for stmt in node.body] 

1384 

1385 return node 

1386 

1387 def handle_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> Any: 

1388 names = {arg.arg for arg in arguments(node)} 

1389 

1390 type_params = set(self.type_params) 

1391 if sys.version_info >= (3, 12): 

1392 type_params |= {typ.name for typ in node.type_params} # type: ignore 1392 ↛ exitline 1392 didn't run the set comprehension on line 1392

1393 

1394 fixer = NonLocalFixer( 

1395 names, self.nonlocals, self.globals, type_params, self.parent_globals 

1396 ) 

1397 node.body = [fixer.visit(stmt) for stmt in node.body] 

1398 

1399 ft = FunctionTransformer( 

1400 fixer.locals | self.nonlocals, 

1401 self.globals, 

1402 type_params, 

1403 fixer.used_globals, 

1404 ) 

1405 node.body = [ft.visit(stmt) for stmt in node.body] 

1406 

1407 return node 

1408 

1409 fixer = NonLocalFixer([], [], [], [], []) 

1410 node = fixer.visit(node) 

1411 

1412 node = FunctionTransformer([], [], [], []).visit(node) 

1413 return node 

1414 

1415 

1416def min_attr_length(node_type, attr_name): 

1417 attr = f"{node_type}.{attr_name}" 

1418 if node_type == "Module" and attr_name == "body": 

1419 return 0 

1420 if attr_name == "body": 

1421 return 1 

1422 if node_type == "MatchOr" and attr_name == "patterns": 

1423 return 2 

1424 if node_type == "BoolOp" and attr_name == "values": 

1425 return 2 

1426 if node_type == "BinOp" and attr_name == "values": 1426 ↛ 1427line 1426 didn't jump to line 1427, because the condition on line 1426 was never true

1427 return 1 

1428 if node_type == "Import" and attr_name == "names": 

1429 return 1 

1430 if node_type == "ImportFrom" and attr_name == "names": 

1431 return 1 

1432 if node_type in ("With", "AsyncWith") and attr_name == "items": 

1433 return 1 

1434 if node_type in ("Try", "TryStar") and attr_name == "handlers": 

1435 return 1 

1436 if node_type == "Delete" and attr_name == "targets": 

1437 return 1 

1438 if node_type == "Match" and attr_name == "cases": 

1439 return 1 

1440 if node_type == "ExtSlice" and attr_name == "dims": 

1441 return 1 

1442 if sys.version_info < (3, 9) and node_type == "Set" and attr_name == "elts": 

1443 return 1 

1444 if node_type == "Compare" and attr_name in ("ops", "comparators"): 

1445 return 1 

1446 if attr_name == "generators": 

1447 return 1 

1448 

1449 if attr == "Assign.targets": 

1450 return 1 

1451 

1452 return 0 

1453 

1454 

1455def none_allowed(parents): 

1456 if parents[-2:] == [("TryStar", "handlers"), ("ExceptHandler", "type")]: 1456 ↛ 1457line 1456 didn't jump to line 1457, because the condition on line 1456 was never true

1457 return False 

1458 return True 

1459 

1460 

1461same_length = { 

1462 "MatchClass": ["kwd_attrs", "kwd_patterns"], 

1463 "MatchMapping": ["patterns", "keys"], 

1464 "arguments": ["kw_defaults", "kwonlyargs"], 

1465 "Compare": ["ops", "comparators"], 

1466 "Dict": ["keys", "values"], 

1467} 

1468 

1469 

1470class AstGenerator: 

1471 def __init__(self, seed, node_limit, depth_limit): 

1472 self.rand = random.Random(seed) 

1473 self.nodes = 0 

1474 self.node_limit = node_limit 

1475 self.depth_limit = depth_limit 

1476 

1477 def cnd(self): 

1478 return self.rand.choice([True, False]) 

1479 

1480 def generate(self, name: str, parents=(), depth=0): 

1481 result = self.generate_impl(name, parents, depth) 

1482 result = fix_result(result) 

1483 return result 

1484 

1485 def generate_impl(self, name: str, parents=(), depth=0): 

1486 depth += 1 

1487 self.nodes += 1 

1488 

1489 if depth > 100: 

1490 exit() 

1491 

1492 stop = depth > self.depth_limit or self.nodes > self.node_limit 

1493 

1494 info = get_info(name) 

1495 

1496 if isinstance(info, NodeType): 

1497 ranges = {} 

1498 

1499 def attr_length(child, attr_name): 

1500 if name == "Module": 

1501 return 20 

1502 

1503 if name in same_length: 

1504 attrs = same_length[name] 

1505 if attr_name in attrs[1:]: 

1506 return attr_length(child, attrs[0]) 

1507 

1508 if child == "arguments" and attr_name == "defaults": 

1509 min = 0 

1510 max = attr_length(child, "posonlyargs") + attr_length(child, "args") 

1511 ranges[attr_name] = self.rand.randint(min, max) 

1512 

1513 elif attr_name not in ranges: 

1514 min = min_attr_length(child, attr_name) 

1515 

1516 max = min if stop else min + 1 if depth > 10 else min + 5 

1517 ranges[attr_name] = self.rand.randint(min, max) 

1518 

1519 return ranges[attr_name] 

1520 

1521 def child_node(n, t, q, parents): 

1522 if q == "": 

1523 return self.generate_impl(t, parents, depth) 

1524 elif q == "*": 

1525 return [ 

1526 self.generate_impl(t, parents, depth) 

1527 for _ in range(attr_length(parents[-1][0], n)) 

1528 ] 

1529 elif q == "?": 

1530 return ( 

1531 self.generate_impl(t, parents, depth) 

1532 if not none_allowed(parents) or self.cnd() 

1533 else None 

1534 ) 

1535 else: 

1536 assert False 

1537 

1538 attributes = { 

1539 n: child_node(n, t, q, [*parents, (name, n)]) 

1540 for n, (t, q) in info.fields.items() 

1541 } 

1542 

1543 result = info.ast_type(**attributes) 

1544 result = fix(result, parents) 

1545 return result 

1546 

1547 if isinstance(info, UnionNodeType): 

1548 options_list = [ 

1549 (option, propability(parents, option)) for option in info.options 

1550 ] 

1551 

1552 invalid_option = [ 

1553 option for (option, prop) in options_list if prop == 0 and not use() 

1554 ] 

1555 

1556 assert len(invalid_option) in (0, 1), invalid_option 

1557 

1558 if len(invalid_option) == 1: 

1559 return self.generate_impl(invalid_option[0]) 

1560 

1561 options = dict(options_list) 

1562 if stop: 

1563 for final in ("Name", "MatchValue", "Pass"): 

1564 if options.get(final, 0) != 0: 

1565 options = {final: 1} 

1566 break 

1567 

1568 if sum(options.values()) == 0: 

1569 # TODO: better handling of `type?` 

1570 return None 

1571 

1572 return self.generate_impl( 

1573 self.rand.choices(*zip(*options.items()))[0], parents, depth 

1574 ) 

1575 if isinstance(info, BuiltinNodeType): 

1576 if info.kind == "identifier": 

1577 return f"name_{self.rand.randint(0,5)}" 

1578 elif info.kind == "int": 

1579 return self.rand.randint(0, 5) 

1580 elif info.kind == "string": 

1581 return self.rand.choice(["some text", ""]) 

1582 elif info.kind == "constant": 

1583 return self.rand.choice( 

1584 [ 

1585 None, 

1586 b"some bytes", 

1587 "some const text", 

1588 b"", 

1589 "", 

1590 "'\"'''\"\"\"{}\\", 

1591 b"'\"'''\"\"\"{}\\", 

1592 self.rand.randint(0, 20), 

1593 self.rand.uniform(0, 20), 

1594 True, 

1595 False, 

1596 ] 

1597 ) 

1598 

1599 else: 

1600 assert False, "unknown kind: " + info.kind 

1601 

1602 assert False 

1603 

1604 

1605import warnings 

1606 

1607 

1608def check(tree): 

1609 for node in ast.walk(tree): 

1610 if isinstance(node, ast.arguments): 

1611 assert len(node.posonlyargs) + len(node.args) >= len( 

1612 node.defaults 

1613 ), ast_dump(node) 

1614 assert len(node.kwonlyargs) == len(node.kw_defaults) 

1615 

1616 

1617def generate_ast( 

1618 seed: int, 

1619 *, 

1620 node_limit: int = 10000000, 

1621 depth_limit: int = 8, 

1622 root_node: str = "Module", 

1623) -> ast.AST: 

1624 generator = AstGenerator(seed, depth_limit=depth_limit, node_limit=node_limit) 

1625 

1626 with warnings.catch_warnings(): 

1627 warnings.simplefilter("ignore", SyntaxWarning) 

1628 tree = generator.generate(root_node) 

1629 check(tree) 

1630 

1631 ast.fix_missing_locations(tree) 

1632 return tree 

1633 

1634 

1635def generate( 

1636 seed: int, 

1637 *, 

1638 node_limit: int = 10000000, 

1639 depth_limit: int = 8, 

1640 root_node: str = "Module", 

1641) -> str: 

1642 tree = generate_ast( 

1643 seed, node_limit=node_limit, depth_limit=depth_limit, root_node=root_node 

1644 ) 

1645 return unparse(tree) 

1646 

1647 

1648# next algo 

1649 

1650# design targets: 

1651# * enumerate "all" possible ast-node combinations 

1652# * check if propability 0 would produce incorrect code 

1653# * the algo should be able to generate every possible syntax combination for every python version. 

1654# * hypothesis integration 

1655# * do not use compile() in the implementation 

1656# * generation should be customizable (custom propabilities and random values) 

1657 

1658# features: 

1659# * node-context: function-scope async-scope type-scope class-scope ... 

1660# * names: nonlocal global 

1661 

1662from dataclasses import dataclass 

1663 

1664 

1665@dataclass 

1666class ParentRef: 

1667 node: PartialNode 

1668 attr_name: str 

1669 index: int 

1670 _context: dict 

1671 

1672 def __getattr__(self, name): 

1673 if name.startswith("ctx_"): 

1674 return getattr(node, name) 

1675 raise AttributeError 

1676 

1677 

1678# (d:=[n] | q_parent("Delete.targets")) and len(d.targets)==1 

1679 

1680 

1681@dataclass 

1682class PartialValue: 

1683 value: int | str | bool 

1684 

1685 

1686@dataclass 

1687class PartialNode: 

1688 _node_type_name: str 

1689 parent_ref: ParentRef | None 

1690 _defined_attrs: dict 

1691 _context: dict 

1692 

1693 def inside(self, spec) -> PartialNode | None: ... 1693 ↛ exitline 1693 didn't jump to line 1693, because

1694 

1695 @property 

1696 def parent(self): 

1697 return self.parent_ref.node 

1698 

1699 def __getattr__(self, name): 

1700 if name.startswith("ctx_"): 

1701 return getattr(node, name) 

1702 

1703 if name not in self._defined_attrs: 

1704 raise RuntimeError(f"{self._node_type_name}.{name} is not defined jet") 

1705 

1706 return self._defined_attrs[name] 

1707 

1708 

1709def gen(node: PartialNode): 

1710 # parents [(node,attr_name)] 

1711 pass