Coverage for pysource_codegen/_codegen.py: 74%
917 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-06-17 07:45 +0200
« prev ^ index » next coverage.py v7.5.3, created at 2024-06-17 07:45 +0200
1from __future__ import annotations
3import ast
4import inspect
5import itertools
6import re
7import sys
8import traceback
9from copy import deepcopy
10from typing import Any
12from ._limits import f_string_expr_limit
13from ._limits import f_string_format_limit
14from ._utils import ast_dump
15from ._utils import unparse
16from .types import BuiltinNodeType
17from .types import NodeType
18from .types import UnionNodeType
20py38plus = (3, 8) <= sys.version_info
21py39plus = (3, 9) <= sys.version_info
22py310plus = (3, 10) <= sys.version_info
23py311plus = (3, 11) <= sys.version_info
24py312plus = (3, 12) <= sys.version_info
26type_infos: dict[str, NodeType | BuiltinNodeType | UnionNodeType] = {}
29def all_args(args):
30 if py38plus: 30 ↛ 33line 30 didn't jump to line 33, because the condition on line 30 was always true
31 return (args.posonlyargs, args.args, args.kwonlyargs)
32 else:
33 return (args.args, args.kwonlyargs)
36def walk_until(node, stop):
37 if isinstance(node, stop):
38 return
39 yield node
40 if isinstance(node, list):
41 for e in node:
42 yield from walk_until(e, stop)
43 return
44 for child in ast.iter_child_nodes(node):
45 yield from walk_until(child, stop)
48def walk_function_nodes(node):
49 if isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef, ast.Lambda)):
50 for argument in arguments(node): 50 ↛ 51line 50 didn't jump to line 51, because the loop on line 50 never started
51 if argument.annotation:
52 yield from walk_function_nodes(argument.annotation)
53 for default in [*node.args.kw_defaults, *node.args.defaults]: 53 ↛ 54line 53 didn't jump to line 54, because the loop on line 53 never started
54 if default is not None:
55 yield from walk_function_nodes(default)
57 if isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef)): 57 ↛ 64line 57 didn't jump to line 64, because the condition on line 57 was always true
58 for decorator in node.decorator_list: 58 ↛ 59line 58 didn't jump to line 59, because the loop on line 58 never started
59 yield from walk_function_nodes(decorator)
61 if node.returns is not None:
62 yield from walk_function_nodes(node.returns)
64 return
65 yield node
66 if isinstance(node, list):
67 for e in node:
68 yield from walk_function_nodes(e)
69 return
70 for child in ast.iter_child_nodes(node):
71 yield from walk_function_nodes(child)
74def use():
75 """
76 this function is mocked in test_valid_source to ignore some decisions
77 which are usually made by the algo.
78 The goal is to try to generate some valid source code which would otherwise not be generated,
79 becaus the algo falsely thinks it is invalid.
80 """
81 return True
84def equal_ast(lhs, rhs, dump_info=False, t="root"):
85 if type(lhs) != type(rhs):
86 if dump_info: 86 ↛ 88line 86 didn't jump to line 88, because the condition on line 86 was always true
87 print(t, lhs, "!=", rhs)
88 return False
90 elif isinstance(lhs, list):
91 if len(lhs) != len(rhs):
92 if dump_info: 92 ↛ 94line 92 didn't jump to line 94, because the condition on line 92 was always true
93 print(t, lhs, "!=", rhs)
94 return False
96 return all(
97 equal_ast(l, r, dump_info, t + f"[{i}]")
98 for i, (l, r) in enumerate(zip(lhs, rhs))
99 )
101 elif isinstance(lhs, ast.AST):
102 return all(
103 equal_ast(
104 getattr(lhs, field), getattr(rhs, field), dump_info, t + f".{field}"
105 )
106 for field in lhs._fields
107 )
108 else:
109 if dump_info and lhs != rhs:
110 print(t, lhs, "!=", rhs)
111 return lhs == rhs
114def get_info(name):
115 if name in type_infos:
116 return type_infos[name]
117 elif name in ("identifier", "int", "string", "constant"):
118 type_infos[name] = BuiltinNodeType(name)
120 else:
121 doc = inspect.getdoc(getattr(ast, name)) or ""
122 doc = doc.replace("\n", " ")
124 if doc:
125 m = re.fullmatch(r"(\w*)", doc)
126 if m:
127 nt = NodeType(fields={}, ast_type=getattr(ast, name))
128 name = m.group(1)
129 type_infos[name] = nt
130 else:
131 m = re.fullmatch(r"(\w*)\((.*)\)", doc)
132 if m:
133 nt = NodeType(fields={}, ast_type=getattr(ast, name))
134 name = m.group(1)
135 type_infos[name] = nt
136 for string_field in m.group(2).split(","):
137 field_type, field_name = string_field.split()
138 quantity = ""
139 last = field_type[-1]
140 if last in "*?":
141 quantity = last
142 field_type = field_type[:-1]
144 nt.fields[field_name] = (field_type, quantity)
145 get_info(field_type)
146 elif doc.startswith(f"{name} = "):
147 doc = doc.split(" = ", 1)[1]
148 nt = UnionNodeType(options=[])
149 type_infos[name] = nt
150 nt.options = [d.split("(")[0] for d in doc.split(" | ")]
151 for o in nt.options:
152 get_info(o)
154 else:
155 assert False, "can not parse:" + doc
156 else:
157 assert False, "no doc for " + name
159 return type_infos[name]
162if sys.version_info < (3, 9):
163 from .static_type_info import type_infos # type: ignore
166import random
169def only_firstone(l, condition):
170 found = False
171 for i, e in reversed(list(enumerate(l))):
172 if condition(e):
173 if found:
174 del l[i]
175 found = True
178def unique_by(l, key):
179 return list({key(e): e for e in l}.values()) 179 ↛ exitline 179 didn't run the dictionary comprehension on line 179
182class Invalid(Exception):
183 pass
186def propability(parents, child_name):
187 try:
188 return propability_try(parents, child_name)
189 except Invalid:
190 return 0
193def propability_try(parents, child_name):
194 parent_types = [p[0] for p in parents]
196 def inside(types, not_types=()):
197 if not isinstance(types, tuple):
198 types = (types,)
200 for parent, arg in reversed(parents):
201 qual_parent = f"{parent}.{arg}"
202 if any(qual_parent == t if "." in t else parent == t for t in types):
203 return True
204 if any(qual_parent == t if "." in t else parent == t for t in not_types):
205 return False
206 return False
208 if child_name in ("Store", "Del", "Load"):
209 return 1
211 if child_name == "Slice" and not (
212 parents[-1] == ("Subscript", "slice")
213 or parents[-2:]
214 == [
215 ("Subscript", "slice"),
216 ("Tuple", "elts"),
217 ]
218 ):
219 raise Invalid
221 if child_name == "ExtSlice" and parents[-1] == ("ExtSlice", "dims"):
222 # SystemError('extended slice invalid in nested slice')
223 raise Invalid
225 # f-string
226 if parents[-1] == ("JoinedStr", "values") and child_name not in (
227 "Constant",
228 "FormattedValue",
229 ):
230 raise Invalid
232 if 0:
233 if (
234 not py312plus
235 and parents[-1] == ("FormattedValue", "value")
236 and child_name != "Constant"
237 ):
238 # TODO: WHY?
239 raise Invalid
241 if parents[-1] == ("FormattedValue", "format_spec") and child_name != "JoinedStr":
242 raise Invalid
244 if (
245 child_name == "JoinedStr"
246 and parents.count(("FormattedValue", "format_spec")) > f_string_format_limit
247 ):
248 raise Invalid
250 if ( 250 ↛ 254line 250 didn't jump to line 254
251 child_name == "JoinedStr"
252 and parents.count(("FormattedValue", "value")) > f_string_expr_limit
253 ):
254 raise Invalid
256 if child_name == "FormattedValue" and parents[-1][0] != "JoinedStr":
257 # TODO: doc says this should be valid, maybe a bug in the python doc
258 # see https://github.com/python/cpython/issues/111257
259 raise Invalid
261 if inside(
262 ("Delete.targets"), ("Subscript.value", "Subscript.slice", "Attribute.value")
263 ) and child_name not in (
264 "Name",
265 "Attribute",
266 "Subscript",
267 "List",
268 "Tuple",
269 ):
270 raise Invalid
272 # function statements
273 if child_name in (
274 "Return",
275 "Yield",
276 "YieldFrom",
277 ) and not inside(
278 ("FunctionDef.body", "AsyncFunctionDef.body", "Lambda.body"), ("ClassDef.body",)
279 ):
280 raise Invalid
281 # function statements
282 if child_name in ("Nonlocal",) and not inside( 282 ↛ 285line 282 didn't jump to line 285, because the condition on line 282 was never true
283 ("FunctionDef.body", "AsyncFunctionDef.body", "Lambda.body", "ClassDef.body")
284 ):
285 raise Invalid
287 if ( 287 ↛ 295line 287 didn't jump to line 295
288 not py38plus
289 and child_name == "Continue"
290 and inside(
291 ("Try.finalbody", "TryStar.finalbody"),
292 ("FunctionDef.body", "AsyncFunctionDef.body"),
293 )
294 ):
295 raise Invalid
297 if parents[-1] == ("MatchMapping", "keys") and child_name != "Constant": 297 ↛ 299line 297 didn't jump to line 299, because the condition on line 297 was never true
298 # TODO: find all allowed key types
299 raise Invalid
301 if child_name == "MatchStar" and parent_types[-1] != "MatchSequence": 301 ↛ 302line 301 didn't jump to line 302, because the condition on line 301 was never true
302 raise Invalid
304 if child_name == "Starred" and parents[-1] not in (
305 ("Tuple", "elts"),
306 ("Call", "args"),
307 ("List", "elts"),
308 ("Set", "elts"),
309 ("ClassDef", "bases"),
310 ):
311 raise Invalid
313 assign_target = ("Subscript", "Attribute", "Name", "Starred", "List", "Tuple")
315 assign_context = [p for p in parents if p[0] not in ("Tuple", "List", "Starred")]
317 if assign_context and assign_context[-1] in [
318 ("For", "target"),
319 ("AsyncFor", "target"),
320 ("AnnAssign", "target"),
321 ("AugAssign", "target"),
322 ("Assign", "targets"),
323 ("withitem", "optional_vars"),
324 ("comprehension", "target"),
325 ]:
326 if child_name not in assign_target:
327 raise Invalid
329 if parents[-1] in [("AugAssign", "target"), ("AnnAssign", "target")]:
330 if child_name in ("Starred", "List", "Tuple"):
331 raise Invalid
333 if inside(("AnnAssign.target",)) and child_name == "Starred":
334 # TODO this might be a cpython bug
335 raise Invalid
337 if parents[-1] in [("AnnAssign", "target")]:
338 if child_name not in ("Name", "Attribute", "Subscript"): 338 ↛ 339line 338 didn't jump to line 339, because the condition on line 338 was never true
339 raise Invalid
341 if parents[-1] in [("NamedExpr", "target")] and child_name != "Name":
342 raise Invalid
344 in_async_code = inside(
345 ("AsyncFunctionDef.body", "GeneratorExp.elt"),
346 ("FunctionDef.body", "Lambda.body", "ClassDef.body"),
347 )
349 if child_name in ("AsyncFor", "Await", "AsyncWith") and not in_async_code:
350 raise Invalid
352 if child_name in ("YieldFrom",) and in_async_code:
353 raise Invalid
355 in_loop = inside(
356 ("For.body", "While.body", "AsyncFor.body"),
357 ("FunctionDef.body", "Lambda.body", "AsyncFunctionDef.body", "ClassDef.body"),
358 )
360 if child_name in ("Break", "Continue") and not in_loop:
361 raise Invalid
363 if inside("TryStar.handlers") and child_name in ("Break", "Continue", "Return"):
364 # SyntaxError: 'break', 'continue' and 'return' cannot appear in an except* block
365 raise Invalid
367 if inside(("MatchValue",)) and child_name not in (
368 "Attribute",
369 "Name",
370 "Constant",
371 "UnaryOp",
372 "USub",
373 ):
374 raise Invalid
376 if (
377 inside("MatchValue.value")
378 and inside("Attribute.value")
379 and child_name not in ("Attribute", "Name")
380 ):
381 raise Invalid
383 if (
384 inside(("MatchValue",))
385 and inside(("UnaryOp",))
386 and child_name in ("Name", "UnaryOp", "Attribute")
387 ):
388 raise Invalid
390 if parents[-1] == ("MatchValue", "value") and child_name == "Name":
391 raise Invalid
393 if inside("MatchClass.cls"): 393 ↛ 394line 393 didn't jump to line 394, because the condition on line 393 was never true
394 if child_name not in ("Name", "Attribute"):
395 raise Invalid
397 if parents[-1] == ("comprehension", "iter") and child_name == "NamedExpr": 397 ↛ 398line 397 didn't jump to line 398, because the condition on line 397 was never true
398 raise Invalid
400 if inside( 400 ↛ 404line 400 didn't jump to line 404, because the condition on line 400 was never true
401 ("GeneratorExp", "ListComp", "SetComp", "DictComp", "DictComp")
402 ) and child_name in ("Yield", "YieldFrom"):
403 # SyntaxError: 'yield' inside list comprehension
404 raise Invalid
406 if ( 406 ↛ 416line 406 didn't jump to line 416
407 inside(("GeneratorExp", "ListComp", "SetComp", "DictComp", "DictComp"))
408 # TODO restrict to comprehension inside ClassDef
409 and inside(
410 "ClassDef.body",
411 ("FunctionDef.body", "AsyncFunctionDef.body", "Lambda.body"),
412 )
413 and child_name == "NamedExpr"
414 ):
415 # SyntaxError: assignment expression within a comprehension cannot be used in a class body
416 raise Invalid
418 if not py39plus and any(p[1] == "decorator_list" for p in parents): 418 ↛ 422line 418 didn't jump to line 422, because the condition on line 418 was never true
419 # restricted decorators
420 # see https://peps.python.org/pep-0614/
422 deco_parents = list(
423 itertools.takewhile(lambda a: a[1] != "decorator_list", reversed(parents))
424 )[::-1]
426 def valid_deco_parents(parents):
427 # Call?,Attribute*
428 parents = list(parents)
429 if parents and parents[0] == ("Call", "func"):
430 parents.pop()
431 return all(p == ("Attribute", "value") for p in parents)
433 if valid_deco_parents(deco_parents) and child_name != "Name":
434 raise Invalid
436 # type alias
437 if py312plus:
438 if parents[-1] == ("TypeAlias", "name") and child_name != "Name":
439 raise Invalid
441 if (
442 child_name == "Lambda"
443 and inside("TypeAlias.value")
444 and inside("ClassDef.body")
445 and sys.version_info < (3, 13)
446 ):
447 # SyntaxError('Cannot use lambda in annotation scope within class scope')
448 raise Invalid
450 if child_name in (
451 # "NamedExpr",
452 "Yield",
453 "YieldFrom",
454 "Await",
455 # "DictComp",
456 # "ListComp",
457 # "SetComp",
458 ) and inside(
459 (
460 "ClassDef.bases",
461 "ClassDef.keywords",
462 "FunctionDef.returns",
463 "AsyncFunctionDef.returns",
464 "arg.annotation",
465 "TypeAlias.value",
466 "TypeVar.bound",
467 )
468 ):
469 # todo this should only be invalid in type scopes (when the class/def has type parameters)
470 # and only for async comprehensions
471 raise Invalid
473 if child_name in ("NamedExpr",) and inside(
474 ("TypeAlias.value", "TypeVar.bound")
475 ):
476 # todo this should only be invalid in type scopes (when the class/def has type parameters)
477 # and only for async comprehensions
478 raise Invalid
480 if child_name == "Await" and inside("AnnAssign.annotation"): 480 ↛ 481line 480 didn't jump to line 481, because the condition on line 480 was never true
481 raise Invalid
483 if child_name == "Expr":
484 return 30
486 if child_name == "NonLocal" and parents[-1] == ("Module", "body"): 486 ↛ 487line 486 didn't jump to line 487, because the condition on line 486 was never true
487 raise Invalid
489 return 1
492def fix(node, parents):
493 if isinstance(node, ast.ImportFrom): 493 ↛ 494line 493 didn't jump to line 494, because the condition on line 493 was never true
494 if use() and not py310plus and node.level is None:
495 node.level = 0
497 if use() and node.module == None and (node.level == None or node.level == 0):
498 node.level = 1
500 if isinstance(node, ast.ExceptHandler):
501 if use() and node.type is None: 501 ↛ 504line 501 didn't jump to line 504, because the condition on line 501 was always true
502 node.name = None
504 if (
505 sys.version_info < (3, 11)
506 and isinstance(node, ast.Tuple)
507 and parents[-1] == ("Subscript", "slice")
508 ):
509 # a[(a:b,*c)] <- not valid
510 # TODO check this
511 found = False
512 new_elts = []
513 # allow only the first Slice or Starred
514 for e in node.elts:
515 if isinstance(e, (ast.Starred, ast.Slice)):
516 if not found:
517 new_elts.append(e)
518 found = True
519 else:
520 new_elts.append(e)
521 node.elts = new_elts
523 if (
524 use()
525 and isinstance(node, ast.AnnAssign)
526 and not isinstance(node.target, ast.Name)
527 ):
528 node.simple = 0
530 if isinstance(node, ast.Constant):
531 # TODO: what is Constant.kind
532 # Constant.kind can be u for unicode strings
533 allowed_kind: list[str | None] = [None]
534 if isinstance(node.value, str):
535 allowed_kind.append("u")
536 elif node.kind not in allowed_kind: 536 ↛ 537line 536 didn't jump to line 537, because the condition on line 536 was never true
537 node.kind = allowed_kind[hash(node.kind) % len(allowed_kind)]
539 if (
540 use()
541 and parents
542 and parents[-1] == ("JoinedStr", "values")
543 and not isinstance(node.value, str)
544 ):
545 # TODO: better format string generation
546 node.value = str(node.value)
548 if isinstance(node, ast.FormattedValue): 548 ↛ 549line 548 didn't jump to line 549, because the condition on line 548 was never true
549 valid_conversion = (-1, 115, 114, 97)
550 if use() and not py310plus and node.conversion is None:
551 node.conversion = 5
552 if use() and node.conversion not in valid_conversion:
553 node.conversion = valid_conversion[node.conversion % 4]
555 assign_context = [p for p in parents if p[0] not in ("Tuple", "List", "Starred")]
557 if hasattr(node, "ctx"):
558 if use() and assign_context and assign_context[-1] == ("Delete", "targets"):
559 node.ctx = ast.Del()
560 elif (
561 use()
562 and assign_context
563 and assign_context[-1]
564 in (
565 ("Assign", "targets"),
566 ("AnnAssign", "target"),
567 ("AugAssign", "target"),
568 ("NamedExpr", "target"),
569 ("TypeAlias", "name"),
570 ("For", "target"),
571 ("AsyncFor", "target"),
572 ("withitem", "optional_vars"),
573 ("comprehension", "target"),
574 )
575 ):
576 node.ctx = ast.Store()
577 else:
578 node.ctx = ast.Load()
580 if (
581 use()
582 and isinstance(node, (ast.List, ast.Tuple))
583 and isinstance(node.ctx, ast.Store)
584 ):
585 only_firstone(node.elts, lambda e: isinstance(e, ast.Starred))
587 if use() and isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef, ast.Lambda)):
588 # unique argument names
589 seen = set()
590 for args in all_args(node.args):
591 for i, arg in reversed(list(enumerate(args))):
592 if arg.arg in seen: 592 ↛ 593line 592 didn't jump to line 593, because the condition on line 592 was never true
593 del args[i]
594 if node.args.defaults:
595 del node.args.defaults[0]
596 seen.add(arg.arg)
598 for arg_name in ("kwarg", "vararg"):
599 arg = getattr(node.args, arg_name)
600 if arg:
601 if arg.arg in seen: 601 ↛ 603line 601 didn't jump to line 603, because the condition on line 601 was always true
602 setattr(node.args, arg_name, None)
603 seen.add(arg.arg)
604 arguments = node.args
605 # kwonlyargs and kw_defaults has to have the same size
606 min_kw_size = min(len(arguments.kwonlyargs), len(arguments.kw_defaults))
607 arguments.kwonlyargs = arguments.kwonlyargs[:min_kw_size]
608 arguments.kw_defaults = arguments.kw_defaults[:min_kw_size]
610 if use() and isinstance(node, ast.AsyncFunctionDef):
611 if any(
612 isinstance(n, (ast.Yield, ast.YieldFrom))
613 for n in walk_function_nodes(node.body)
614 ):
615 for n in walk_function_nodes(node.body):
616 if isinstance(n, ast.Return):
617 n.value = None
619 if use() and isinstance(node, (ast.ClassDef, ast.Call)):
620 # unique argument names
621 seen = set()
622 for i, kw in reversed(list(enumerate(node.keywords))):
623 if kw.arg: 623 ↛ 624line 623 didn't jump to line 624, because the condition on line 623 was never true
624 if kw.arg in seen:
625 del node.keywords[i]
626 seen.add(kw.arg)
628 if use() and isinstance(node, (ast.Try)):
629 node.handlers[:-1] = [
630 handler for handler in node.handlers[:-1] if handler.type is not None
631 ]
632 if use() and not node.handlers: 632 ↛ 633line 632 didn't jump to line 633, because the condition on line 632 was never true
633 node.orelse = []
635 if use() and isinstance(
636 node, (ast.GeneratorExp, ast.ListComp, ast.DictComp, ast.SetComp)
637 ):
638 # SyntaxError: assignment expression cannot rebind comprehension iteration variable 'name_3'
639 names = {
640 n.id
641 for c in node.generators
642 for n in ast.walk(c.target)
643 if isinstance(n, ast.Name)
644 } | {
645 n.id
646 for c in node.generators
647 for n in ast.walk(c.iter)
648 if isinstance(n, ast.Name)
649 }
651 class Transformer(ast.NodeTransformer):
652 def visit_NamedExpr(self, node: ast.NamedExpr):
653 if use() and node.target.id in names: 653 ↛ 654line 653 didn't jump to line 654, because the condition on line 653 was never true
654 return self.visit(node.value)
655 return self.generic_visit(node)
657 node = Transformer().visit(node)
659 # pattern matching
660 if sys.version_info >= (3, 10):
662 def match_wildcard(node):
663 if isinstance(node, ast.MatchAs):
664 return (
665 node.pattern is None
666 or match_wildcard(node.pattern)
667 or node.name is None
668 )
669 if isinstance(node, ast.MatchOr):
670 return any(match_wildcard(p) for p in node.patterns) 670 ↛ exitline 670 didn't finish the generator expression on line 670
672 if isinstance(node, ast.Match):
673 found = False
674 new_last = None
675 for i, case_ in reversed(list(enumerate(node.cases))):
676 p = case_.pattern
677 if match_wildcard(p) and case_.guard is None:
678 if not found: 678 ↛ 681line 678 didn't jump to line 681, because the condition on line 678 was always true
679 new_last = node.cases[i]
680 found = True
681 del node.cases[i]
682 if new_last:
683 node.cases.append(new_last)
685 if (
686 isinstance(node, ast.MatchValue)
687 and isinstance(node.value, ast.UnaryOp)
688 and isinstance(node.value.operand, ast.Constant)
689 and type(node.value.operand.value) not in (int, float)
690 ):
691 node.value = node.value.operand
693 if (
694 isinstance(node, ast.MatchValue)
695 and isinstance(node.value, ast.Constant)
696 and any(node.value.value is v for v in (None, True, False))
697 ):
698 return ast.MatchSingleton(value=node.value.value)
700 if isinstance(node, ast.MatchSingleton) and not any(
701 node.value is v for v in (None, True, False)
702 ):
703 return ast.MatchValue(value=ast.Constant(value=node.value))
705 # @lambda f:lambda pattern:set(f(pattern))
706 def names(node):
707 if isinstance(node, ast.MatchAs) and node.name: 707 ↛ 708line 707 didn't jump to line 708, because the condition on line 707 was never true
708 yield node.name
709 elif isinstance(node, ast.MatchStar) and node.name:
710 yield node.name
711 elif isinstance(node, ast.MatchMapping) and node.rest: 711 ↛ 712line 711 didn't jump to line 712, because the condition on line 711 was never true
712 yield node.rest
713 elif isinstance(node, ast.MatchOr): 713 ↛ 714line 713 didn't jump to line 714, because the condition on line 713 was never true
714 yield from set.intersection(
715 *[set(names(pattern)) for pattern in node.patterns]
716 )
717 else:
718 for child in ast.iter_child_nodes(node):
719 yield from names(child)
721 class RemoveName(ast.NodeVisitor):
722 def __init__(self, condition):
723 self.condition = condition
725 def visit_MatchAs(self, node):
726 if self.condition(node.name): 726 ↛ exitline 726 didn't return from function 'visit_MatchAs', because the condition on line 726 was always true
727 node.name = None
729 def visit_MatchMapping(self, node):
730 if self.condition(node.rest):
731 node.rest = None
733 class RemoveNameCleanup(ast.NodeTransformer):
734 def visit_MatchAs(self, node):
735 if node.name is None and node.pattern is not None:
736 return self.visit(node.pattern)
737 return self.generic_visit(node)
739 class FixPatternNames(ast.NodeTransformer):
740 def __init__(self, used=None, allowed=None):
741 # variables which are already used
742 self.used = set() if used is None else used
743 # variables which are allowed in a MatchOr
744 self.allowed = allowed
746 def is_allowed(self, name):
747 return (
748 name is None
749 or name not in self.used
750 and (name in self.allowed if self.allowed is not None else True)
751 )
753 def visit_MatchAs(self, node):
754 if not self.is_allowed(node.name): 754 ↛ 755line 754 didn't jump to line 755, because the condition on line 754 was never true
755 return ast.MatchSingleton(value=None)
756 elif node.name is not None:
757 self.used.add(node.name)
758 return self.generic_visit(node)
760 def visit_MatchStar(self, node):
761 if not self.is_allowed(node.name): 761 ↛ 762line 761 didn't jump to line 762, because the condition on line 761 was never true
762 return ast.MatchSingleton(value=None)
763 elif node.name is not None: 763 ↛ 765line 763 didn't jump to line 765, because the condition on line 763 was always true
764 self.used.add(node.name)
765 return self.generic_visit(node)
767 def visit_MatchMapping(self, node):
768 if not self.is_allowed(node.rest):
769 return ast.MatchSingleton(value=None)
770 elif node.rest is not None:
771 self.used.add(node.rest)
772 return self.generic_visit(node)
774 def visit_MatchOr(self, node: ast.MatchOr):
775 allowed = set.intersection(
776 *[set(names(pattern)) for pattern in node.patterns]
777 )
778 allowed -= self.used
780 node.patterns = [
781 FixPatternNames(set(self.used), allowed).visit(child)
782 for child in node.patterns
783 ]
785 self.used |= allowed
787 return node
789 if isinstance(node, ast.match_case):
790 node.pattern = FixPatternNames().visit(node.pattern)
792 if isinstance(node, ast.MatchMapping): 792 ↛ 794line 792 didn't jump to line 794, because the condition on line 792 was never true
794 def can_literal_eval(node):
795 try:
796 hash(ast.literal_eval(node))
797 except ValueError:
798 return False
799 return True
801 node.keys = [k for k in node.keys if can_literal_eval(k)]
803 node.keys = unique_by(node.keys, ast.literal_eval)
804 del node.patterns[len(node.keys) :]
806 seen = set()
807 for pattern in node.patterns:
808 RemoveName(lambda name: name in seen).visit(pattern)
809 seen |= {*names(pattern)}
811 if isinstance(node, ast.MatchOr):
812 var_names = set.intersection(
813 *[set(names(pattern)) for pattern in node.patterns]
814 )
816 RemoveName(lambda name: name not in var_names).visit(node)
818 for i, pattern in enumerate(node.patterns): 818 ↛ 823line 818 didn't jump to line 823, because the loop on line 818 didn't complete
819 if match_wildcard(pattern):
820 node.patterns = node.patterns[: i + 1]
821 break
823 if len(node.patterns) == 1: 823 ↛ 824line 823 didn't jump to line 824, because the condition on line 823 was never true
824 return node.patterns[0]
826 if isinstance(node, ast.Match):
827 for i, case in enumerate(node.cases):
828 # default match `case _:`
829 if (
830 isinstance(case.pattern, ast.MatchAs)
831 and case.pattern.name is None
832 or isinstance(case.pattern, ast.MatchOr)
833 and isinstance(case.pattern.patterns[-1], ast.MatchAs)
834 and case.pattern.patterns[-1].name is None
835 and case.guard is None
836 ):
837 node.cases = node.cases[: i + 1]
838 break
840 if isinstance(node, ast.MatchSequence):
841 only_firstone(node.patterns, lambda e: isinstance(e, ast.MatchStar))
843 seen = set()
844 for pattern in node.patterns:
845 RemoveName(lambda name: name in seen).visit(pattern) 845 ↛ exitline 845 didn't run the lambda on line 845
846 seen |= {*names(pattern)}
848 if isinstance(node, ast.MatchClass): 848 ↛ 849line 848 didn't jump to line 849, because the condition on line 848 was never true
849 node.kwd_attrs = unique_by(node.kwd_attrs, lambda e: e)
850 del node.kwd_patterns[len(node.kwd_attrs) :]
852 seen = set()
853 for pattern in [*node.patterns, *node.kwd_patterns]:
854 RemoveName(lambda name: name in seen).visit(pattern)
855 seen |= {*names(pattern)}
857 if isinstance(node, ast.Match):
858 node = RemoveNameCleanup().visit(node)
860 # async nodes
862 in_async_code = False
863 for parent, attr in reversed(parents):
864 if parent == "AsyncFunctionDef" and attr == "body":
865 in_async_code = True
866 break
867 if parent in ("FunctionDef", "Lambda", "ClassDef", "TypeAlias"):
868 break
870 if not py311plus and parent in (
871 "ListComp",
872 "DictComp",
873 "SetComp",
874 "GeneratorExp",
875 ):
876 break
878 if isinstance(node, (ast.ListComp, ast.SetComp, ast.DictComp)):
879 if use() and not in_async_code:
880 for comp in node.generators:
881 comp.is_async = 0
883 in_excepthandler = False
884 for parent, _ in reversed(parents):
885 if parent == "ExceptHandler":
886 in_excepthandler = True
887 break
888 if parent in ("FunctionDef", "Lambda", "AsyncFunctionDef"):
889 break
891 if isinstance(node, ast.Raise):
892 if use() and not node.exc:
893 node.cause = None
895 if use() and isinstance(node, ast.Lambda):
896 # no annotation for lambda arguments
897 for args in all_args(node.args):
898 for arg in args:
899 arg.annotation = None
901 if use() and node.args.vararg: 901 ↛ 902line 901 didn't jump to line 902, because the condition on line 901 was never true
902 node.args.vararg.annotation = None
904 if use() and node.args.kwarg: 904 ↛ 905line 904 didn't jump to line 905, because the condition on line 904 was never true
905 node.args.kwarg.annotation = None
907 if sys.version_info >= (3, 12):
908 # if use() and isinstance(node, ast.Global):
909 # node.names = unique_by(node.names, lambda e: e)
911 # type scopes
912 if use() and hasattr(node, "type_params"):
913 node.type_params = unique_by(node.type_params, lambda p: p.name)
915 def cleanup_annotation(annotation):
916 class Transformer(ast.NodeTransformer):
917 def visit_NamedExpr(self, node: ast.NamedExpr):
918 if not use():
919 return self.generic_visit(node)
920 return self.visit(node.value)
922 def visit_Yield(self, node: ast.Yield) -> Any:
923 if not use():
924 return self.generic_visit(node)
925 if node.value is None:
926 return ast.Constant(value=None)
927 return self.visit(node.value)
929 def visit_YieldFrom(self, node: ast.YieldFrom) -> Any:
930 if not use():
931 return self.generic_visit(node)
932 return self.visit(node.value)
934 # def visit_Lambda(self, node: ast.Lambda) -> Any:
935 # if not use():
936 # return self.generic_visit(node)
937 # return self.visit(node.body)
939 return Transformer().visit(annotation)
941 if ( 941 ↛ 945line 941 didn't jump to line 945
942 isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
943 and node.type_params
944 ):
945 for arg in [
946 *node.args.posonlyargs,
947 *node.args.args,
948 *node.args.kwonlyargs,
949 node.args.vararg,
950 node.args.kwarg,
951 ]:
952 if use() and arg is not None and arg.annotation:
953 arg.annotation = cleanup_annotation(arg.annotation)
955 if use() and node.returns is not None:
956 node.returns = cleanup_annotation(node.returns)
958 if isinstance(node, ast.ClassDef) and node.type_params: 958 ↛ 959line 958 didn't jump to line 959, because the condition on line 958 was never true
959 node.bases = [cleanup_annotation(b) for b in node.bases]
960 for kw in node.keywords:
961 if use():
962 kw.value = cleanup_annotation(kw.value)
964 for n in ast.walk(node):
965 if use() and isinstance(n, ast.TypeAlias):
966 n.value = cleanup_annotation(n.value)
968 if isinstance(node, ast.ClassDef):
969 for n in ast.walk(node):
970 if use() and isinstance(n, ast.TypeVar) and n.bound is not None: 970 ↛ 971line 970 didn't jump to line 971, because the condition on line 970 was never true
971 n.bound = cleanup_annotation(n.bound)
973 if use() and isinstance(node, ast.AnnAssign):
974 node.annotation = cleanup_annotation(node.annotation)
976 if sys.version_info >= (3, 13):
977 if hasattr(node, "type_params"):
978 # non-default type parameter 'name_1' follows default type parameter
979 no_default = False
980 for child in reversed(node.type_params):
981 if child.default_value != None:
982 no_default = True
983 if use() and no_default:
984 child.default_value = None
986 return node
989def fix_result(node):
990 return fix_nonlocal(node)
993def is_valid_ast(tree) -> bool:
994 def is_valid(node: ast.AST, parents):
995 type_name = node.__class__.__name__
996 if (
997 isinstance(node, (ast.AST))
998 and parents
999 and propability(
1000 parents,
1001 type_name,
1002 )
1003 == 0
1004 ):
1005 print("invalid node with:")
1006 print("parents:", parents)
1007 print("node:", node)
1009 try:
1010 propability_try(
1011 parents,
1012 node.__class__.__name__,
1013 )
1014 except Invalid:
1015 frame = traceback.extract_tb(sys.exc_info()[2])[1]
1016 print("file:", f"{frame.filename}:{frame.lineno}")
1018 return False
1020 if type_name in same_length:
1021 attrs = same_length[type_name]
1022 if len({len(v) for k, v in ast.iter_fields(node) if k in attrs}) != 1:
1023 return False
1025 if isinstance(node, (ast.AST)):
1026 info = get_info(type_name)
1027 assert isinstance(info, NodeType)
1029 for attr_name, value in ast.iter_fields(node):
1030 attr_info = info.fields[attr_name]
1031 if attr_info[1] == "":
1032 value_info = get_info(attr_info[0])
1033 if isinstance(value_info, UnionNodeType):
1034 if type(value).__name__ not in value_info.options:
1035 print(
1036 f"{type(node).__name__}.{attr_name} {value} is not one type of {value_info.options}"
1037 )
1038 print("parents are:", parents)
1039 return False
1041 if isinstance(value, list) and len(value) < min_attr_length(
1042 type_name, attr_name
1043 ):
1044 print("invalid arg length", type_name, attr_name)
1045 return False
1047 if isinstance(value, list) != (info.fields[attr_name][1] == "*"): 1047 ↛ 1048line 1047 didn't jump to line 1048, because the condition on line 1047 was never true
1048 print("no list", value)
1049 return False
1050 if value is None:
1051 if not ( 1051 ↛ 1058line 1051 didn't jump to line 1058, because the condition on line 1051 was never true
1052 (
1053 info.fields[attr_name][1] == "?"
1054 and none_allowed(parents + [(type_name, attr_name)])
1055 )
1056 or info.fields[attr_name][0] == "constant"
1057 ):
1058 print("none not allowed", parents, type_name, attr_name)
1059 return False
1061 for field in node._fields:
1062 value = getattr(node, field)
1063 if isinstance(value, list):
1064 if not all(
1065 is_valid(e, parents + [(type_name, field)]) for e in value
1066 ):
1067 return False
1068 else:
1069 if not is_valid(value, parents + [(type_name, field)]):
1070 return False
1071 return True
1073 if not is_valid(tree, []):
1074 return False
1076 tree_copy = deepcopy(tree)
1078 def fix_tree(node: ast.AST, parents):
1079 for field in node._fields:
1080 value = getattr(node, field)
1081 if isinstance(value, ast.AST):
1082 setattr(
1083 node,
1084 field,
1085 fix_tree(value, parents + [(node.__class__.__name__, field)]),
1086 )
1087 if isinstance(value, list):
1088 setattr(
1089 node,
1090 field,
1091 [
1092 (
1093 fix_tree(v, parents + [(node.__class__.__name__, field)])
1094 if isinstance(v, ast.AST)
1095 else v
1096 )
1097 for v in value
1098 ],
1099 )
1101 return fix(node, parents)
1103 tree_copy = fix_tree(tree_copy, [])
1104 tree_copy = fix_result(tree_copy)
1106 result = equal_ast(tree_copy, tree, dump_info=True)
1108 if 1:
1109 if sys.version_info >= (3, 9) and not result:
1110 dump_copy = ast_dump(tree_copy).splitlines()
1111 dump = ast_dump(tree).splitlines()
1112 import difflib
1114 print("ast was changed by during fixing:")
1116 print("\n".join(difflib.unified_diff(dump, dump_copy, "original", "fixed")))
1118 return result
1121def arguments(
1122 node: ast.FunctionDef | ast.AsyncFunctionDef | ast.Lambda,
1123) -> list[ast.arg]:
1124 args = node.args
1125 l = [
1126 *args.args,
1127 args.vararg,
1128 *args.kwonlyargs,
1129 args.kwarg,
1130 ]
1132 l += args.posonlyargs
1134 return [arg for arg in l if arg is not None]
1137def fix_nonlocal(node):
1138 class NonLocalFixer(ast.NodeTransformer):
1139 """
1140 removes invalid Nonlocals from the class/function
1141 """
1143 def __init__(self, locals, nonlocals, globals, type_params, parent_globals):
1144 self.locals = set(locals)
1145 self.used_names = set(locals)
1146 self.type_params = set(type_params)
1148 # nonlocals from the parent scope
1149 self.nonlocals = set(nonlocals)
1150 self.used_nonlocals = set()
1152 # globals from the global scope
1153 self.globals = set(globals)
1154 self.used_globals = set()
1155 self.parent_globals = parent_globals
1157 def name_assigned(self, name):
1158 self.locals.add(name)
1159 self.used_names.add(name)
1161 def visit_Name(self, node: ast.Name) -> Any:
1162 if isinstance(node.ctx, (ast.Store, ast.Del)):
1163 self.name_assigned(node.id)
1164 else:
1165 self.used_names.add(node.id)
1166 return node
1168 if sys.version_info >= (3, 10):
1170 def visit_MatchAs(self, node: ast.MatchAs) -> Any:
1171 if node.pattern:
1172 self.visit(node.pattern)
1173 self.name_assigned(node.name)
1174 return node
1176 def search_walrus(self, node):
1177 for n in ast.walk(node):
1178 if isinstance(n, ast.NamedExpr):
1179 self.visit(n.target)
1181 def visit_GeneratorExp(self, node: ast.GeneratorExp) -> Any:
1182 self.visit(node.generators[0].iter)
1183 self.search_walrus(node)
1184 return node
1186 def visit_ListComp(self, node: ast.ListComp) -> Any:
1187 self.visit(node.generators[0].iter)
1188 self.search_walrus(node)
1189 return node
1191 def visit_DictComp(self, node: ast.DictComp) -> Any:
1192 self.visit(node.generators[0].iter)
1193 self.search_walrus(node)
1194 return node
1196 def visit_SetComp(self, node: ast.SetComp) -> Any:
1197 self.visit(node.generators[0].iter)
1198 self.search_walrus(node)
1199 return node
1201 def visit_Nonlocal(self, node: ast.Nonlocal) -> Any:
1202 # TODO: research __class__ seems to be defined in the class scope
1203 # but it is also not
1204 # class A:
1205 # print(locals()) # no __class__
1206 # def f():
1207 # nonlocal __class__ # is A
1208 node.names = [
1209 name
1210 for name in node.names
1211 if name not in self.locals
1212 and name in self.nonlocals
1213 and name not in self.used_names
1214 and name not in self.type_params
1215 and name not in self.parent_globals
1216 and name not in self.used_globals
1217 or name in ("__class__",)
1218 ]
1219 self.used_nonlocals |= set(node.names)
1221 if not node.names:
1222 return ast.Pass()
1224 return node
1226 def visit_Global(self, node: ast.Global) -> Any:
1227 node.names = [
1228 name
1229 for name in node.names
1230 if name not in self.locals
1231 and name not in self.used_names
1232 and name not in self.used_nonlocals
1233 ]
1234 self.used_globals |= set(node.names)
1236 if not node.names:
1237 return ast.Pass()
1239 return node
1241 def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
1242 if isinstance(node.target, ast.Name) and (
1243 node.target.id in self.used_globals
1244 or node.target.id in self.used_nonlocals
1245 ):
1246 if node.value:
1247 return self.generic_visit(
1248 ast.Assign(
1249 targets=[node.target], value=node.value, type_comment=None
1250 )
1251 )
1252 else:
1253 return ast.Pass()
1254 return self.generic_visit(node)
1256 def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
1257 self.name_assigned(node.name)
1259 all_nodes = [
1260 *node.args.defaults,
1261 *node.args.kw_defaults,
1262 *node.decorator_list,
1263 node.returns,
1264 ]
1266 all_nodes += [arg.annotation for arg in arguments(node)]
1268 for default in all_nodes:
1269 if default is not None:
1270 self.visit(default)
1272 return node
1274 def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
1275 self.name_assigned(node.name)
1277 all_nodes = [
1278 *node.args.defaults,
1279 *node.args.kw_defaults,
1280 *node.decorator_list,
1281 node.returns,
1282 ]
1284 all_nodes += [arg.annotation for arg in arguments(node)]
1286 for default in all_nodes:
1287 if default is not None:
1288 self.visit(default)
1289 return node
1291 def visit_ClassDef(self, node: ast.ClassDef) -> Any:
1292 for expr in [
1293 *[k.value for k in node.keywords],
1294 *node.bases,
1295 *node.decorator_list,
1296 ]:
1297 if expr is not None: 1297 ↛ 1292line 1297 didn't jump to line 1292
1298 self.visit(expr)
1300 self.name_assigned(node.name)
1302 return node
1304 # pattern matching
1305 if sys.version_info >= (3, 10):
1307 def visit_MatchMapping(self, node: ast.MatchMapping) -> Any:
1308 if node.rest is not None:
1309 self.name_assigned(node.rest)
1310 return self.generic_visit(node)
1312 if sys.version_info >= (3, 13):
1314 def visit_MatchStar(self, node: ast.MatchStar) -> Any:
1315 self.name_assigned(node.name)
1316 return self.generic_visit(node)
1318 def visit_ExceptHandler(self, handler):
1319 if handler.name: 1319 ↛ 1320line 1319 didn't jump to line 1320, because the condition on line 1319 was never true
1320 self.name_assigned(handler.name)
1321 return self.generic_visit(handler)
1323 def visit_Lambda(self, node: ast.Lambda) -> Any:
1324 for default in [*node.args.defaults, *node.args.kw_defaults]:
1325 if default is not None: 1325 ↛ 1324line 1325 didn't jump to line 1324, because the condition on line 1325 was always true
1326 self.visit(default)
1327 return node
1329 if sys.version_info < (3, 13):
1331 def visit_Try(self, node: ast.Try) -> Any:
1332 # work around for https://github.com/python/cpython/issues/111123
1333 args = {}
1334 for k in ("body", "orelse", "handlers", "finalbody"):
1335 args[k] = [self.visit(x) for x in getattr(node, k)]
1337 return ast.Try(**args)
1339 if sys.version_info >= (3, 11):
1341 def visit_TryStar(self, node: ast.TryStar) -> Any:
1342 # work around for https://github.com/python/cpython/issues/111123
1343 args = {}
1344 for k in ("body", "orelse", "handlers", "finalbody"):
1345 args[k] = [self.visit(x) for x in getattr(node, k)]
1347 return ast.TryStar(**args)
1349 class FunctionTransformer(ast.NodeTransformer):
1350 """
1351 - transformes a class/function
1352 """
1354 def __init__(self, nonlocals, globals, type_params, parent_globals):
1355 self.nonlocals = set(nonlocals)
1356 self.globals = set(globals)
1357 self.type_params = type_params
1358 self.parent_globals = parent_globals
1360 def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
1361 return self.handle_function(node)
1363 def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
1364 return self.handle_function(node)
1366 def visit_Lambda(self, node: ast.Lambda) -> Any:
1367 # there are no globals/nonlocals/functiondefs in lambdas
1368 return node
1370 def visit_ClassDef(self, node: ast.ClassDef) -> Any:
1371 type_params = set(self.type_params)
1372 if sys.version_info >= (3, 12):
1373 type_params |= {typ.name for typ in node.type_params} # type: ignore 1373 ↛ exitline 1373 didn't run the set comprehension on line 1373
1375 fixer = NonLocalFixer(
1376 [], self.nonlocals, self.globals, type_params, self.parent_globals
1377 )
1378 node.body = [fixer.visit(stmt) for stmt in node.body]
1380 ft = FunctionTransformer(
1381 self.nonlocals, self.globals, type_params, self.parent_globals
1382 )
1383 node.body = [ft.visit(stmt) for stmt in node.body]
1385 return node
1387 def handle_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> Any:
1388 names = {arg.arg for arg in arguments(node)}
1390 type_params = set(self.type_params)
1391 if sys.version_info >= (3, 12):
1392 type_params |= {typ.name for typ in node.type_params} # type: ignore 1392 ↛ exitline 1392 didn't run the set comprehension on line 1392
1394 fixer = NonLocalFixer(
1395 names, self.nonlocals, self.globals, type_params, self.parent_globals
1396 )
1397 node.body = [fixer.visit(stmt) for stmt in node.body]
1399 ft = FunctionTransformer(
1400 fixer.locals | self.nonlocals,
1401 self.globals,
1402 type_params,
1403 fixer.used_globals,
1404 )
1405 node.body = [ft.visit(stmt) for stmt in node.body]
1407 return node
1409 fixer = NonLocalFixer([], [], [], [], [])
1410 node = fixer.visit(node)
1412 node = FunctionTransformer([], [], [], []).visit(node)
1413 return node
1416def min_attr_length(node_type, attr_name):
1417 attr = f"{node_type}.{attr_name}"
1418 if node_type == "Module" and attr_name == "body":
1419 return 0
1420 if attr_name == "body":
1421 return 1
1422 if node_type == "MatchOr" and attr_name == "patterns":
1423 return 2
1424 if node_type == "BoolOp" and attr_name == "values":
1425 return 2
1426 if node_type == "BinOp" and attr_name == "values": 1426 ↛ 1427line 1426 didn't jump to line 1427, because the condition on line 1426 was never true
1427 return 1
1428 if node_type == "Import" and attr_name == "names":
1429 return 1
1430 if node_type == "ImportFrom" and attr_name == "names":
1431 return 1
1432 if node_type in ("With", "AsyncWith") and attr_name == "items":
1433 return 1
1434 if node_type in ("Try", "TryStar") and attr_name == "handlers":
1435 return 1
1436 if node_type == "Delete" and attr_name == "targets":
1437 return 1
1438 if node_type == "Match" and attr_name == "cases":
1439 return 1
1440 if node_type == "ExtSlice" and attr_name == "dims":
1441 return 1
1442 if sys.version_info < (3, 9) and node_type == "Set" and attr_name == "elts":
1443 return 1
1444 if node_type == "Compare" and attr_name in ("ops", "comparators"):
1445 return 1
1446 if attr_name == "generators":
1447 return 1
1449 if attr == "Assign.targets":
1450 return 1
1452 return 0
1455def none_allowed(parents):
1456 if parents[-2:] == [("TryStar", "handlers"), ("ExceptHandler", "type")]: 1456 ↛ 1457line 1456 didn't jump to line 1457, because the condition on line 1456 was never true
1457 return False
1458 return True
1461same_length = {
1462 "MatchClass": ["kwd_attrs", "kwd_patterns"],
1463 "MatchMapping": ["patterns", "keys"],
1464 "arguments": ["kw_defaults", "kwonlyargs"],
1465 "Compare": ["ops", "comparators"],
1466 "Dict": ["keys", "values"],
1467}
1470class AstGenerator:
1471 def __init__(self, seed, node_limit, depth_limit):
1472 self.rand = random.Random(seed)
1473 self.nodes = 0
1474 self.node_limit = node_limit
1475 self.depth_limit = depth_limit
1477 def cnd(self):
1478 return self.rand.choice([True, False])
1480 def generate(self, name: str, parents=(), depth=0):
1481 result = self.generate_impl(name, parents, depth)
1482 result = fix_result(result)
1483 return result
1485 def generate_impl(self, name: str, parents=(), depth=0):
1486 depth += 1
1487 self.nodes += 1
1489 if depth > 100:
1490 exit()
1492 stop = depth > self.depth_limit or self.nodes > self.node_limit
1494 info = get_info(name)
1496 if isinstance(info, NodeType):
1497 ranges = {}
1499 def attr_length(child, attr_name):
1500 if name == "Module":
1501 return 20
1503 if name in same_length:
1504 attrs = same_length[name]
1505 if attr_name in attrs[1:]:
1506 return attr_length(child, attrs[0])
1508 if child == "arguments" and attr_name == "defaults":
1509 min = 0
1510 max = attr_length(child, "posonlyargs") + attr_length(child, "args")
1511 ranges[attr_name] = self.rand.randint(min, max)
1513 elif attr_name not in ranges:
1514 min = min_attr_length(child, attr_name)
1516 max = min if stop else min + 1 if depth > 10 else min + 5
1517 ranges[attr_name] = self.rand.randint(min, max)
1519 return ranges[attr_name]
1521 def child_node(n, t, q, parents):
1522 if q == "":
1523 return self.generate_impl(t, parents, depth)
1524 elif q == "*":
1525 return [
1526 self.generate_impl(t, parents, depth)
1527 for _ in range(attr_length(parents[-1][0], n))
1528 ]
1529 elif q == "?":
1530 return (
1531 self.generate_impl(t, parents, depth)
1532 if not none_allowed(parents) or self.cnd()
1533 else None
1534 )
1535 else:
1536 assert False
1538 attributes = {
1539 n: child_node(n, t, q, [*parents, (name, n)])
1540 for n, (t, q) in info.fields.items()
1541 }
1543 result = info.ast_type(**attributes)
1544 result = fix(result, parents)
1545 return result
1547 if isinstance(info, UnionNodeType):
1548 options_list = [
1549 (option, propability(parents, option)) for option in info.options
1550 ]
1552 invalid_option = [
1553 option for (option, prop) in options_list if prop == 0 and not use()
1554 ]
1556 assert len(invalid_option) in (0, 1), invalid_option
1558 if len(invalid_option) == 1:
1559 return self.generate_impl(invalid_option[0])
1561 options = dict(options_list)
1562 if stop:
1563 for final in ("Name", "MatchValue", "Pass"):
1564 if options.get(final, 0) != 0:
1565 options = {final: 1}
1566 break
1568 if sum(options.values()) == 0:
1569 # TODO: better handling of `type?`
1570 return None
1572 return self.generate_impl(
1573 self.rand.choices(*zip(*options.items()))[0], parents, depth
1574 )
1575 if isinstance(info, BuiltinNodeType):
1576 if info.kind == "identifier":
1577 return f"name_{self.rand.randint(0,5)}"
1578 elif info.kind == "int":
1579 return self.rand.randint(0, 5)
1580 elif info.kind == "string":
1581 return self.rand.choice(["some text", ""])
1582 elif info.kind == "constant":
1583 return self.rand.choice(
1584 [
1585 None,
1586 b"some bytes",
1587 "some const text",
1588 b"",
1589 "",
1590 "'\"'''\"\"\"{}\\",
1591 b"'\"'''\"\"\"{}\\",
1592 self.rand.randint(0, 20),
1593 self.rand.uniform(0, 20),
1594 True,
1595 False,
1596 ]
1597 )
1599 else:
1600 assert False, "unknown kind: " + info.kind
1602 assert False
1605import warnings
1608def check(tree):
1609 for node in ast.walk(tree):
1610 if isinstance(node, ast.arguments):
1611 assert len(node.posonlyargs) + len(node.args) >= len(
1612 node.defaults
1613 ), ast_dump(node)
1614 assert len(node.kwonlyargs) == len(node.kw_defaults)
1617def generate_ast(
1618 seed: int,
1619 *,
1620 node_limit: int = 10000000,
1621 depth_limit: int = 8,
1622 root_node: str = "Module",
1623) -> ast.AST:
1624 generator = AstGenerator(seed, depth_limit=depth_limit, node_limit=node_limit)
1626 with warnings.catch_warnings():
1627 warnings.simplefilter("ignore", SyntaxWarning)
1628 tree = generator.generate(root_node)
1629 check(tree)
1631 ast.fix_missing_locations(tree)
1632 return tree
1635def generate(
1636 seed: int,
1637 *,
1638 node_limit: int = 10000000,
1639 depth_limit: int = 8,
1640 root_node: str = "Module",
1641) -> str:
1642 tree = generate_ast(
1643 seed, node_limit=node_limit, depth_limit=depth_limit, root_node=root_node
1644 )
1645 return unparse(tree)
1648# next algo
1650# design targets:
1651# * enumerate "all" possible ast-node combinations
1652# * check if propability 0 would produce incorrect code
1653# * the algo should be able to generate every possible syntax combination for every python version.
1654# * hypothesis integration
1655# * do not use compile() in the implementation
1656# * generation should be customizable (custom propabilities and random values)
1658# features:
1659# * node-context: function-scope async-scope type-scope class-scope ...
1660# * names: nonlocal global
1662from dataclasses import dataclass
1665@dataclass
1666class ParentRef:
1667 node: PartialNode
1668 attr_name: str
1669 index: int
1670 _context: dict
1672 def __getattr__(self, name):
1673 if name.startswith("ctx_"):
1674 return getattr(node, name)
1675 raise AttributeError
1678# (d:=[n] | q_parent("Delete.targets")) and len(d.targets)==1
1681@dataclass
1682class PartialValue:
1683 value: int | str | bool
1686@dataclass
1687class PartialNode:
1688 _node_type_name: str
1689 parent_ref: ParentRef | None
1690 _defined_attrs: dict
1691 _context: dict
1693 def inside(self, spec) -> PartialNode | None: ... 1693 ↛ exitline 1693 didn't jump to line 1693, because
1695 @property
1696 def parent(self):
1697 return self.parent_ref.node
1699 def __getattr__(self, name):
1700 if name.startswith("ctx_"):
1701 return getattr(node, name)
1703 if name not in self._defined_attrs:
1704 raise RuntimeError(f"{self._node_type_name}.{name} is not defined jet")
1706 return self._defined_attrs[name]
1709def gen(node: PartialNode):
1710 # parents [(node,attr_name)]
1711 pass