Coverage for greyhorse/factory.py: 98%

417 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2026-05-11 15:54 +0300

1from __future__ import annotations 

2 

3import builtins 

4import contextlib 

5import enum 

6import inspect 

7from abc import ABC, abstractmethod 

8from collections.abc import AsyncGenerator, Awaitable, Callable, Generator, Mapping 

9from dataclasses import dataclass, replace 

10from functools import partial, reduce 

11from typing import Any, get_type_hints, override 

12 

13from greyhorse.maybe import Maybe 

14from greyhorse.utils.types import ( 

15 TypeWrapper, 

16 is_awaitable, 

17 is_maybe, 

18 is_optional, 

19 unwrap_maybe, 

20 unwrap_optional, 

21) 

22 

23 

24type FactoryFn[T] = ( 

25 Callable[..., T] 

26 | Callable[..., Awaitable[T]] 

27 | type[T] 

28 | T 

29 | Callable[..., Generator[T, T, None]] 

30 | Callable[..., AsyncGenerator[T, T]] 

31) 

32 

33 

34class CachePolicy(enum.IntEnum): 

35 GRAPH = enum.auto(0) 

36 FRAGMENT = enum.auto() 

37 ANY = enum.auto() 

38 

39 

40class ParamOpt(int, enum.Enum): 

41 NORMAL = 0 

42 OPTIONAL = 1 

43 MAYBE = 2 

44 

45 

46@dataclass(slots=True, frozen=True, kw_only=True) 

47class ParamData: 

48 type: type 

49 name: str 

50 optional: ParamOpt 

51 value: Any | None = None 

52 

53 @property 

54 def is_required(self) -> bool: 

55 return self.optional == ParamOpt.NORMAL 

56 

57 @property 

58 def raw_type(self) -> builtins.type: 

59 match self.optional: 

60 case ParamOpt.NORMAL: 

61 return self.type 

62 case ParamOpt.OPTIONAL: 

63 return self.type | None 

64 case ParamOpt.MAYBE: 64 ↛ exitline 64 didn't return from function 'raw_type' because the pattern on line 64 always matched

65 return Maybe[self.type] 

66 

67 @classmethod 

68 def from_kv(cls, k: str, v: type, value: Any | None = None) -> ParamData: 

69 opt = ( 

70 ParamOpt.OPTIONAL 

71 if is_optional(v) 

72 else ParamOpt.MAYBE 

73 if is_maybe(v) 

74 else ParamOpt.NORMAL 

75 ) 

76 return cls(type=unwrap_maybe(unwrap_optional(v)), name=k, optional=opt, value=value) 

77 

78 

79class Factory[T](TypeWrapper[T], ABC): 

80 __slots__ = ('_cache_policy', '_name_map', '_params', '_type_map') 

81 

82 scoped: bool = False 

83 is_async: bool = False 

84 

85 def __init__( 

86 self, 

87 params: Mapping[str, type], 

88 args: Mapping[str, Any] | None = None, 

89 cache_policy: CachePolicy = CachePolicy.GRAPH, 

90 ) -> None: 

91 args = args or {} 

92 params_data = [ParamData.from_kv(k, v, args.get(k)) for k, v in params.items()] 

93 self._params: list[ParamData] = params_data 

94 self._name_map = {pd.name: i for i, pd in enumerate(self._params)} 

95 self._type_map = {pd.type: i for i, pd in enumerate(self._params)} 

96 self._cache_policy = cache_policy 

97 

98 @property 

99 def cache_policy(self) -> CachePolicy: 

100 return self._cache_policy 

101 

102 @property 

103 def cacheable(self) -> bool: 

104 return self._cache_policy > 0 

105 

106 @abstractmethod 

107 def create[**P](self, *args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]: ... 

108 

109 @abstractmethod 

110 def destroy(self, instance: T) -> None | Awaitable[None]: ... 

111 

112 def __call__[**P](self, *args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]: 

113 return self.create(*args, **kwargs) 

114 

115 def clone(self) -> Factory[T]: 

116 instance = type(self)({}) 

117 self._clone_into(instance) 

118 return instance 

119 

