Coverage for src/incant/__init__.py: 99%
192 statements
« prev ^ index » next coverage.py v7.0.1, created at 2023-06-27 02:31 +0200
« prev ^ index » next coverage.py v7.0.1, created at 2023-06-27 02:31 +0200
1from functools import lru_cache
2from inspect import Parameter, Signature, iscoroutinefunction
3from typing import (
4 Any,
5 Awaitable,
6 Callable,
7 Dict,
8 List,
9 Optional,
10 Sequence,
11 Set,
12 Tuple,
13 Type,
14 TypeVar,
15 Union,
16)
18from attr import Factory, define, field, frozen
20from ._codegen import (
21 CtxManagerKind,
22 Invocation,
23 ParameterDep,
24 compile_incant_wrapper,
25 compile_invoke,
26)
27from ._compat import NO_OVERRIDE, Override, get_annotated_override, signature
30__all__ = ["NO_OVERRIDE", "Override", "Hook", "Incanter", "IncantError"]
32_type = type
35R = TypeVar("R")
38@frozen
39class FactoryDep:
40 factory: Callable # The fn to call.
41 arg_name: str # The name of the param this is fulfulling.
42 # Is the result of the factory a ctx manager?
43 is_ctx_manager: Optional[CtxManagerKind] = None
46Dep = Union[FactoryDep, ParameterDep]
49PredicateFn = Callable[[Parameter], bool]
52def is_subclass(type, superclass) -> bool:
53 """A safe version of `issubclass`."""
54 try:
55 return issubclass(type, superclass)
56 except Exception:
57 return False
60@frozen
61class Hook:
62 predicate: PredicateFn
63 factory: Optional[Tuple[Callable[[Parameter], Callable], Optional[CtxManagerKind]]]
65 @classmethod
66 def for_name(cls, name: str, hook: Optional[Callable]):
67 return cls(lambda p: p.name == name, None if hook is None else (lambda _: hook, None)) # type: ignore
69 @classmethod
70 def for_type(cls, type: Any, hook: Optional[Callable]):
71 """Register by exact type (subclasses won't match)."""
72 return cls(
73 lambda p: p.annotation == type,
74 None if hook is None else (lambda _: hook, None), # type: ignore
75 )
78@define
79class Incanter:
80 hook_factory_registry: List[Hook] = Factory(list)
81 _invoke_cache: Callable = field(
82 init=False,
83 default=Factory(
84 lambda self: lru_cache(None)(self._gen_invoke), takes_self=True
85 ),
86 )
87 _incant_cache: Callable = field(
88 init=False,
89 default=Factory(
90 lambda self: lru_cache(None)(self._gen_incant), takes_self=True
91 ),
92 )
94 def prepare(
95 self,
96 fn: Callable[..., R],
97 hooks: Sequence[Hook] = (),
98 is_async: Optional[bool] = None,
99 forced_deps: Sequence[Union[Callable, Tuple[Callable, CtxManagerKind]]] = (),
100 ) -> Callable[..., R]:
101 """Prepare a new function, encapsulating satisfied dependencies.
103 forced_deps: A sequence of dependencies that will be called even if `fn` doesn't require them explicitly.
104 """
105 return self._invoke_cache(
106 fn,
107 tuple(hooks),
108 is_async,
109 tuple(f if isinstance(f, tuple) else (f, None) for f in forced_deps),
110 )
112 def invoke(self, fn: Callable[..., R], *args, **kwargs) -> R:
113 return self.prepare(fn, is_async=False)(*args, **kwargs)
115 async def ainvoke(self, fn: Callable[..., Awaitable[R]], *args, **kwargs) -> R:
116 return await self.prepare(fn, is_async=True)(*args, **kwargs)
118 def incant(self, fn: Callable[..., R], *args, **kwargs) -> R:
119 """Invoke `fn` the best way we can."""
120 return self._incant(fn, args, kwargs)
122 async def aincant(self, fn: Callable[..., Awaitable[R]], *args, **kwargs) -> R:
123 """Invoke async `fn` the best way we can."""
124 return await self._incant(fn, args, kwargs, is_async=True)
126 def register_by_name(
127 self,
128 fn: Optional[Callable] = None,
129 *,
130 name: Optional[str] = None,
131 is_ctx_manager: Optional[CtxManagerKind] = None,
132 ):
133 """
134 Register a factory to be injected by name. Can also be used as a decorator.
136 If the name is not provided, the name of the factory will be used.
137 """
138 if fn is None:
139 # Decorator
140 return lambda fn: self.register_by_name(
141 fn, name=name, is_ctx_manager=is_ctx_manager
142 )
144 if name is None:
145 name = fn.__name__
146 self.register_hook(lambda p: p.name == name, fn, is_ctx_manager=is_ctx_manager)
147 return fn
149 def register_by_type(
150 self,
151 fn: Union[Callable, Type],
152 type: Optional[Type] = None,
153 is_ctx_manager: Optional[CtxManagerKind] = None,
154 ):
155 """
156 Register a factory to be injected by type. Can also be used as a decorator.
158 If the type is not provided, the return annotation from the
159 factory will be used.
160 """
161 if type is None:
162 if isinstance(fn, _type):
163 type_to_reg = fn
164 else:
165 sig = signature(fn)
166 type_to_reg = sig.return_annotation
167 if type_to_reg is Signature.empty:
168 raise IncantError("No return type found, provide a type.")
169 else:
170 type_to_reg = type
171 self.register_hook(
172 lambda p: is_subclass(p.annotation, type_to_reg), fn, is_ctx_manager
173 )
174 return fn
176 def register_hook(
177 self,
178 predicate: PredicateFn,
179 factory: Callable,
180 is_ctx_manager: Optional[CtxManagerKind] = None,
181 ) -> None:
182 self.register_hook_factory(predicate, lambda _: factory, is_ctx_manager)
184 def register_hook_factory(
185 self,
186 predicate: PredicateFn,
187 hook_factory: Callable[[Parameter], Callable],
188 is_ctx_manager: Optional[CtxManagerKind] = None,
189 ) -> None:
190 self.hook_factory_registry.insert(
191 0, Hook(predicate, (hook_factory, is_ctx_manager))
192 )
193 self._invoke_cache.cache_clear() # type: ignore
194 self._incant_cache.cache_clear() # type: ignore
196 def _incant(
197 self,
198 fn: Callable,
199 args: Tuple[Any, ...],
200 kwargs: Dict[str, Any],
201 is_async: bool = False,
202 ):
203 """The shared entrypoint for ``incant`` and ``aincant``."""
205 pos_args_types = tuple([a.__class__ for a in args])
206 kwargs_by_name_and_type = frozenset(
207 [(k, v.__class__) for k, v in kwargs.items()]
208 )
209 wrapper = self._incant_cache(
210 fn, pos_args_types, kwargs_by_name_and_type, is_async
211 )
213 return wrapper(*args, **kwargs)
215 def _gen_incant_plan(
216 self, fn, pos_args_types: Tuple[Any, ...], kwargs: Set[Tuple[str, Any]]
217 ) -> List[Union[int, str]]:
218 """Generate a plan to invoke `fn`, potentially using `args` and `kwargs`."""
219 pos_arg_plan: List[Union[int, str]] = []
220 kwarg_names = {kw[0] for kw in kwargs}
221 sig = signature(fn)
222 for arg_name, arg in sig.parameters.items():
223 found = False
224 if (
225 arg.annotation is not Signature.empty
226 and (arg_name, arg.annotation) in kwargs
227 ):
228 pos_arg_plan.append(arg_name)
229 found = True
230 if found:
231 continue
233 if arg.annotation is not Signature.empty:
234 for ix, a in enumerate(pos_args_types):
235 if is_subclass(a, arg.annotation):
236 pos_arg_plan.append(ix)
237 found = True
238 break
239 if found:
240 continue
242 if arg_name in kwarg_names:
243 pos_arg_plan.append(arg_name)
244 elif arg.default is not Signature.empty:
245 # An argument with a default we cannot fulfil is ok.
246 continue
247 else:
248 raise TypeError(f"Cannot fulfil argument {arg_name}")
249 return pos_arg_plan
251 def _gen_incant(
252 self,
253 fn: Callable,
254 pos_args_types: Tuple,
255 kwargs_by_name_and_type: Set,
256 is_async: Optional[bool] = False,
257 ) -> Callable:
258 plan = self._gen_incant_plan(fn, pos_args_types, kwargs_by_name_and_type)
259 return compile_incant_wrapper(
260 fn, plan, len(pos_args_types), len(kwargs_by_name_and_type)
261 )
263 def _gen_dep_tree(
264 self,
265 fn: Callable,
266 additional_hooks: Sequence[Hook],
267 forced_deps: Sequence[Tuple[Callable, Optional[CtxManagerKind]]] = (),
268 ) -> List[Tuple[Callable, Optional[CtxManagerKind], List[Dep]]]:
269 """Generate the dependency tree for `fn`.
271 The dependency tree is a list of factories and their dependencies.
273 The actual function is the last item.
274 """
275 to_process = [(fn, None), *forced_deps]
276 final_nodes: List[Tuple[Callable, Optional[CtxManagerKind], List[Dep]]] = []
277 hooks = list(additional_hooks) + self.hook_factory_registry
278 already_processed_hooks = set()
279 while to_process:
280 _nodes = to_process
281 to_process = []
282 for node, ctx_mgr_kind in _nodes:
283 sig = _signature(node)
284 dependents: List[Union[ParameterDep, FactoryDep]] = []
285 for name, param in sig.parameters.items():
286 if (
287 node is not fn
288 and param.default is not Signature.empty
289 and param.kind is Parameter.KEYWORD_ONLY
290 ):
291 # Do not expose optional kw-only params of dependencies.
292 continue
293 param_type = param.annotation
294 for hook in hooks:
295 if hook.predicate(param):
296 # Match!
297 if hook.factory is None:
298 dependents.append(
299 ParameterDep(name, param_type, param.default)
300 )
301 else:
302 factory = hook.factory[0](param)
303 if factory == node:
304 # A hook cannot satisfy itself.
305 continue
306 if factory not in already_processed_hooks:
307 to_process.append((factory, hook.factory[1]))
308 already_processed_hooks.add(factory)
309 dependents.append(
310 FactoryDep(factory, name, hook.factory[1])
311 )
313 break
314 else:
315 dependents.append(ParameterDep(name, param_type, param.default))
316 final_nodes.insert(0, (node, ctx_mgr_kind, dependents))
318 # We need to sort the nodes to ensure no unbound local vars.
319 dep_nodes = final_nodes[:-1]
320 dep_nodes.sort(key=lambda n: len(n[2]))
321 dep_nodes.append(final_nodes[-1])
322 return dep_nodes
324 def _gen_invoke(
325 self,
326 fn: Callable,
327 hooks: Tuple[Hook, ...] = (),
328 is_async: Optional[bool] = False,
329 forced_deps: Tuple[Tuple[Callable, CtxManagerKind], ...] = (),
330 ):
331 dep_tree = self._gen_dep_tree(fn, hooks, forced_deps)
332 if len(dep_tree) == 1 and (
333 is_async is None or (is_async is iscoroutinefunction(fn))
334 ):
335 # Nothing we can do for this function.
336 return fn
338 # is_async = None means autodetect
339 if is_async is None:
340 is_async = any(
341 iscoroutinefunction(factory) or ctx_mgr_kind == "async"
342 for factory, ctx_mgr_kind, _ in dep_tree
343 )
345 invocs: List[Invocation] = []
346 # All non-parameter deps become invocations.
347 for ix, (factory, ctx_mgr_kind, deps) in enumerate(dep_tree[:-1]):
348 if not is_async and (
349 iscoroutinefunction(factory) or ctx_mgr_kind == "async"
350 ):
351 raise TypeError(
352 f"The function would be a coroutine because of {factory}, use `ainvoke` instead"
353 )
355 # It's possible this is a forced dependency, and nothing downstream actually needs it.
356 # In that case, we mark it as forced so it doesn't get its own local var in the generated function.
357 is_needed = False
358 for _, _, downstream_deps in dep_tree[ix + 1 :]:
359 if any(
360 isinstance(d, FactoryDep) and d.factory == factory
361 for d in downstream_deps
362 ):
363 is_needed = True
364 break
366 invocs.append(
367 Invocation(
368 factory,
369 [
370 dep.factory if isinstance(dep, FactoryDep) else dep
371 for dep in deps
372 ],
373 not is_needed,
374 ctx_mgr_kind,
375 )
376 )
378 outer_args = [
379 dep for node in dep_tree for dep in node[2] if isinstance(dep, ParameterDep)
380 ]
381 # We need to do a pass over the outer args to consolidate duplicates.
382 per_outer_arg: dict[str, List[ParameterDep]] = {}
383 for arg in outer_args:
384 per_outer_arg.setdefault(arg.arg_name, []).append(arg)
386 outer_args.clear()
387 for arg_name, args in per_outer_arg.items():
388 if len(args) == 1:
389 arg_type = args[0].type
390 arg_default = args[0].default
391 else:
392 # If there are multiple competing argument defs,
393 # we need to pick a winning type.
394 arg_type = Signature.empty
395 arg_default = Signature.empty
396 for arg in args:
397 try:
398 arg_type = _reconcile_types(arg_type, arg.type)
399 except Exception as exc:
400 raise IncantError(
401 f"Unable to reconcile types {arg_type} and {arg.type} for argument {arg_name}"
402 ) from exc
403 if arg.default is not Signature.empty:
404 arg_default = arg.default
405 outer_args.append(ParameterDep(arg_name, arg_type, arg_default))
407 # outer_args need to be sorted by the presence of a default value
408 outer_args.sort(key=lambda a: a.default is not Signature.empty)
410 fn_factory_args = []
411 fn_factories = []
412 for dep in dep_tree[-1][2]:
413 if isinstance(dep, FactoryDep):
414 fn_factory_args.append(dep.arg_name)
415 fn_factories.append(dep.factory)
417 return compile_invoke(
418 fn,
419 fn_factories,
420 fn_factory_args,
421 outer_args,
422 invocs,
423 is_async=is_async,
424 )
427def _reconcile_types(type_a, type_b):
428 if type_a is Signature.empty:
429 return type_b
430 if type_b is Signature.empty:
431 return type_a
432 if type_a is type_b:
433 return type_a
434 raise Exception(f"Unable to reconcile types {type_a!r} and {type_b!r}")
437def _signature(f: Callable) -> Signature:
438 """Return the signature of f, with potential overrides applied."""
439 sig = signature(f)
440 parameters = [get_annotated_override(val) for val in sig.parameters.values()]
441 return sig.replace(parameters=parameters)
444class IncantError(Exception):
445 """An Incant error."""