Coverage for src/incant/__init__.py: 99%

192 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2023-06-27 02:31 +0200

1from functools import lru_cache 

2from inspect import Parameter, Signature, iscoroutinefunction 

3from typing import ( 

4 Any, 

5 Awaitable, 

6 Callable, 

7 Dict, 

8 List, 

9 Optional, 

10 Sequence, 

11 Set, 

12 Tuple, 

13 Type, 

14 TypeVar, 

15 Union, 

16) 

17 

18from attr import Factory, define, field, frozen 

19 

20from ._codegen import ( 

21 CtxManagerKind, 

22 Invocation, 

23 ParameterDep, 

24 compile_incant_wrapper, 

25 compile_invoke, 

26) 

27from ._compat import NO_OVERRIDE, Override, get_annotated_override, signature 

28 

29 

30__all__ = ["NO_OVERRIDE", "Override", "Hook", "Incanter", "IncantError"] 

31 

32_type = type 

33 

34 

35R = TypeVar("R") 

36 

37 

38@frozen 

39class FactoryDep: 

40 factory: Callable # The fn to call. 

41 arg_name: str # The name of the param this is fulfulling. 

42 # Is the result of the factory a ctx manager? 

43 is_ctx_manager: Optional[CtxManagerKind] = None 

44 

45 

46Dep = Union[FactoryDep, ParameterDep] 

47 

48 

49PredicateFn = Callable[[Parameter], bool] 

50 

51 

52def is_subclass(type, superclass) -> bool: 

53 """A safe version of `issubclass`.""" 

54 try: 

55 return issubclass(type, superclass) 

56 except Exception: 

57 return False 

58 

59 

60@frozen 

61class Hook: 

62 predicate: PredicateFn 

63 factory: Optional[Tuple[Callable[[Parameter], Callable], Optional[CtxManagerKind]]] 

64 

65 @classmethod 

66 def for_name(cls, name: str, hook: Optional[Callable]): 

67 return cls(lambda p: p.name == name, None if hook is None else (lambda _: hook, None)) # type: ignore 

68 

69 @classmethod 

70 def for_type(cls, type: Any, hook: Optional[Callable]): 

71 """Register by exact type (subclasses won't match).""" 

72 return cls( 

73 lambda p: p.annotation == type, 

74 None if hook is None else (lambda _: hook, None), # type: ignore 

75 ) 

76 

77 

78@define 

79class Incanter: 

80 hook_factory_registry: List[Hook] = Factory(list) 

81 _invoke_cache: Callable = field( 

82 init=False, 

83 default=Factory( 

84 lambda self: lru_cache(None)(self._gen_invoke), takes_self=True 

85 ), 

86 ) 

87 _incant_cache: Callable = field( 

88 init=False, 

89 default=Factory( 

90 lambda self: lru_cache(None)(self._gen_incant), takes_self=True 

91 ), 

92 ) 

93 

94 def prepare( 

95 self, 

96 fn: Callable[..., R], 

97 hooks: Sequence[Hook] = (), 

98 is_async: Optional[bool] = None, 

99 forced_deps: Sequence[Union[Callable, Tuple[Callable, CtxManagerKind]]] = (), 

100 ) -> Callable[..., R]: 

101 """Prepare a new function, encapsulating satisfied dependencies. 

102 

103 forced_deps: A sequence of dependencies that will be called even if `fn` doesn't require them explicitly. 

104 """ 

105 return self._invoke_cache( 

106 fn, 

107 tuple(hooks), 

108 is_async, 

109 tuple(f if isinstance(f, tuple) else (f, None) for f in forced_deps), 

110 ) 

111 

112 def invoke(self, fn: Callable[..., R], *args, **kwargs) -> R: 

113 return self.prepare(fn, is_async=False)(*args, **kwargs) 

114 

115 async def ainvoke(self, fn: Callable[..., Awaitable[R]], *args, **kwargs) -> R: 

116 return await self.prepare(fn, is_async=True)(*args, **kwargs) 

117 

118 def incant(self, fn: Callable[..., R], *args, **kwargs) -> R: 

119 """Invoke `fn` the best way we can.""" 