120 @abstractmethod 

121 def __eq__(self, other: Factory[T]) -> bool: ... 

122 

123 @abstractmethod 

124 def __hash__(self) -> int: ... 

125 

126 @property 

127 def return_type(self) -> type[T]: 

128 return self.__wrapped_type__ 

129 

130 @property 

131 def params_count(self) -> int: 

132 return len(self._params) 

133 

134 # @property 

135 # def params_types(self) -> tuple[type, ...]: 

136 # return tuple(pd.raw_type for pd in self._params if pd.value is None) 

137 

138 @property 

139 def args_count(self) -> int: 

140 return reduce(lambda a, pd: a + (1 if pd.value is not None else 0), self._params, 0) 

141 

142 @property 

143 def all_params(self) -> Mapping[str, type]: 

144 res = {} 

145 for pd in self._params: 

146 res[pd.name] = pd.raw_type 

147 return res 

148 

149 @property 

150 def actual_params(self) -> Mapping[str, type]: 

151 res = {} 

152 for pd in self._params: 

153 if pd.value is None: 

154 res[pd.name] = pd.raw_type 

155 return res 

156 

157 def check_signature(self, signature: type[Callable]) -> bool: 

158 sig_ret_type = signature.__args__[-1] 

159 sig_params = signature.__args__[0 : len(signature.__args__) - 1] 

160 

161 if not issubclass(self.return_type, sig_ret_type): 

162 return False 

163 

164 sig_params_idx = 0 

165 

166 for pd in self._params: 

167 if sig_params_idx < len(sig_params): 

168 sig_param_type = sig_params[sig_params_idx] 

169 fit = False 

170 

171 match pd.optional: 

172 case ParamOpt.NORMAL: 

173 if inspect.isclass(sig_param_type): 

174 fit = issubclass(sig_param_type, pd.type) 

175 else: 

176 fit = sig_param_type is pd.type 

177 case ParamOpt.OPTIONAL: 

178 fit = issubclass(unwrap_optional(sig_param_type), pd.type) 

179 case ParamOpt.MAYBE: 179 ↛ 182line 179 didn't jump to line 182 because the pattern on line 179 always matched

180 fit = issubclass(unwrap_maybe(sig_param_type), pd.type) 

181 

182 if fit: 

183 sig_params_idx += 1 

184 continue 

185 

186 if pd.value is None and pd.optional == ParamOpt.NORMAL: 

187 return False 

188 

189 return sig_params_idx == len(sig_params) 

190 

191 def has_named_param(self, name: str) -> bool: 

192 return name in self._name_map 

193 

194 def has_typed_param[U](self, key: type[U]) -> bool: 

195 return key in self._type_map 

196 

197 def add_named_arg[U](self, name: str, value: U) -> bool: 

198 if name not in self._name_map: 

199 return False 

200 idx = self._name_map[name] 

201 if self._params[idx].value is value: 

202 return False 

203 # if self._params[idx].optional == _ParamOpt.MAYBE: 

204 # value = Maybe(value) 

205 self._params[idx] = replace(self._params[idx], value=value) 

206 return True 

207 

208 def remove_named_arg(self, name: str) -> bool: 

209 if name not in self._name_map: 

210 return False 

211 idx = self._name_map[name] 

212 self._params[idx] = replace(self._params[idx], value=None) 

213 return True 

214 

215 def add_typed_arg[U](self, key: type[U], value: U) -> bool: 

216 if key not in self._type_map: 

217 return False 

218 idx = self._type_map[key] 

219 if self._params[idx].value is value: 

220 return False 

221 # if self._params[idx].optional == _ParamOpt.MAYBE: 

222 # value = Maybe(value) 

223 self._params[idx] = replace(self._params[idx], value=value) 

224 return True 

225 

226 def remove_typed_arg[U](self, key: type[U]) -> bool: 

227 if key not in self._type_map: 

228 return False 

229 idx = self._type_map[key] 

230 self._params[idx] = replace(self._params[idx], value=None) 

231 return True 

232 

233 def _clone_into(self, other: Factory[T]) -> None: 

234 other._params = self._params.copy() # noqa: SLF001 

