Coverage for greyhorse/factory.py: 98%
417 statements
« prev ^ index » next coverage.py v7.11.3, created at 2026-05-11 15:54 +0300
« prev ^ index » next coverage.py v7.11.3, created at 2026-05-11 15:54 +0300
1from __future__ import annotations
3import builtins
4import contextlib
5import enum
6import inspect
7from abc import ABC, abstractmethod
8from collections.abc import AsyncGenerator, Awaitable, Callable, Generator, Mapping
9from dataclasses import dataclass, replace
10from functools import partial, reduce
11from typing import Any, get_type_hints, override
13from greyhorse.maybe import Maybe
14from greyhorse.utils.types import (
15 TypeWrapper,
16 is_awaitable,
17 is_maybe,
18 is_optional,
19 unwrap_maybe,
20 unwrap_optional,
21)
24type FactoryFn[T] = (
25 Callable[..., T]
26 | Callable[..., Awaitable[T]]
27 | type[T]
28 | T
29 | Callable[..., Generator[T, T, None]]
30 | Callable[..., AsyncGenerator[T, T]]
31)
34class CachePolicy(enum.IntEnum):
35 GRAPH = enum.auto(0)
36 FRAGMENT = enum.auto()
37 ANY = enum.auto()
40class ParamOpt(int, enum.Enum):
41 NORMAL = 0
42 OPTIONAL = 1
43 MAYBE = 2
46@dataclass(slots=True, frozen=True, kw_only=True)
47class ParamData:
48 type: type
49 name: str
50 optional: ParamOpt
51 value: Any | None = None
53 @property
54 def is_required(self) -> bool:
55 return self.optional == ParamOpt.NORMAL
57 @property
58 def raw_type(self) -> builtins.type:
59 match self.optional:
60 case ParamOpt.NORMAL:
61 return self.type
62 case ParamOpt.OPTIONAL:
63 return self.type | None
64 case ParamOpt.MAYBE: 64 ↛ exitline 64 didn't return from function 'raw_type' because the pattern on line 64 always matched
65 return Maybe[self.type]
67 @classmethod
68 def from_kv(cls, k: str, v: type, value: Any | None = None) -> ParamData:
69 opt = (
70 ParamOpt.OPTIONAL
71 if is_optional(v)
72 else ParamOpt.MAYBE
73 if is_maybe(v)
74 else ParamOpt.NORMAL
75 )
76 return cls(type=unwrap_maybe(unwrap_optional(v)), name=k, optional=opt, value=value)
79class Factory[T](TypeWrapper[T], ABC):
80 __slots__ = ('_cache_policy', '_name_map', '_params', '_type_map')
82 scoped: bool = False
83 is_async: bool = False
85 def __init__(
86 self,
87 params: Mapping[str, type],
88 args: Mapping[str, Any] | None = None,
89 cache_policy: CachePolicy = CachePolicy.GRAPH,
90 ) -> None:
91 args = args or {}
92 params_data = [ParamData.from_kv(k, v, args.get(k)) for k, v in params.items()]
93 self._params: list[ParamData] = params_data
94 self._name_map = {pd.name: i for i, pd in enumerate(self._params)}
95 self._type_map = {pd.type: i for i, pd in enumerate(self._params)}
96 self._cache_policy = cache_policy
98 @property
99 def cache_policy(self) -> CachePolicy:
100 return self._cache_policy
102 @property
103 def cacheable(self) -> bool:
104 return self._cache_policy > 0
106 @abstractmethod
107 def create[**P](self, *args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]: ...
109 @abstractmethod
110 def destroy(self, instance: T) -> None | Awaitable[None]: ...
112 def __call__[**P](self, *args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]:
113 return self.create(*args, **kwargs)
115 def clone(self) -> Factory[T]:
116 instance = type(self)({})
117 self._clone_into(instance)
118 return instance
120 @abstractmethod
121 def __eq__(self, other: Factory[T]) -> bool: ...
123 @abstractmethod
124 def __hash__(self) -> int: ...
126 @property
127 def return_type(self) -> type[T]:
128 return self.__wrapped_type__
130 @property
131 def params_count(self) -> int:
132 return len(self._params)
134 # @property
135 # def params_types(self) -> tuple[type, ...]:
136 # return tuple(pd.raw_type for pd in self._params if pd.value is None)
138 @property
139 def args_count(self) -> int:
140 return reduce(lambda a, pd: a + (1 if pd.value is not None else 0), self._params, 0)
142 @property
143 def all_params(self) -> Mapping[str, type]:
144 res = {}
145 for pd in self._params:
146 res[pd.name] = pd.raw_type
147 return res
149 @property
150 def actual_params(self) -> Mapping[str, type]:
151 res = {}
152 for pd in self._params:
153 if pd.value is None:
154 res[pd.name] = pd.raw_type
155 return res
157 def check_signature(self, signature: type[Callable]) -> bool:
158 sig_ret_type = signature.__args__[-1]
159 sig_params = signature.__args__[0 : len(signature.__args__) - 1]
161 if not issubclass(self.return_type, sig_ret_type):
162 return False
164 sig_params_idx = 0
166 for pd in self._params:
167 if sig_params_idx < len(sig_params):
168 sig_param_type = sig_params[sig_params_idx]
169 fit = False
171 match pd.optional:
172 case ParamOpt.NORMAL:
173 if inspect.isclass(sig_param_type):
174 fit = issubclass(sig_param_type, pd.type)
175 else:
176 fit = sig_param_type is pd.type
177 case ParamOpt.OPTIONAL:
178 fit = issubclass(unwrap_optional(sig_param_type), pd.type)
179 case ParamOpt.MAYBE: 179 ↛ 182line 179 didn't jump to line 182 because the pattern on line 179 always matched
180 fit = issubclass(unwrap_maybe(sig_param_type), pd.type)
182 if fit:
183 sig_params_idx += 1
184 continue
186 if pd.value is None and pd.optional == ParamOpt.NORMAL:
187 return False
189 return sig_params_idx == len(sig_params)
191 def has_named_param(self, name: str) -> bool:
192 return name in self._name_map
194 def has_typed_param[U](self, key: type[U]) -> bool:
195 return key in self._type_map
197 def add_named_arg[U](self, name: str, value: U) -> bool:
198 if name not in self._name_map:
199 return False
200 idx = self._name_map[name]
201 if self._params[idx].value is value:
202 return False
203 # if self._params[idx].optional == _ParamOpt.MAYBE:
204 # value = Maybe(value)
205 self._params[idx] = replace(self._params[idx], value=value)
206 return True
208 def remove_named_arg(self, name: str) -> bool:
209 if name not in self._name_map:
210 return False
211 idx = self._name_map[name]
212 self._params[idx] = replace(self._params[idx], value=None)
213 return True
215 def add_typed_arg[U](self, key: type[U], value: U) -> bool:
216 if key not in self._type_map:
217 return False
218 idx = self._type_map[key]
219 if self._params[idx].value is value:
220 return False
221 # if self._params[idx].optional == _ParamOpt.MAYBE:
222 # value = Maybe(value)
223 self._params[idx] = replace(self._params[idx], value=value)
224 return True
226 def remove_typed_arg[U](self, key: type[U]) -> bool:
227 if key not in self._type_map:
228 return False
229 idx = self._type_map[key]
230 self._params[idx] = replace(self._params[idx], value=None)
231 return True
233 def _clone_into(self, other: Factory[T]) -> None:
234 other._params = self._params.copy() # noqa: SLF001
235 other._name_map = self._name_map.copy() # noqa: SLF001
236 other._type_map = self._type_map.copy() # noqa: SLF001
237 other._cache_policy = self._cache_policy # noqa: SLF001
239 def _prepare_kwargs[**P](self, *args: P.args, **kwargs: P.kwargs) -> dict[str, Any]:
240 prepared_kwargs = {}
241 args_idx = 0
243 for pd in self._params:
244 if args_idx < len(args):
245 arg_value = args[args_idx]
246 arg_type = type(arg_value)
247 fit = False
249 match pd.optional:
250 case ParamOpt.NORMAL:
251 fit = issubclass(arg_type, pd.type)
252 case ParamOpt.OPTIONAL:
253 fit = issubclass(unwrap_optional(arg_type), pd.type)
254 case ParamOpt.MAYBE: 254 ↛ 258line 254 didn't jump to line 258 because the pattern on line 254 always matched
255 fit = issubclass(unwrap_maybe(arg_type), pd.type)
256 arg_value = Maybe(arg_value)
258 if fit:
259 prepared_kwargs[pd.name] = arg_value
260 args_idx += 1
261 continue
263 match pd.optional:
264 case ParamOpt.NORMAL:
265 if pd.name in kwargs:
266 prepared_kwargs[pd.name] = kwargs[pd.name]
267 elif pd.value is not None:
268 prepared_kwargs[pd.name] = pd.value
269 case ParamOpt.OPTIONAL:
270 prepared_kwargs[pd.name] = kwargs.get(pd.name, pd.value)
271 case ParamOpt.MAYBE: 271 ↛ 243line 271 didn't jump to line 243 because the pattern on line 271 always matched
272 if pd.name in kwargs:
273 prepared_kwargs[pd.name] = Maybe(kwargs[pd.name])
274 else:
275 prepared_kwargs[pd.name] = Maybe(pd.value)
277 return prepared_kwargs
279 @classmethod
280 def from_syncfn(
281 cls,
282 fn: Callable[..., T],
283 cache_policy: CachePolicy | None = None,
284 hints: Mapping[str, Any] | None = None,
285 ) -> Factory[T]:
286 if not hints:
287 hints = get_type_hints(strip_partial(fn), localns={'T': cls.__wrapped_type__})
288 return_type = hints.pop('return', cls.__wrapped_type__)
289 kwargs = {'cache_policy': cache_policy} if cache_policy is not None else {}
290 return _SyncFnFactory[return_type](fn, hints, **kwargs)
292 @classmethod
293 def from_asyncfn(
294 cls,
295 fn: Callable[..., Awaitable[T]],
296 cache_policy: CachePolicy | None = None,
297 hints: Mapping[str, Any] | None = None,
298 ) -> Factory[T]:
299 if not hints:
300 hints = get_type_hints(strip_partial(fn), localns={'T': cls.__wrapped_type__})
301 return_type = hints.pop('return', cls.__wrapped_type__)
302 kwargs = {'cache_policy': cache_policy} if cache_policy is not None else {}
303 return _AsyncFnFactory[return_type](fn, hints, **kwargs)
305 @classmethod
306 def from_class(
307 cls,
308 cls_: type[T],
309 cache_policy: CachePolicy | None = None,
310 hints: Mapping[str, Any] | None = None,
311 ) -> Factory[T]:
312 if not hints:
313 hints = get_type_hints(cls_.__init__, localns={'T': cls.__wrapped_type__})
314 hints.pop('return', None)
315 kwargs = {'cache_policy': cache_policy} if cache_policy is not None else {}
316 return _ClassFactory[cls_](cls_, hints, **kwargs)
318 @classmethod
319 def from_instance(cls, instance: T) -> Factory[T]:
320 instance_type = type(instance)
321 return _InstanceFactory[instance_type](instance)
323 @classmethod
324 def from_syncgen(
325 cls,
326 gen_fn: Callable[..., Generator[T, T, None]],
327 cache_policy: CachePolicy | None = None,
328 hints: Mapping[str, Any] | None = None,
329 ) -> Factory[T]:
330 if not hints:
331 hints = get_type_hints(strip_partial(gen_fn), localns={'T': cls.__wrapped_type__})
332 return_type = hints.pop('return', cls.__wrapped_type__)
333 if hasattr(return_type, '__origin__') and hasattr(return_type, '__args__'): 333 ↛ 337line 333 didn't jump to line 337 because the condition on line 333 was always true
334 match return_type.__origin__.__name__:
335 case 'Generator' | 'Iterator' | 'Iterable': 335 ↛ 337line 335 didn't jump to line 337 because the pattern on line 335 always matched
336 return_type = return_type.__args__[0]
337 kwargs = {'cache_policy': cache_policy} if cache_policy is not None else {}
338 return _SyncGenFactory[return_type](gen_fn, hints, **kwargs)
340 @classmethod
341 def from_asyncgen(
342 cls,
343 gen_fn: Callable[..., AsyncGenerator[T, T]],
344 cache_policy: CachePolicy | None = None,
345 hints: Mapping[str, Any] | None = None,
346 ) -> Factory[T]:
347 if not hints:
348 hints = get_type_hints(strip_partial(gen_fn), localns={'T': cls.__wrapped_type__})
349 return_type = hints.pop('return', cls.__wrapped_type__)
350 if hasattr(return_type, '__origin__') and hasattr(return_type, '__args__'): 350 ↛ 354line 350 didn't jump to line 354 because the condition on line 350 was always true
351 match return_type.__origin__.__name__:
352 case 'AsyncGenerator' | 'AsyncIterator' | 'AsyncIterable': 352 ↛ 354line 352 didn't jump to line 354 because the pattern on line 352 always matched
353 return_type = return_type.__args__[0]
354 kwargs = {'cache_policy': cache_policy} if cache_policy is not None else {}
355 return _AsyncGenFactory[return_type](gen_fn, hints, **kwargs)
358def into_factory[T](
359 fn: FactoryFn[T],
360 return_hint: type[T] | None = None,
361 cache_policy: CachePolicy | None = None,
362) -> Factory[T]:
363 orig_fn = strip_partial(fn)
364 return_hint = return_hint if return_hint is not None else type
366 if callable(fn):
367 unwrapped_fn = inspect.unwrap(orig_fn)
369 if inspect.isgeneratorfunction(unwrapped_fn):
370 return Factory[return_hint].from_syncgen(fn, cache_policy)
371 if inspect.isasyncgenfunction(unwrapped_fn):
372 return Factory[return_hint].from_asyncgen(fn, cache_policy)
373 if inspect.isfunction(orig_fn) or inspect.ismethod(orig_fn):
374 if is_awaitable(orig_fn):
375 return Factory[return_hint].from_asyncfn(fn, cache_policy)
376 return Factory[return_hint].from_syncfn(fn, cache_policy)
377 if inspect.isclass(orig_fn):
378 return Factory[return_hint].from_class(fn, cache_policy)
379 return Factory[return_hint].from_instance(fn)
382def strip_partial[T](fn: FactoryFn[T]) -> FactoryFn[T]:
383 orig_fn = fn
384 while isinstance(orig_fn, partial):
385 orig_fn = orig_fn.func
386 return orig_fn
389class _SyncFnFactory[T](Factory[T]):
390 __slots__ = ('_fn',)
392 def __init__(
393 self,
394 fn: Callable[..., T],
395 params: Mapping[str, type],
396 args: Mapping[str, Any] | None = None,
397 cache_policy: CachePolicy = CachePolicy.GRAPH,
398 ) -> None:
399 super().__init__(params, args, cache_policy)
400 self._fn = fn
402 @override
403 def create[**P](self, *args: P.args, **kwargs: P.kwargs) -> T:
404 prepared_kwargs = self._prepare_kwargs(*args, **kwargs)
405 return self._fn(**prepared_kwargs)
407 @override
408 def destroy(self, instance: T) -> None:
409 del instance
411 @override
412 def clone(self) -> Factory[T]:
413 instance = type(self)(self._fn, {})
414 self._clone_into(instance)
415 return instance
417 @override
418 def __eq__(self, other: Factory[T]) -> bool:
419 return isinstance(other, _SyncFnFactory) and self._fn == other._fn
421 @override
422 def __hash__(self) -> int:
423 return hash(self._fn)
426class _AsyncFnFactory[T](Factory[T]):
427 __slots__ = ('_fn',)
429 is_async = True
431 def __init__(
432 self,
433 fn: Callable[..., Awaitable[T]],
434 params: Mapping[str, type],
435 args: Mapping[str, Any] | None = None,
436 cache_policy: CachePolicy = CachePolicy.GRAPH,
437 ) -> None:
438 super().__init__(params, args, cache_policy)
439 self._fn = fn
441 @override
442 async def create[**P](self, *args: P.args, **kwargs: P.kwargs) -> T:
443 prepared_kwargs = self._prepare_kwargs(*args, **kwargs)
444 return await self._fn(**prepared_kwargs)
446 @override
447 async def destroy(self, instance: T) -> None:
448 del instance
450 @override
451 async def __call__[**P](self, *args: P.args, **kwargs: P.kwargs) -> T:
452 return await self.create(*args, **kwargs)
454 @override
455 def clone(self) -> Factory[T]:
456 instance = type(self)(self._fn, {})
457 self._clone_into(instance)
458 return instance
460 @override
461 def __eq__(self, other: Factory[T]) -> bool:
462 return isinstance(other, _AsyncFnFactory) and self._fn == other._fn
464 @override
465 def __hash__(self) -> int:
466 return hash(self._fn)
469class _ClassFactory[T](Factory[T]):
470 __slots__ = ('_cls',)
472 def __init__(
473 self,
474 cls: type[T],
475 params: Mapping[str, type],
476 args: Mapping[str, Any] | None = None,
477 cache_policy: CachePolicy = CachePolicy.FRAGMENT,
478 ) -> None:
479 super().__init__(params, args, cache_policy)
480 self._cls = cls
482 @override
483 def create[**P](self, *args: P.args, **kwargs: P.kwargs) -> T:
484 prepared_kwargs = self._prepare_kwargs(*args, **kwargs)
485 return self._cls(**prepared_kwargs)
487 @override
488 def destroy(self, instance: T) -> None:
489 del instance
491 @override
492 def clone(self) -> Factory[T]:
493 instance = type(self)(self._cls, {})
494 self._clone_into(instance)
495 return instance
497 @override
498 def __eq__(self, other: Factory[T]) -> bool:
499 return isinstance(other, _ClassFactory) and self._cls == other._cls
501 @override
502 def __hash__(self) -> int:
503 return hash(self._cls)
506class _InstanceFactory[T](Factory[T]):
507 __slots__ = ('_instance',)
509 def __init__(self, instance: T) -> None:
510 super().__init__({}, None, CachePolicy.ANY)
511 self._instance = instance
513 @override
514 def create[**P](self, *args: P.args, **kwargs: P.kwargs) -> T:
515 return self._instance
517 @override
518 def destroy(self, instance: T) -> None: ...
520 @override
521 def clone(self) -> Factory[T]:
522 instance = type(self)(self._instance)
523 self._clone_into(instance)
524 return instance
526 @override
527 def __eq__(self, other: Factory[T]) -> bool:
528 return isinstance(other, _InstanceFactory) and self._instance == other._instance
530 @override
531 def __hash__(self) -> int:
532 return hash(self._instance)
535class _SyncGenFactory[T](Factory[T]):
536 __slots__ = ('_gen_fn', '_gens')
538 scoped = True
540 def __init__(
541 self,
542 generator_fn: Callable[..., Generator[T, T, None]],
543 params: Mapping[str, type],
544 args: Mapping[str, Any] | None = None,
545 cache_policy: CachePolicy = CachePolicy.GRAPH,
546 ) -> None:
547 super().__init__(params, args, cache_policy)
548 self._gen_fn = generator_fn
549 self._gens: dict[int, Generator[T, T, None]] = {}
551 @override
552 def create[**P](self, *args: P.args, **kwargs: P.kwargs) -> T:
553 prepared_kwargs = self._prepare_kwargs(*args, **kwargs)
554 gen = self._gen_fn(**prepared_kwargs)
556 try:
557 instance = next(gen)
559 except StopIteration as e:
560 instance = e.value
562 self._gens[id(instance)] = gen
563 return instance
565 @override
566 def destroy(self, instance: T) -> None:
567 key = id(instance)
568 if not (gen := self._gens.pop(key, None)):
569 return
571 with contextlib.suppress(StopIteration):
572 gen.send(instance)
574 @override
575 def clone(self) -> Factory[T]:
576 instance = type(self)(self._gen_fn, {})
577 self._clone_into(instance)
578 return instance
580 @override
581 def __eq__(self, other: Factory[T]) -> bool:
582 return isinstance(other, _SyncGenFactory) and self._gen_fn == other._gen_fn
584 @override
585 def __hash__(self) -> int:
586 return hash(self._gen_fn)
589class _AsyncGenFactory[T](Factory[T]):
590 __slots__ = ('_gen_fn', '_gens')
592 scoped = True
593 is_async = True
595 def __init__(
596 self,
597 generator_fn: Callable[..., AsyncGenerator[T, T]],
598 params: Mapping[str, type],
599 args: Mapping[str, Any] | None = None,
600 cache_policy: CachePolicy = CachePolicy.GRAPH,
601 ) -> None:
602 super().__init__(params, args, cache_policy)
603 self._gen_fn = generator_fn
604 self._gens: dict[int, AsyncGenerator[T, T]] = {}
606 @override
607 async def create[**P](self, *args: P.args, **kwargs: P.kwargs) -> T:
608 prepared_kwargs = self._prepare_kwargs(*args, **kwargs)
609 gen = self._gen_fn(**prepared_kwargs)
611 try:
612 instance = await anext(gen)
614 except StopAsyncIteration:
615 instance = None
617 self._gens[id(instance)] = gen
618 return instance
620 @override
621 async def destroy(self, instance: T) -> None:
622 key = id(instance)
623 if not (gen := self._gens.pop(key, None)):
624 return
626 with contextlib.suppress(StopAsyncIteration):
627 await gen.asend(instance)
629 @override
630 async def __call__[**P](self, *args: P.args, **kwargs: P.kwargs) -> T:
631 return await self.create(*args, **kwargs)
633 @override
634 def clone(self) -> Factory[T]:
635 instance = type(self)(self._gen_fn, {})
636 self._clone_into(instance)
637 return instance
639 @override
640 def __eq__(self, other: Factory[T]) -> bool:
641 return isinstance(other, _AsyncGenFactory) and self._gen_fn == other._gen_fn
643 @override
644 def __hash__(self) -> int:
645 return hash(self._gen_fn)