120 return self._incant(fn, args, kwargs) 

121 

122 async def aincant(self, fn: Callable[..., Awaitable[R]], *args, **kwargs) -> R: 

123 """Invoke async `fn` the best way we can.""" 

124 return await self._incant(fn, args, kwargs, is_async=True) 

125 

126 def register_by_name( 

127 self, 

128 fn: Optional[Callable] = None, 

129 *, 

130 name: Optional[str] = None, 

131 is_ctx_manager: Optional[CtxManagerKind] = None, 

132 ): 

133 """ 

134 Register a factory to be injected by name. Can also be used as a decorator. 

135 

136 If the name is not provided, the name of the factory will be used. 

137 """ 

138 if fn is None: 

139 # Decorator 

140 return lambda fn: self.register_by_name( 

141 fn, name=name, is_ctx_manager=is_ctx_manager 

142 ) 

143 

144 if name is None: 

145 name = fn.__name__ 

146 self.register_hook(lambda p: p.name == name, fn, is_ctx_manager=is_ctx_manager) 

147 return fn 

148 

149 def register_by_type( 

150 self, 

151 fn: Union[Callable, Type], 

152 type: Optional[Type] = None, 

153 is_ctx_manager: Optional[CtxManagerKind] = None, 

154 ): 

155 """ 

156 Register a factory to be injected by type. Can also be used as a decorator. 

157 

158 If the type is not provided, the return annotation from the 

159 factory will be used. 

160 """ 

161 if type is None: 

162 if isinstance(fn, _type): 

163 type_to_reg = fn 

164 else: 

165 sig = signature(fn) 

166 type_to_reg = sig.return_annotation 

167 if type_to_reg is Signature.empty: 

168 raise IncantError("No return type found, provide a type.") 

169 else: 

170 type_to_reg = type 

171 self.register_hook( 

172 lambda p: is_subclass(p.annotation, type_to_reg), fn, is_ctx_manager 

173 ) 

174 return fn 

175 

176 def register_hook( 

177 self, 

178 predicate: PredicateFn, 

179 factory: Callable, 

180 is_ctx_manager: Optional[CtxManagerKind] = None, 

181 ) -> None: 

182 self.register_hook_factory(predicate, lambda _: factory, is_ctx_manager) 

183 

184 def register_hook_factory( 

185 self, 

186 predicate: PredicateFn, 

187 hook_factory: Callable[[Parameter], Callable], 

188 is_ctx_manager: Optional[CtxManagerKind] = None, 

189 ) -> None: 

190 self.hook_factory_registry.insert( 

191 0, Hook(predicate, (hook_factory, is_ctx_manager)) 

192 ) 

193 self._invoke_cache.cache_clear() # type: ignore 

194 self._incant_cache.cache_clear() # type: ignore 

195 

196 def _incant( 

197 self, 

198 fn: Callable, 

199 args: Tuple[Any, ...], 

200 kwargs: Dict[str, Any], 

201 is_async: bool = False, 

202 ): 

203 """The shared entrypoint for ``incant`` and ``aincant``.""" 

204 

205 pos_args_types = tuple([a.__class__ for a in args]) 

206 kwargs_by_name_and_type = frozenset( 

207 [(k, v.__class__) for k, v in kwargs.items()] 

208 ) 

209 wrapper = self._incant_cache( 

210 fn, pos_args_types, kwargs_by_name_and_type, is_async 

211 ) 

212 

213 return wrapper(*args, **kwargs) 

214 

215 def _gen_incant_plan( 

216 self, fn, pos_args_types: Tuple[Any, ...], kwargs: Set[Tuple[str, Any]] 

217 ) -> List[Union[int, str]]: 

218 """Generate a plan to invoke `fn`, potentially using `args` and `kwargs`.""" 

219 pos_arg_plan: List[Union[int, str]] = [] 

220 kwarg_names = {kw[0] for kw in kwargs} 

221 sig = signature(fn) 

222 for arg_name, arg in sig.parameters.items(): 

223 found = False 

224 if ( 

225 arg.annotation is not Signature.empty 

226 and (arg_name, arg.annotation) in kwargs 

227 ): 