235 other._name_map = self._name_map.copy() # noqa: SLF001 

236 other._type_map = self._type_map.copy() # noqa: SLF001 

237 other._cache_policy = self._cache_policy # noqa: SLF001 

238 

239 def _prepare_kwargs[**P](self, *args: P.args, **kwargs: P.kwargs) -> dict[str, Any]: 

240 prepared_kwargs = {} 

241 args_idx = 0 

242 

243 for pd in self._params: 

244 if args_idx < len(args): 

245 arg_value = args[args_idx] 

246 arg_type = type(arg_value) 

247 fit = False 

248 

249 match pd.optional: 

250 case ParamOpt.NORMAL: 

251 fit = issubclass(arg_type, pd.type) 

252 case ParamOpt.OPTIONAL: 

253 fit = issubclass(unwrap_optional(arg_type), pd.type) 

254 case ParamOpt.MAYBE: 254 ↛ 258line 254 didn't jump to line 258 because the pattern on line 254 always matched

255 fit = issubclass(unwrap_maybe(arg_type), pd.type) 

256 arg_value = Maybe(arg_value) 

257 

258 if fit: 

259 prepared_kwargs[pd.name] = arg_value 

260 args_idx += 1 

261 continue 

262 

263 match pd.optional: 

264 case ParamOpt.NORMAL: 

265 if pd.name in kwargs: 

266 prepared_kwargs[pd.name] = kwargs[pd.name] 

267 elif pd.value is not None: 

268 prepared_kwargs[pd.name] = pd.value 

269 case ParamOpt.OPTIONAL: 

270 prepared_kwargs[pd.name] = kwargs.get(pd.name, pd.value) 

271 case ParamOpt.MAYBE: 271 ↛ 243line 271 didn't jump to line 243 because the pattern on line 271 always matched

272 if pd.name in kwargs: 

273 prepared_kwargs[pd.name] = Maybe(kwargs[pd.name]) 

274 else: 

275 prepared_kwargs[pd.name] = Maybe(pd.value) 

276 

277 return prepared_kwargs 

278 

279 @classmethod 

280 def from_syncfn( 

281 cls, 

282 fn: Callable[..., T], 

283 cache_policy: CachePolicy | None = None, 

284 hints: Mapping[str, Any] | None = None, 

285 ) -> Factory[T]: 

286 if not hints: 

287 hints = get_type_hints(strip_partial(fn), localns={'T': cls.__wrapped_type__}) 

288 return_type = hints.pop('return', cls.__wrapped_type__) 

289 kwargs = {'cache_policy': cache_policy} if cache_policy is not None else {} 

290 return _SyncFnFactory[return_type](fn, hints, **kwargs) 

291 

292 @classmethod 

293 def from_asyncfn( 

294 cls, 

295 fn: Callable[..., Awaitable[T]], 

296 cache_policy: CachePolicy | None = None, 

297 hints: Mapping[str, Any] | None = None, 

298 ) -> Factory[T]: 

299 if not hints: 

300 hints = get_type_hints(strip_partial(fn), localns={'T': cls.__wrapped_type__}) 

301 return_type = hints.pop('return', cls.__wrapped_type__) 

302 kwargs = {'cache_policy': cache_policy} if cache_policy is not None else {} 

303 return _AsyncFnFactory[return_type](fn, hints, **kwargs) 

304 

305 @classmethod 

306 def from_class( 

307 cls, 

308 cls_: type[T], 

309 cache_policy: CachePolicy | None = None, 

310 hints: Mapping[str, Any] | None = None, 

311 ) -> Factory[T]: 

312 if not hints: 

313 hints = get_type_hints(cls_.__init__, localns={'T': cls.__wrapped_type__}) 

314 hints.pop('return', None) 

315 kwargs = {'cache_policy': cache_policy} if cache_policy is not None else {} 

316 return _ClassFactory[cls_](cls_, hints, **kwargs) 

317 

318 @classmethod 

319 def from_instance(cls, instance: T) -> Factory[T]: 

320 instance_type = type(instance) 

321 return _InstanceFactory[instance_type](instance) 

322 

323 @classmethod 

