Coverage for src/incant/_codegen.py: 99%
144 statements
« prev ^ index » next coverage.py v7.0.1, created at 2023-06-27 02:25 +0200
« prev ^ index » next coverage.py v7.0.1, created at 2023-06-27 02:25 +0200
1import linecache
3from inspect import Signature, iscoroutinefunction
4from typing import Any, Callable, Counter, Dict, List, Literal, Optional, Union
6from attr import define
8from ._compat import signature
11@define
12class ParameterDep:
13 arg_name: str
14 type: Any
15 default: Any = Signature.empty
18CtxManagerKind = Literal["sync", "async"]
21@define
22class Invocation:
23 """Produce an invocation (and possibly a local var) in a generated function."""
25 factory: Callable
26 args: List[Union[Callable, ParameterDep]]
27 is_forced: bool = False
28 is_ctx_manager: Optional[CtxManagerKind] = None
31def compile_invoke(
32 fn: Callable,
33 fn_args: List[Callable],
34 fn_factory_args: List[str],
35 outer_args: List[ParameterDep],
36 invocations: List[Invocation],
37 is_async: bool = False,
38) -> Callable:
39 """Generate the invocation wrapper for `fn`.
41 :param fn_factory_args: Used names to avoid for local variables.
42 :param outer_args: Arguments that the generated function needs to retain.
44 """
45 # Some arguments need to be taken from outside.
46 # Some arguments need to be calculated from factories.
47 sig = signature(fn)
48 fn_name = f"invoke_{fn.__name__}" if fn.__name__ != "<lambda>" else "invoke_lambda"
49 globs: Dict[str, Any] = {}
50 taken_local_vars = set()
51 arg_lines = []
53 for dep in outer_args:
54 if dep.type is not Signature.empty:
55 # Some types, like new unions (`int|str`), do not have names.
56 if (type_name := getattr(dep.type, "__name__", None)) and (
57 type_name not in globs or globs[type_name] is dep.type
58 ):
59 arg_type_snippet = f": {type_name}"
60 globs[type_name] = dep.type
61 else:
62 arg_type_snippet = f": _incant_arg_{dep.arg_name}"
63 globs[f"_incant_arg_{dep.arg_name}"] = dep.type
64 else:
65 arg_type_snippet = ""
66 if dep.default is not Signature.empty:
67 arg_default = f"_incant_default_{dep.arg_name}"
68 arg_type_snippet = f"{arg_type_snippet} = {arg_default}"
69 globs[arg_default] = dep.default
71 arg_lines.append(f"{dep.arg_name}{arg_type_snippet}")
72 taken_local_vars.add(dep.arg_name)
73 outer_arg_names = {o.arg_name for o in outer_args}
75 lines = []
77 ret_type = ""
78 if sig.return_annotation is not Signature.empty:
79 tn = getattr(sig.return_annotation, "__name__", None)
80 if tn is None:
81 tn = "None"
82 elif tn in globs and globs[tn] is not sig.return_annotation:
83 tn = "_incant_return_type"
84 globs[tn] = sig.return_annotation
85 ret_type = f" -> {tn}"
86 if is_async:
87 lines.append(f"async def {fn_name}({', '.join(arg_lines)}){ret_type}:")
88 else:
89 lines.append(f"def {fn_name}({', '.join(arg_lines)}){ret_type}:")
91 local_vars_ix_by_factory = {
92 local_var.factory: ix for ix, local_var in enumerate(invocations)
93 }
94 inline_exprs_by_factory: Dict[Callable, str] = {}
95 ind = 0 # Indentation level
97 local_counter = 0
99 # The results of some invocations are used only once.
100 # In that case, we can forgo the use of a local variable.
101 # We call these invocations `inlineable`.
102 # An invocation is inlineable if:
103 # * it is not a context manager
104 # * it appears only once in the args attribute of the invocation chain.
105 factory_fns = Counter(fn_args)
106 for invoc in invocations:
107 if invoc.is_ctx_manager:
108 continue
109 factory_fns.update(
110 Counter(a for a in invoc.args if not isinstance(a, ParameterDep))
111 )
112 inlineable = {fn for fn, cnt in factory_fns.items() if cnt == 1}
114 for i, invoc in enumerate(invocations):
115 inv_fn_name = invoc.factory.__name__
116 if (
117 inv_fn_name not in taken_local_vars
118 and inv_fn_name not in globs
119 and inv_fn_name != "<lambda>"
120 ):
121 global_fn_name = inv_fn_name
122 else:
123 global_fn_name = f"_incant_local_factory_{i}"
125 globs[global_fn_name] = invoc.factory
127 local_arg_lines = []
128 for local_arg in invoc.args:
129 if isinstance(local_arg, ParameterDep):
130 local_arg_lines.append(local_arg.arg_name)
131 else:
132 if local_arg in inline_exprs_by_factory:
133 local_arg_lines.append(inline_exprs_by_factory[local_arg])
134 else:
135 local_arg_lines.append(
136 f"_incant_local_{local_vars_ix_by_factory[local_arg]}"
137 )
139 if invoc.factory in inlineable and not invoc.is_ctx_manager:
140 aw = "await " if iscoroutinefunction(invoc.factory) else ""
141 inline_exprs_by_factory[
142 invoc.factory
143 ] = f"{aw}{global_fn_name}({', '.join(local_arg_lines)})"
145 else:
146 local_name = f"_incant_local_{local_counter}"
148 if invoc.is_ctx_manager is not None:
149 aw = "async " if invoc.is_ctx_manager == "async" else ""
150 if not invoc.is_forced:
151 lines.append(
152 f" {' ' * ind}{aw}with {global_fn_name}({', '.join(local_arg_lines)}) as {local_name}:"
153 )
154 local_counter += 1
155 else:
156 lines.append(
157 f" {' ' * ind}{aw}with {global_fn_name}({', '.join(local_arg_lines)}):"
158 )
159 ind += 2
160 else:
161 aw = "await " if iscoroutinefunction(invoc.factory) else ""
162 if not invoc.is_forced:
163 lines.append(
164 f" {' ' * ind}{local_name} = {aw}{global_fn_name}({', '.join(local_arg_lines)})"
165 )
166 local_counter += 1
167 else:
168 lines.append(
169 f" {' ' * ind}{aw}{global_fn_name}({', '.join(local_arg_lines)})"
170 )
172 incant_arg_lines = []
173 cnt = 0
174 for name in sig.parameters:
175 if name not in fn_factory_args and name in outer_arg_names:
176 incant_arg_lines.append(name)
177 else:
178 # We need to fish out the local name for this fn arg.
179 factory = fn_args[cnt]
180 if factory in inline_exprs_by_factory:
181 incant_arg_lines.append(inline_exprs_by_factory[factory])
182 else:
183 local_var_ix = local_vars_ix_by_factory[factory]
184 incant_arg_lines.append(f"_incant_local_{local_var_ix}")
185 cnt += 1
187 aw = "await " if iscoroutinefunction(fn) else ""
188 orig_name = fn.__name__
189 if (
190 orig_name != "<lambda>"
191 and orig_name not in globs
192 and orig_name not in outer_arg_names
193 and orig_name not in taken_local_vars
194 ):
195 inner_name = fn.__name__
196 else:
197 inner_name = "_incant_inner_fn"
198 globs[inner_name] = fn
199 lines.append(f" {' ' * ind}return {aw}{inner_name}({', '.join(incant_arg_lines)})")
201 script = "\n".join(lines)
203 fname = _generate_unique_filename(fn.__name__, "invoke", lines)
204 eval(compile(script, fname, "exec"), globs)
206 return globs[fn_name]
209def compile_incant_wrapper(
210 fn: Callable, incant_plan: List[Union[int, str]], num_pos_args: int, num_kwargs: int
211):
212 fn_name = f"incant_{fn.__name__}" if fn.__name__ != "<lambda>" else "incant_lambda"
213 globs = {"_incant_inner_fn": fn}
214 arg_lines = []
215 if num_pos_args:
216 arg_lines.append("*args")
218 kwargs = [arg for arg in incant_plan if isinstance(arg, str)]
219 arg_lines.extend(kwargs)
220 if num_kwargs > len(kwargs):
221 arg_lines.append("**kwargs")
223 lines = []
224 lines.append(f"def {fn_name}({', '.join(arg_lines)}):")
225 lines.append(" return _incant_inner_fn(")
226 for arg in incant_plan:
227 if isinstance(arg, int):
228 lines.append(f" args[{arg}],")
229 else:
230 lines.append(f" {arg},")
231 lines.append(" )")
233 script = "\n".join(lines)
235 fname = _generate_unique_filename(fn.__name__, "incant", lines)
236 eval(compile(script, fname, "exec"), globs)
238 return globs[fn_name]
241def _generate_unique_filename(func_name: str, func_type: str, source: List[str]) -> str:
242 """
243 Create a "filename" suitable for a function being generated.
244 """
245 extra = ""
246 count = 1
248 while True:
249 unique_filename = f"<incant generated {func_type} of {func_name}{extra}>"
250 # To handle concurrency we essentially "reserve" our spot in
251 # the linecache with a dummy line. The caller can then
252 # set this value correctly.
253 cache_line = (len(source), None, source, unique_filename)
254 if linecache.cache.setdefault(unique_filename, cache_line) == cache_line:
255 return unique_filename
257 # Looks like this spot is taken. Try again.
258 count += 1
259 extra = f"-{count}"