228 pos_arg_plan.append(arg_name) 

229 found = True 

230 if found: 

231 continue 

232 

233 if arg.annotation is not Signature.empty: 

234 for ix, a in enumerate(pos_args_types): 

235 if is_subclass(a, arg.annotation): 

236 pos_arg_plan.append(ix) 

237 found = True 

238 break 

239 if found: 

240 continue 

241 

242 if arg_name in kwarg_names: 

243 pos_arg_plan.append(arg_name) 

244 elif arg.default is not Signature.empty: 

245 # An argument with a default we cannot fulfil is ok. 

246 continue 

247 else: 

248 raise TypeError(f"Cannot fulfil argument {arg_name}") 

249 return pos_arg_plan 

250 

251 def _gen_incant( 

252 self, 

253 fn: Callable, 

254 pos_args_types: Tuple, 

255 kwargs_by_name_and_type: Set, 

256 is_async: Optional[bool] = False, 

257 ) -> Callable: 

258 plan = self._gen_incant_plan(fn, pos_args_types, kwargs_by_name_and_type) 

259 return compile_incant_wrapper( 

260 fn, plan, len(pos_args_types), len(kwargs_by_name_and_type) 

261 ) 

262 

263 def _gen_dep_tree( 

264 self, 

265 fn: Callable, 

266 additional_hooks: Sequence[Hook], 

267 forced_deps: Sequence[Tuple[Callable, Optional[CtxManagerKind]]] = (), 

268 ) -> List[Tuple[Callable, Optional[CtxManagerKind], List[Dep]]]: 

269 """Generate the dependency tree for `fn`. 

270 

271 The dependency tree is a list of factories and their dependencies. 

272 

273 The actual function is the last item. 

274 """ 

275 to_process = [(fn, None), *forced_deps] 

276 final_nodes: List[Tuple[Callable, Optional[CtxManagerKind], List[Dep]]] = [] 

277 hooks = list(additional_hooks) + self.hook_factory_registry 

278 already_processed_hooks = set() 

279 while to_process: 

280 _nodes = to_process 

281 to_process = [] 

282 for node, ctx_mgr_kind in _nodes: 

283 sig = _signature(node) 

284 dependents: List[Union[ParameterDep, FactoryDep]] = [] 

285 for name, param in sig.parameters.items(): 

286 if ( 

287 node is not fn 

288 and param.default is not Signature.empty 

289 and param.kind is Parameter.KEYWORD_ONLY 

290 ): 

291 # Do not expose optional kw-only params of dependencies. 

292 continue 

293 param_type = param.annotation 

294 for hook in hooks: 

295 if hook.predicate(param): 

296 # Match! 

297 if hook.factory is None: 

298 dependents.append( 

299 ParameterDep(name, param_type, param.default) 

300 ) 

301 else: 

302 factory = hook.factory[0](param) 

303 if factory == node: 

304 # A hook cannot satisfy itself. 

305 continue 

306 if factory not in already_processed_hooks: 

307 to_process.append((factory, hook.factory[1])) 

308 already_processed_hooks.add(factory) 

309 dependents.append( 

310 FactoryDep(factory, name, hook.factory[1]) 

311 ) 

312 

313 break 

314 else: 

315 dependents.append(ParameterDep(name, param_type, param.default)) 

316 final_nodes.insert(0, (node, ctx_mgr_kind, dependents)) 

317 

318 # We need to sort the nodes to ensure no unbound local vars. 

319 dep_nodes = final_nodes[:-1] 

320 dep_nodes.sort(key=lambda n: len(n[2])) 

321 dep_nodes.append(final_nodes[-1]) 

322 return dep_nodes 

323 

324 def _gen_invoke( 

325 self, 

326 fn: Callable, 

327 hooks: Tuple[Hook, ...] = (), 

328 is_async: Optional[bool] = False, 

329 forced_deps: Tuple[Tuple[Callable, CtxManagerKind], ...] = (), 

330 ): 

331 dep_tree = self._gen_dep_tree(fn, hooks, forced_deps) 