324 def from_syncgen( 

325 cls, 

326 gen_fn: Callable[..., Generator[T, T, None]], 

327 cache_policy: CachePolicy | None = None, 

328 hints: Mapping[str, Any] | None = None, 

329 ) -> Factory[T]: 

330 if not hints: 

331 hints = get_type_hints(strip_partial(gen_fn), localns={'T': cls.__wrapped_type__}) 

332 return_type = hints.pop('return', cls.__wrapped_type__) 

333 if hasattr(return_type, '__origin__') and hasattr(return_type, '__args__'): 333 ↛ 337line 333 didn't jump to line 337 because the condition on line 333 was always true

334 match return_type.__origin__.__name__: 

335 case 'Generator' | 'Iterator' | 'Iterable': 335 ↛ 337line 335 didn't jump to line 337 because the pattern on line 335 always matched

336 return_type = return_type.__args__[0] 

337 kwargs = {'cache_policy': cache_policy} if cache_policy is not None else {} 

338 return _SyncGenFactory[return_type](gen_fn, hints, **kwargs) 

339 

340 @classmethod 

341 def from_asyncgen( 

342 cls, 

343 gen_fn: Callable[..., AsyncGenerator[T, T]], 

344 cache_policy: CachePolicy | None = None, 

345 hints: Mapping[str, Any] | None = None, 

346 ) -> Factory[T]: 

347 if not hints: 

348 hints = get_type_hints(strip_partial(gen_fn), localns={'T': cls.__wrapped_type__}) 

349 return_type = hints.pop('return', cls.__wrapped_type__) 

350 if hasattr(return_type, '__origin__') and hasattr(return_type, '__args__'): 350 ↛ 354line 350 didn't jump to line 354 because the condition on line 350 was always true

351 match return_type.__origin__.__name__: 

352 case 'AsyncGenerator' | 'AsyncIterator' | 'AsyncIterable': 352 ↛ 354line 352 didn't jump to line 354 because the pattern on line 352 always matched

353 return_type = return_type.__args__[0] 

354 kwargs = {'cache_policy': cache_policy} if cache_policy is not None else {} 

355 return _AsyncGenFactory[return_type](gen_fn, hints, **kwargs) 

356 

357 

358def into_factory[T]( 

359 fn: FactoryFn[T], 

360 return_hint: type[T] | None = None, 

361 cache_policy: CachePolicy | None = None, 

362) -> Factory[T]: 

363 orig_fn = strip_partial(fn) 

364 return_hint = return_hint if return_hint is not None else type 

365 

366 if callable(fn): 

367 unwrapped_fn = inspect.unwrap(orig_fn) 

368 

369 if inspect.isgeneratorfunction(unwrapped_fn): 

370 return Factory[return_hint].from_syncgen(fn, cache_policy) 

371 if inspect.isasyncgenfunction(unwrapped_fn): 

372 return Factory[return_hint].from_asyncgen(fn, cache_policy) 

373 if inspect.isfunction(orig_fn) or inspect.ismethod(orig_fn): 

374 if is_awaitable(orig_fn): 

375 return Factory[return_hint].from_asyncfn(fn, cache_policy) 

376 return Factory[return_hint].from_syncfn(fn, cache_policy) 

377 if inspect.isclass(orig_fn): 

378 return Factory[return_hint].from_class(fn, cache_policy) 

379 return Factory[return_hint].from_instance(fn) 

380 

381 

382def strip_partial[T](fn: FactoryFn[T]) -> FactoryFn[T]: 

383 orig_fn = fn 

384 while isinstance(orig_fn, partial): 

385 orig_fn = orig_fn.func 

386 return orig_fn 

387 

388 

389class _SyncFnFactory[T](Factory[T]): 

390 __slots__ = ('_fn',) 

391 

392 def __init__( 

393 self, 

394 fn: Callable[..., T], 

395 params: Mapping[str, type], 

396 args: Mapping[str, Any] | None = None, 

397 cache_policy: CachePolicy = CachePolicy.GRAPH, 

398 ) -> None: 

399 super().__init__(params, args, cache_policy) 

400 self._fn = fn 

