Coverage for onionizer/onionizer/onionizer.py: 99%
108 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-04-06 11:39 +0200
« prev ^ index » next coverage.py v7.2.2, created at 2023-04-06 11:39 +0200
1import functools
2import inspect
3from abc import ABC
4from contextlib import ExitStack
5from typing import Callable, Any, Iterable, Sequence, TypeVar, Generator
7T = TypeVar("T")
9OnionGenerator = Generator[Any, T, T]
11UNCHANGED = 123
13__all__ = [
14 "wrap",
15 "decorate",
16 "OnionGenerator",
17 "UNCHANGED",
18 "PositionalArgs",
19 "MixedArgs",
20 "KeywordArgs",
21 "postprocessor",
22 "preprocessor",
23 "as_decorator",
24]
27def _capture_last_message(coroutine, value_to_send: Any) -> Any:
28 try:
29 coroutine.send(value_to_send)
30 except StopIteration as e:
31 # expected if the generator is exhausted
32 return e.value
33 else:
34 raise RuntimeError(
35 "Generator did not exhaust. Your function should yield exactly once."
36 )
39def _leave_the_onion(coroutines: Sequence, output: Any) -> Any:
40 for coroutine in reversed(coroutines):
41 # reversed to respect onion model
42 output = _capture_last_message(coroutine, output)
43 return output
46def as_decorator(middleware):
47 return decorate([middleware])
50def decorate(middlewares):
51 if not isinstance(middlewares, Iterable):
52 if callable(middlewares):
53 middlewares = [middlewares]
54 else:
55 raise TypeError(
56 "middlewares must be a list of coroutines or a single coroutine"
57 )
59 def decorator(func):
60 return wrap(func, middlewares)
62 return decorator
65def wrap(
66 func: Callable[..., Any], middlewares: list, sigcheck: bool = True
67) -> Callable[..., Any]:
68 """
69 It takes a function and a list of middlewares,
70 and returns a function that calls the middlewares in order, then the
71 function, then the middlewares in reverse order
73 def func(x, y):
74 return x + y
76 def middleware1(*args, **kwargs):
77 result = yield (args[0]+1, args[1]), kwargs
78 return result
80 def middleware2(*args, **kwargs):
81 result = yield (args[0], args[1]+1), kwargs
82 return result
85 wrapped_func = dip.wrap(func, [middleware1, middleware2])
86 result = wrapped_func(0, 0)
88 assert result == 2
90 :param func: the function to be wrapped
91 :type func: Callable[..., Any]
92 :param middlewares: a list of functions that will be called in order
93 :type middlewares: list
94 :return: A function that wraps the original function with the middlewares.
95 """
96 _check_validity(func, middlewares, sigcheck)
98 @functools.wraps(func)
99 def wrapped_func(*args, **kwargs):
100 arguments = MixedArgs(args, kwargs)
101 coroutines = []
102 with ExitStack() as stack:
103 # programmatic support for context manager, possibly nested !
104 # https://docs.python.org/3/library/contextlib.html#contextlib.ExitStack
105 for middleware in middlewares:
106 if hasattr(middleware, "__enter__") and hasattr(middleware, "__exit__"):
107 stack.enter_context(middleware)
108 continue
109 coroutine = arguments.call_function(middleware)
110 coroutines.append(coroutine)
111 try:
112 raw_arguments = coroutine.send(None)
113 except AttributeError:
114 raise TypeError(
115 f"Middleware {middleware.__name__} is not a coroutine. "
116 f"Did you forget to use a yield statement?"
117 )
118 arguments = _refine(raw_arguments, arguments)
119 # just reached the core of the onion
120 output = arguments.call_function(func)
121 # now we go back to the surface
122 output = _leave_the_onion(coroutines, output)
123 return output
125 return wrapped_func
128def _check_validity(func, middlewares, sigcheck):
129 if not callable(func):
130 raise TypeError("func must be callable")
131 if not isinstance(middlewares, Iterable):
132 raise TypeError("middlewares must be a list of coroutines")
133 if sigcheck:
134 _inspect_signatures(func, middlewares)
137def _inspect_signatures(func, middlewares):
138 func_signature = inspect.signature(func)
139 func_signature_params = func_signature.parameters
140 for middleware in middlewares:
141 if not (
142 hasattr(middleware, "ignore_signature_check")
143 and middleware.ignore_signature_check is True
144 ) and not all(hasattr(middleware, attr) for attr in ("__enter__", "__exit__")):
145 middleware_signature = inspect.signature(middleware)
146 middleware_signature_params = middleware_signature.parameters
147 if middleware_signature_params != func_signature_params:
148 raise ValueError(
149 f"Expected arguments of the target function mismatch "
150 f"middleware expected arguments. {func.__name__}{func_signature} "
151 f"differs with {middleware.__name__}{middleware_signature}"
152 )
155class ArgsMode(ABC):
156 def call_function(self, func: Callable[..., Any]):
157 raise NotImplementedError
160class PositionalArgs(ArgsMode):
161 def __init__(self, *args):
162 self.args = args
164 def call_function(self, func: Callable[..., Any]):
165 return func(*self.args)
168class KeywordArgs(ArgsMode):
169 def __init__(self, kwargs):
170 self.kwargs = kwargs
172 def call_function(self, func: Callable[..., Any]):
173 return func(**self.kwargs)
176class MixedArgs(ArgsMode):
177 def __init__(self, args, kwargs):
178 self.args = args
179 self.kwargs = kwargs
181 def call_function(self, func: Callable[..., Any]):
182 return func(*self.args, **self.kwargs)
185def _refine(arguments, previous_arguments, accept_none=False):
186 if arguments is UNCHANGED or (accept_none and arguments is None):
187 return previous_arguments
188 if isinstance(arguments, ArgsMode):
189 return arguments
190 if not isinstance(arguments, Sequence) or len(arguments) != 2:
191 raise TypeError(
192 "arguments must be a tuple of length 2, "
193 "maybe use onionizer.PositionalArgs or onionizer.MixedArgs instead"
194 )
195 args, kwargs = arguments
196 return MixedArgs(args, kwargs)
199def preprocessor(func):
200 @functools.wraps(func)
201 def wrapper(*args, **kwargs) -> OnionGenerator:
202 arguments = yield func(*args, **kwargs)
203 return arguments
205 return wrapper
208def postprocessor(func):
209 @functools.wraps(func)
210 def wrapper(*args, **kwargs) -> OnionGenerator:
211 output = yield UNCHANGED
212 return func(output)
214 wrapper.ignore_signature_check = True
215 return wrapper