332 if len(dep_tree) == 1 and ( 

333 is_async is None or (is_async is iscoroutinefunction(fn)) 

334 ): 

335 # Nothing we can do for this function. 

336 return fn 

337 

338 # is_async = None means autodetect 

339 if is_async is None: 

340 is_async = any( 

341 iscoroutinefunction(factory) or ctx_mgr_kind == "async" 

342 for factory, ctx_mgr_kind, _ in dep_tree 

343 ) 

344 

345 invocs: List[Invocation] = [] 

346 # All non-parameter deps become invocations. 

347 for ix, (factory, ctx_mgr_kind, deps) in enumerate(dep_tree[:-1]): 

348 if not is_async and ( 

349 iscoroutinefunction(factory) or ctx_mgr_kind == "async" 

350 ): 

351 raise TypeError( 

352 f"The function would be a coroutine because of {factory}, use `ainvoke` instead" 

353 ) 

354 

355 # It's possible this is a forced dependency, and nothing downstream actually needs it. 

356 # In that case, we mark it as forced so it doesn't get its own local var in the generated function. 

357 is_needed = False 

358 for _, _, downstream_deps in dep_tree[ix + 1 :]: 

359 if any( 

360 isinstance(d, FactoryDep) and d.factory == factory 

361 for d in downstream_deps 

362 ): 

363 is_needed = True 

364 break 

365 

366 invocs.append( 

367 Invocation( 

368 factory, 

369 [ 

370 dep.factory if isinstance(dep, FactoryDep) else dep 

371 for dep in deps 

372 ], 

373 not is_needed, 

374 ctx_mgr_kind, 

375 ) 

376 ) 

377 

378 outer_args = [ 

379 dep for node in dep_tree for dep in node[2] if isinstance(dep, ParameterDep) 

380 ] 

381 # We need to do a pass over the outer args to consolidate duplicates. 

382 per_outer_arg: dict[str, List[ParameterDep]] = {} 

383 for arg in outer_args: 

384 per_outer_arg.setdefault(arg.arg_name, []).append(arg) 

385 

386 outer_args.clear() 

387 for arg_name, args in per_outer_arg.items(): 

388 if len(args) == 1: 

389 arg_type = args[0].type 

390 arg_default = args[0].default 

391 else: 

392 # If there are multiple competing argument defs, 

393 # we need to pick a winning type. 

394 arg_type = Signature.empty 

395 arg_default = Signature.empty 

396 for arg in args: 

397 try: 

398 arg_type = _reconcile_types(arg_type, arg.type) 

399 except Exception as exc: 

400 raise IncantError( 

401 f"Unable to reconcile types {arg_type} and {arg.type} for argument {arg_name}" 

402 ) from exc 

403 if arg.default is not Signature.empty: 

404 arg_default = arg.default 

405 outer_args.append(ParameterDep(arg_name, arg_type, arg_default)) 

406 

407 # outer_args need to be sorted by the presence of a default value 

408 outer_args.sort(key=lambda a: a.default is not Signature.empty) 

409 

410 fn_factory_args = [] 

411 fn_factories = [] 

412 for dep in dep_tree[-1][2]: 

413 if isinstance(dep, FactoryDep): 

414 fn_factory_args.append(dep.arg_name) 

415 fn_factories.append(dep.factory) 

416 

417 return compile_invoke( 

418 fn, 

419 fn_factories, 

420 fn_factory_args, 

421 outer_args, 

422 invocs, 

423 is_async=is_async, 

424 ) 

425 

426 

427def _reconcile_types(type_a, type_b): 

428 if type_a is Signature.empty: 

429 return type_b 

430 if type_b is Signature.empty: 

431 return type_a 

432 if type_a is type_b: 

433 return type_a 

434 raise Exception(f"Unable to reconcile types {type_a!r} and {type_b!r}") 

435 

436 

437def _signature(f: Callable) -> Signature: 

438 """Return the signature of f, with potential overrides applied.""" 

439 sig = signature(f) 

440 parameters = [get_annotated_override(val) for val in sig.parameters.values()] 

441 return sig.replace(parameters=parameters) 

442 

443 

444class IncantError(Exception): 

445 """An Incant error."""