401 

402 @override 

403 def create[**P](self, *args: P.args, **kwargs: P.kwargs) -> T: 

404 prepared_kwargs = self._prepare_kwargs(*args, **kwargs) 

405 return self._fn(**prepared_kwargs) 

406 

407 @override 

408 def destroy(self, instance: T) -> None: 

409 del instance 

410 

411 @override 

412 def clone(self) -> Factory[T]: 

413 instance = type(self)(self._fn, {}) 

414 self._clone_into(instance) 

415 return instance 

416 

417 @override 

418 def __eq__(self, other: Factory[T]) -> bool: 

419 return isinstance(other, _SyncFnFactory) and self._fn == other._fn 

420 

421 @override 

422 def __hash__(self) -> int: 

423 return hash(self._fn) 

424 

425 

426class _AsyncFnFactory[T](Factory[T]): 

427 __slots__ = ('_fn',) 

428 

429 is_async = True 

430 

431 def __init__( 

432 self, 

433 fn: Callable[..., Awaitable[T]], 

434 params: Mapping[str, type], 

435 args: Mapping[str, Any] | None = None, 

436 cache_policy: CachePolicy = CachePolicy.GRAPH, 

437 ) -> None: 

438 super().__init__(params, args, cache_policy) 

439 self._fn = fn 

440 

441 @override 

442 async def create[**P](self, *args: P.args, **kwargs: P.kwargs) -> T: 

443 prepared_kwargs = self._prepare_kwargs(*args, **kwargs) 

444 return await self._fn(**prepared_kwargs) 

445 

446 @override 

447 async def destroy(self, instance: T) -> None: 

448 del instance 

449 

450 @override 

451 async def __call__[**P](self, *args: P.args, **kwargs: P.kwargs) -> T: 

452 return await self.create(*args, **kwargs) 

453 

454 @override 

455 def clone(self) -> Factory[T]: 

456 instance = type(self)(self._fn, {}) 

457 self._clone_into(instance) 

458 return instance 

459 

460 @override 

461 def __eq__(self, other: Factory[T]) -> bool: 

462 return isinstance(other, _AsyncFnFactory) and self._fn == other._fn 

463 

464 @override 

465 def __hash__(self) -> int: 

466 return hash(self._fn) 

467 

468 

469class _ClassFactory[T](Factory[T]): 

470 __slots__ = ('_cls',) 

471 

472 def __init__( 

473 self, 

474 cls: type[T], 

475 params: Mapping[str, type], 

476 args: Mapping[str, Any] | None = None, 

477 cache_policy: CachePolicy = CachePolicy.FRAGMENT, 

478 ) -> None: 

479 super().__init__(params, args, cache_policy) 

480 self._cls = cls 

481 

482 @override 

483 def create[**P](self, *args: P.args, **kwargs: P.kwargs) -> T: 

484 prepared_kwargs = self._prepare_kwargs(*args, **kwargs) 

485 return self._cls(**prepared_kwargs) 

486 

487 @override 

488 def destroy(self, instance: T) -> None: 

489 del instance 

490 

491 @override 

492 def clone(self) -> Factory[T]: 

493 instance = type(self)(self._cls, {}) 

494 self._clone_into(instance) 

495 return instance 

496 

497 @override 

498 def __eq__(self, other: Factory[T]) -> bool: 

499 return isinstance(other, _ClassFactory) and self._cls == other._cls 

500 

501 @override 

502 def __hash__(self) -> int: 

503 return hash(self._cls) 

504 

505 

506class _InstanceFactory[T](Factory[T]): 

507 __slots__ = ('_instance',) 

508 

509 def __init__(self, instance: T) -> None: 

510 super().__init__({}, None, CachePolicy.ANY) 

511 self._instance = instance 

512 

513 @override 

514 def create[**P](self, *args: P.args, **kwargs: P.kwargs) -> T: 

515 return self._instance 

516 

517 @override 

518 def destroy(self, instance: T) -> None: ... 

519 

520 @override 

521 def clone(self) -> Factory[T]: 

522 instance = type(self)(self._instance) 

523 self._clone_into(instance) 

524 return instance 

