Coverage for kye/parser/kye_transformer.py: 8%
99 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-01-04 16:32 -0700
« prev ^ index » next coverage.py v7.3.2, created at 2024-01-04 16:32 -0700
1from typing import Any
2from kye.parser.kye_ast import *
3from lark import Token, Tree
5OPERATORS_MAP = {
6 'or_exp': '|',
7 'xor_exp': '^',
8 'and_exp': '&',
9 'dot_exp': '.',
10 'filter_exp': '[]',
11 'is_exp': 'is',
12}
14def tokens_to_ast(token: Union[Tree, Token], script: str):
16 if isinstance(token, Token):
17 kind = token.type
18 meta = token
19 value = token.value
20 children = [ value ]
21 elif isinstance(token, Tree):
22 kind = token.data
23 meta = token.meta
24 children = [tokens_to_ast(child, script) for child in token.children]
26 meta = TokenPosition(
27 line=meta.line,
28 column=meta.column,
29 end_line=meta.end_line,
30 end_column=meta.end_column,
31 start_pos=meta.start_pos,
32 end_pos=meta.end_pos,
33 text=script[meta.start_pos:meta.end_pos],
34 )
36 # Lark prefixes imported rules with '<module_name>__'
37 # we will just make sure that we don't have any name conflicts
38 # across grammar files and remove the prefixes so that we can
39 # use the same transformer independently of how the grammar
40 # was imported
41 if '__' in kind:
42 kind = kind.split('__')[-1]
43 assert kind != '', 'Did not expect rule name to end with a double underscore'
45 if kind == 'SIGNED_NUMBER':
46 return float(value)
47 if kind == 'ESCAPED_STRING':
48 return value[1:-1]
49 if kind == 'identifier':
50 return Identifier(name=children[0], meta=meta)
51 if kind == 'literal':
52 return LiteralExpression(value=children[0], meta=meta)
53 if kind in ('comp_exp', 'mult_exp', 'add_exp'):
54 return Operation(op=children[1], children=[children[0], children[2]], meta=meta)
55 if kind in OPERATORS_MAP:
56 return Operation(op=OPERATORS_MAP[kind], children=children, meta=meta)
57 if kind == 'unary_expression':
58 Operation(op=children[0], children=[children[1]], meta=meta)
60 if kind == 'alias_def':
61 name, typ = children
62 return AliasDefinition(name=name, type=typ, meta=meta)
64 if kind == 'edge_def':
65 if len(children) == 3:
66 name, cardinality, typ = children
67 elif len(children) == 2:
68 name, typ = children
69 cardinality = None
70 else:
71 raise ValueError('Invalid edge definition')
73 return EdgeDefinition(name=name, type=typ, cardinality=cardinality, meta=meta)
75 if kind == 'index':
76 return children
78 if kind == 'model_def':
79 indexes = []
80 edges = []
81 subtypes = []
82 for child in children[1:]:
83 if isinstance(child, EdgeDefinition):
84 edges.append(child)
85 elif isinstance(child, TypeDefinition):
86 subtypes.append(child)
87 else:
88 assert type(child) is list
89 indexes.append(child)
91 return ModelDefinition(
92 name=children[0],
93 indexes=indexes,
94 edges=edges,
95 subtypes=subtypes,
96 meta=meta
97 )
99 if kind == 'definitions':
100 return ModuleDefinitions(children=children, meta=meta)
102 if isinstance(token, Token):
103 return value
104 else:
105 raise Exception(f'Unknown rule: {kind}')
108def globalize_names(ast: AST, path = tuple(), type_name_map = dict()):
110 def rename_identifiers(ast: Expression, type_name_map: dict[str, str]):
111 if isinstance(ast, Operation) and ast.name == 'dot':
112 # Only rename the first child for the dot operator,
113 # because the following children have a different context
114 rename_identifiers(ast.children[0], type_name_map)
115 return
116 if isinstance(ast, Identifier) and ast.name in type_name_map:
117 ast.name = type_name_map[ast.name]
119 for child in ast.children:
120 rename_identifiers(child, type_name_map)
122 if isinstance(ast, Definition):
123 path += (ast.name,)
124 if isinstance(ast, TypeDefinition):
125 ast.name = '.'.join(path)
126 elif isinstance(ast, EdgeDefinition):
127 ast._ref = '.'.join(path)
128 rename_identifiers(ast, type_name_map)
129 if isinstance(ast, ContainedDefinitions):
130 local = {**type_name_map}
131 for child in ast.children:
132 if isinstance(child, TypeDefinition):
133 local[child.name] = '.'.join(path + (child.name,))
134 for child in ast.children:
135 globalize_names(child, path, local)
137def unnest_subtypes(ast: AST):
138 def collect_types(ast: AST):
139 if isinstance(ast, TypeDefinition):
140 yield ast
141 if isinstance(ast, ContainedDefinitions):
142 for child in ast.children:
143 yield from collect_types(child)
145 if isinstance(ast, ModuleDefinitions):
146 ast.children = list(collect_types(ast))
148def transform(token: Union[Tree, Token], script: str):
149 ast = tokens_to_ast(token, script)
150 globalize_names(ast)
151 unnest_subtypes(ast)
152 return ast