Coverage for src/lazy_imports_lite/_transformer.py: 100%
93 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-11 15:02 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-11 15:02 +0100
1import ast
2import typing
3from typing import Any
5header = """
6import lazy_imports_lite._hooks as __lazy_imports_lite__
7globals=__lazy_imports_lite__.make_globals(lambda g=globals:g())
8"""
9header_ast = ast.parse(header).body
12class TransformModuleImports(ast.NodeTransformer):
13 def __init__(self):
14 self.transformed_imports = []
15 self.functions = []
16 self.context = []
18 self.globals = set()
19 self.locals = set()
20 self.in_function = False
22 def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
23 if self.context[-1] != "Module":
24 return node
26 if node.module == "__future__":
27 return node
29 new_nodes = []
30 for alias in node.names:
31 name = alias.asname or alias.name
33 module = "." * (node.level) + (node.module or "")
34 new_nodes.append(
35 ast.Assign(
36 targets=[ast.Name(id=name, ctx=ast.Store())],
37 value=ast.Call(
38 func=ast.Attribute(
39 value=ast.Name(id="__lazy_imports_lite__", ctx=ast.Load()),
40 attr="ImportFrom",
41 ctx=ast.Load(),
42 ),
43 args=[
44 ast.Name(id="__package__", ctx=ast.Load()),
45 ast.Constant(value=module, kind=None),
46 ast.Constant(alias.name, kind=None),
47 ],
48 keywords=[],
49 ),
50 )
51 )
52 self.transformed_imports.append(name)
53 return new_nodes
55 def visit_Import(self, node: ast.Import) -> Any:
56 if len(self.context) > 1:
57 return node
59 new_nodes = []
60 for alias in node.names:
61 if alias.asname:
62 name = alias.asname
63 new_nodes.append(
64 ast.Assign(
65 targets=[ast.Name(id=name, ctx=ast.Store())],
66 value=ast.Call(
67 func=ast.Attribute(
68 value=ast.Name(
69 id="__lazy_imports_lite__", ctx=ast.Load()
70 ),
71 attr="ImportAs",
72 ctx=ast.Load(),
73 ),
74 args=[ast.Constant(value=alias.name, kind=None)],
75 keywords=[],
76 ),
77 )
78 )
79 self.transformed_imports.append(name)
80 else:
81 name = alias.name.split(".")[0]
82 new_nodes.append(
83 ast.Assign(
84 targets=[ast.Name(id=name, ctx=ast.Store())],
85 value=ast.Call(
86 func=ast.Attribute(
87 value=ast.Name(
88 id="__lazy_imports_lite__", ctx=ast.Load()
89 ),
90 attr="Import",
91 ctx=ast.Load(),
92 ),
93 args=[ast.Constant(value=alias.name, kind=None)],
94 keywords=[],
95 ),
96 )
97 )
98 self.transformed_imports.append(name)
100 return new_nodes
102 def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
103 return self.handle_function(node)
105 def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
106 return self.handle_function(node)
108 def visit_Lambda(self, node: ast.Lambda) -> Any:
109 return self.handle_function(node)
111 def handle_function(self, function):
112 for field, value in ast.iter_fields(function):
113 if field != "body":
114 if isinstance(value, list):
115 setattr(function, field, [self.visit(item) for item in value])
116 elif isinstance(value, ast.AST):
117 setattr(function, field, self.visit(value))
118 self.functions.append(function)
120 return function
122 def handle_function_body(self, function: ast.FunctionDef):
123 args = [
124 *function.args.posonlyargs,
125 *function.args.args,
126 function.args.vararg,
127 *function.args.kwonlyargs,
128 function.args.kwarg,
129 ]
131 self.locals = {arg.arg for arg in args if arg is not None}
133 self.globals = set()
135 self.in_function = True
137 if isinstance(function.body, list):
138 function.body = [self.visit(b) for b in function.body]
139 else:
140 function.body = self.visit(function.body)
142 def visit_Global(self, node: ast.Global) -> Any:
143 self.globals.update(node.names)
144 return self.generic_visit(node)
146 def visit_Name(self, node: ast.Name) -> Any:
147 if isinstance(node.ctx, ast.Store) and (
148 node.id not in self.globals or not self.in_function
149 ):
150 self.locals.add(node.id)
152 if node.id in self.transformed_imports and node.id not in self.locals:
153 old_ctx = node.ctx
154 node.ctx = ast.Load()
155 return ast.Attribute(value=node, attr="_lazy_value", ctx=old_ctx)
156 else:
157 return node
159 def visit_Module(self, module: ast.Module) -> Any:
160 module = typing.cast(ast.Module, self.generic_visit(module))
161 assert len(self.context) == 0
163 pos = 0
165 def is_import_from_future(node):
166 return (
167 isinstance(node, ast.Expr)
168 and isinstance(node.value, ast.Constant)
169 and isinstance(node.value.value, str)
170 or isinstance(node, ast.ImportFrom)
171 and node.module == "__future__"
172 )
174 if module.body:
175 while is_import_from_future(module.body[pos]):
176 pos += 1
177 module.body[pos:pos] = header_ast
179 self.context = ["FunctionBody"]
180 while self.functions:
181 f = self.functions.pop()
182 self.handle_function_body(f)
184 return module
186 def generic_visit(self, node: ast.AST) -> ast.AST:
187 ctx_len = len(self.context)
188 self.context.append(type(node).__name__)
189 result = super().generic_visit(node)
190 self.context = self.context[:ctx_len]
191 return result