525 

526 @override 

527 def __eq__(self, other: Factory[T]) -> bool: 

528 return isinstance(other, _InstanceFactory) and self._instance == other._instance 

529 

530 @override 

531 def __hash__(self) -> int: 

532 return hash(self._instance) 

533 

534 

535class _SyncGenFactory[T](Factory[T]): 

536 __slots__ = ('_gen_fn', '_gens') 

537 

538 scoped = True 

539 

540 def __init__( 

541 self, 

542 generator_fn: Callable[..., Generator[T, T, None]], 

543 params: Mapping[str, type], 

544 args: Mapping[str, Any] | None = None, 

545 cache_policy: CachePolicy = CachePolicy.GRAPH, 

546 ) -> None: 

547 super().__init__(params, args, cache_policy) 

548 self._gen_fn = generator_fn 

549 self._gens: dict[int, Generator[T, T, None]] = {} 

550 

551 @override 

552 def create[**P](self, *args: P.args, **kwargs: P.kwargs) -> T: 

553 prepared_kwargs = self._prepare_kwargs(*args, **kwargs) 

554 gen = self._gen_fn(**prepared_kwargs) 

555 

556 try: 

557 instance = next(gen) 

558 

559 except StopIteration as e: 

560 instance = e.value 

561 

562 self._gens[id(instance)] = gen 

563 return instance 

564 

565 @override 

566 def destroy(self, instance: T) -> None: 

567 key = id(instance) 

568 if not (gen := self._gens.pop(key, None)): 

569 return 

570 

571 with contextlib.suppress(StopIteration): 

572 gen.send(instance) 

573 

574 @override 

575 def clone(self) -> Factory[T]: 

576 instance = type(self)(self._gen_fn, {}) 

577 self._clone_into(instance) 

578 return instance 

579 

580 @override 

581 def __eq__(self, other: Factory[T]) -> bool: 

582 return isinstance(other, _SyncGenFactory) and self._gen_fn == other._gen_fn 

583 

584 @override 

585 def __hash__(self) -> int: 

586 return hash(self._gen_fn) 

587 

588 

589class _AsyncGenFactory[T](Factory[T]): 

590 __slots__ = ('_gen_fn', '_gens') 

591 

592 scoped = True 

593 is_async = True 

594 

595 def __init__( 

596 self, 

597 generator_fn: Callable[..., AsyncGenerator[T, T]], 

598 params: Mapping[str, type], 

599 args: Mapping[str, Any] | None = None, 

600 cache_policy: CachePolicy = CachePolicy.GRAPH, 

601 ) -> None: 

602 super().__init__(params, args, cache_policy) 

603 self._gen_fn = generator_fn 

604 self._gens: dict[int, AsyncGenerator[T, T]] = {} 

605 

606 @override 

607 async def create[**P](self, *args: P.args, **kwargs: P.kwargs) -> T: 

608 prepared_kwargs = self._prepare_kwargs(*args, **kwargs) 

609 gen = self._gen_fn(**prepared_kwargs) 

610 

611 try: 

612 instance = await anext(gen) 

613 

614 except StopAsyncIteration: 

615 instance = None 

616 

617 self._gens[id(instance)] = gen 

618 return instance 

619 

620 @override 

621 async def destroy(self, instance: T) -> None: 

622 key = id(instance) 

623 if not (gen := self._gens.pop(key, None)): 

624 return 

625 

626 with contextlib.suppress(StopAsyncIteration): 

627 await gen.asend(instance) 

628 

629 @override 

630 async def __call__[**P](self, *args: P.args, **kwargs: P.kwargs) -> T: 

631 return await self.create(*args, **kwargs) 

632 

633 @override 

634 def clone(self) -> Factory[T]: 

635 instance = type(self)(self._gen_fn, {}) 

636 self._clone_into(instance) 

637 return instance 

638 

639 @override 

640 def __eq__(self, other: Factory[T]) -> bool: 

641 return isinstance(other, _AsyncGenFactory) and self._gen_fn == other._gen_fn 

642 

643 @override 

644 def __hash__(self) -> int: 

645 return hash(self._gen_fn)