Coverage for greyhorse/app/private/functional/operators.py: 79%

227 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2026-05-12 22:07 +0300

1from __future__ import annotations 

2 

3import contextlib 

4import enum 

5from collections.abc import Collection, Generator, Iterable 

6from dataclasses import dataclass 

7from enum import IntEnum, auto 

8from typing import Any, override 

9 

10from greyhorse.app.abc.functional.context import ( 

11 OperationContext, 

12 OperatorContext as OperatorContextABC, 

13) 

14from greyhorse.app.abc.functional.operators import Operator 

15from greyhorse.app.private.resolving.schemas import ( 

16 PlanAction, 

17 PlanResolveError, 

18 ResolvedData, 

19 ResolveResult, 

20) 

21from greyhorse.factory import Factory, into_factory 

22from greyhorse.result import Err, Ok, Result 

23from greyhorse.utils.types import is_maybe, is_optional, unwrap 

24 

25from ..fragment import Fragment 

26from ..resolving import Resolver, ValueResolver, _type_to_param_name 

27from ..runtime.invoke import invoke_sync 

28from .linker import FragmentLinker, Linkage 

29 

30 

31# --------------------------------------------------------------------------- 

32# Legacy ContextOperator — used by component.py (removal deferred) 

33# --------------------------------------------------------------------------- 

34 

35 

36class ContextOperator[T]: 

37 __slots__ = ( 

38 '_compiled', 

39 '_external_params', 

40 '_functor', 

41 '_resolved_types', 

42 '_resolver', 

43 '_scope', 

44 ) 

45 

46 def __init__( 

47 self, 

48 functor: Factory[T], 

49 resolver: Resolver, 

50 scope: IntEnum | None = None, 

51 external_params: Collection[type] | None = None, 

52 ) -> None: 

53 self._resolver = resolver 

54 self._scope = scope 

55 self._external_params: set[type] = set(external_params or []) 

56 self._resolved_types: dict[type, ResolvedData] = { 

57 functor.return_type: ResolvedData( 

58 type=functor.return_type, factory=functor, deps=dict(functor.actual_params) 

59 ) 

60 } 

61 self._compiled = False 

62 self._functor = functor 

63 

64 @property 

65 def return_type(self) -> type[T]: 

66 return self._functor.return_type 

67 

68 @property 

69 def compiled(self) -> bool: 

70 return self._compiled 

71 

72 @classmethod 

73 def from_context( 

74 cls, 

75 context: OperationContext, 

76 functor: Factory[T], 

77 scope: IntEnum | None, 

78 external_params: Collection[type] | None = None, 

79 ) -> ContextOperator[T]: 

80 resolver = context.context_resolver(scope) 

81 return cls(functor, resolver, scope, external_params) 

82 

83 def compile(self) -> Any: 

84 if self.compiled: 

85 return ResolveResult(resolved=self._resolved_types) 

86 

87 res = self._resolver.resolve_factories(self.return_type, self._scope) 

88 self._resolved_types.update(res.resolved) 

89 self._compiled = not res.unresolved 

90 return res 

91 

92 def get_functor(self) -> Factory[T]: 

93 self.compile() 

94 external_params = { 

95 param_name: param_type 

96 for param_name, param_type in self._functor.actual_params.items() 

97 if param_type in self._external_params or not self._resolver.can_resolve(param_type) 

98 } 

99 

100 if len(external_params) >= len(self._functor.actual_params): 

101 return self._functor 

102 

103 return Factory[self.return_type].from_syncgen( 

104 self._entrypoint, 

105 cache_policy=self._functor.cache_policy, 

106 hints={ 

107 'return': Generator[self.return_type, self.return_type, None], 

108 **external_params, 

109 }, 

110 ) 

111 

112 def _entrypoint(self, *args: Any, **kwargs: Any) -> Generator[T, T, None]: 

113 resolved = self._resolved_types.copy() 

114 with self._resolver as resolver: 

115 for param_idx, (param_name, param_type) in enumerate( 

116 self._functor.actual_params.items() 

117 ): 

118 if param_idx < len(args) and issubclass(type(args[param_idx]), param_type): 

119 resolved[param_type] = ResolvedData( 

120 type=param_type, factory=into_factory(args[param_idx]) 

121 ) 

122 continue 

123 if param_name in kwargs and param_type in self._external_params: 

124 resolved[param_type] = ResolvedData( 

125 type=param_type, factory=into_factory(kwargs[param_name]) 

126 ) 

127 continue 

128 instance = resolver.resolve_value(resolved, self.return_type, self._scope) 

129 assert instance.is_just() 

130 yield instance.unwrap() 

131 

132 

133_LIFECYCLE_METHODS = frozenset({'setup', 'resume', 'pause', 'teardown', '__call__'}) 

134 

135 

136class _State(enum.Enum): 

137 IDLE = auto() 

138 ACTIVE = auto() 

139 RESUMED = auto() 

140 

141 

142@dataclass(slots=True, kw_only=True) 

143class _ParamLayout: 

144 """Pre-computed operator __init__ parameter classification.""" 

145 

146 plain_type_map: dict[str, type] # name → unwrapped type 

147 required_types: list[type] 

148 optional_types: list[type] 

149 

150 

151def _validate_operator_class(cls: type[Operator]) -> None: 

152 for klass in cls.__mro__: 

153 if klass is Operator: 

154 break 

155 if any(m in klass.__dict__ for m in _LIFECYCLE_METHODS): 

156 return 

157 raise ValueError( 

158 f'{cls.__qualname__} must override at least one of: ' 

159 + ', '.join(sorted(_LIFECYCLE_METHODS)) 

160 ) 

161 

162 

163def _build_param_layout(param_info: dict[str, type]) -> _ParamLayout: 

164 plain_type_map: dict[str, type] = {} 

165 required_types: list[type] = [] 

166 optional_types: list[type] = [] 

167 for name, raw_type in param_info.items(): 

168 plain = unwrap(raw_type) 

169 plain_type_map[name] = plain 

170 if is_maybe(raw_type) or is_optional(raw_type): 

171 optional_types.append(plain) 

172 else: 

173 required_types.append(plain) 

174 return _ParamLayout( 

175 plain_type_map=plain_type_map, 

176 required_types=required_types, 

177 optional_types=optional_types, 

178 ) 

179 

180 

181def _check_name_collisions(linkage: Linkage) -> Result[None, PlanResolveError]: 

182 seen: dict[str, type] = {} 

183 for _, plan in linkage.resolved.values(): 

184 for action in plan.actions: 

185 if isinstance(action, PlanAction.ProviderCall): 

186 name = _type_to_param_name(action.target_type) 

187 if name in seen and seen[name] is not action.target_type: 

188 return PlanResolveError.NameCollision( 

189 name=name, 

190 type_a=seen[name].__qualname__, 

191 type_b=action.target_type.__qualname__, 

192 ).to_result() 

193 seen[name] = action.target_type 

194 return Ok(None) 

195 

196 

197def _resolve_instances(linkage: Linkage, external: dict[str, Any]) -> dict[type, Any]: 

198 resolvers: dict[Fragment, ValueResolver] = {} 

199 instances: dict[type, Any] = {} 

200 try: 

201 # Open resolvers in topological order (dependencies first) so that 

202 # reversed(resolvers) during teardown closes dependents before dependencies. 

203 for t in linkage.order: 

204 frag, _ = linkage.resolved[t] 

205 if frag not in resolvers: 

206 resolver = ValueResolver(frag._bucket) # noqa: SLF001 

207 resolver.__enter__() 

208 resolvers[frag] = resolver 

209 

210 for t in linkage.order: 

211 frag, plan = linkage.resolved[t] 

212 resolver = resolvers[frag] 

213 

214 kwargs: dict[str, Any] = {} 

215 for action in plan.actions: 

216 if isinstance(action, PlanAction.ProviderCall): 

217 dep_name = _type_to_param_name(action.target_type) 

218 if action.target_type in instances: 

219 kwargs[dep_name] = instances[action.target_type] 

220 elif dep_name in external: 

221 kwargs[dep_name] = external[dep_name] 

222 

223 match resolver.resolve(plan, t, **kwargs): 

224 case Ok(value): 

225 instances[t] = value 

226 case Err(err): 

227 raise RuntimeError(f'Failed to resolve {t.__qualname__}: {err.message}') 

228 

229 except Exception: 

230 for resolver in reversed(list(resolvers.values())): 

231 with contextlib.suppress(Exception): 

232 resolver.__exit__(None, None, None) 

233 raise 

234 

235 return instances, resolvers # type: ignore[return-value] 

236 

237 

238class OperatorContext(OperatorContextABC): 

239 __slots__ = ( 

240 '_instances', 

241 '_linker', 

242 '_operator', 

243 '_operator_factory', 

244 '_param_info', 

245 '_resolvers', 

246 '_scope', 

247 '_state', 

248 ) 

249 

250 def __init__( 

251 self, 

252 operator_class: type[Operator], 

253 fragments: Iterable[Fragment], 

254 scope: IntEnum | None = None, 

255 ) -> None: 

256 _validate_operator_class(operator_class) 

257 self._operator_factory: Factory[Operator] = Factory.from_class(operator_class) 

258 self._param_info: dict[str, type] = dict(self._operator_factory.actual_params) 

259 self._linker = FragmentLinker(fragments) 

260 self._scope = scope 

261 self._resolvers: dict[Fragment, ValueResolver] = {} 

262 self._instances: dict[type, Any] = {} 

263 self._operator: Operator | None = None 

264 self._state = _State.IDLE 

265 

266 @property 

267 @override 

268 def operator(self) -> Operator: 

269 if self._operator is None: 

270 raise RuntimeError('not set up') 

271 return self._operator 

272 

273 @override 

274 def setup(self, **external: Any) -> None: 

275 if self._state != _State.IDLE: 

276 raise RuntimeError('already set up') 

277 

278 layout = _build_param_layout(self._param_info) 

279 

280 # 1. Link dependencies across fragments 

281 match self._linker.link(layout.required_types, layout.optional_types, self._scope): 

282 case Err(err): 

283 raise RuntimeError(err.message) 

284 case Ok(linkage): 

285 pass 

286 

287 # 2. Check pending: required types not in any fragment must come from external 

288 if linkage.pending: 

289 unsatisfied = [ 

290 pc 

291 for pc in linkage.pending 

292 if _type_to_param_name(pc.target_type) not in external 

293 ] 

294 if unsatisfied: 

295 types = ', '.join(pc.target_type.__qualname__ for pc in unsatisfied) 

296 raise RuntimeError(f'Unresolved providers: {types}') 

297 

298 # 3. Detect provider name collisions early (before touching buckets) 

299 match _check_name_collisions(linkage): 

300 case Err(err): 

301 raise RuntimeError(err.message) 

302 case _: 

303 pass 

304 

305 # 4-5. Create one resolver per fragment and resolve in topological order 

306 instances, resolvers = _resolve_instances(linkage, external) 

307 

308 # 6. Build operator constructor kwargs 

309 op_kwargs: dict[str, Any] = {} 

310 for name in self._param_info: 

311 plain = layout.plain_type_map[name] 

312 if plain in instances: 

313 op_kwargs[name] = instances[plain] 

314 else: 

315 dep_name = _type_to_param_name(plain) 

316 if dep_name in external: 

317 op_kwargs[name] = external[dep_name] 

318 # optional param absent → _prepare_kwargs fills Nothing / None 

319 

320 # 7. Instantiate operator and run its setup 

321 self._operator = None 

322 try: 

323 self._operator = invoke_sync(self._operator_factory.create, **op_kwargs) 

324 invoke_sync(self._operator.setup) 

325 except Exception: 

326 if self._operator is not None: 

327 with contextlib.suppress(Exception): 

328 invoke_sync(self._operator.teardown) 

329 for resolver in reversed(list(resolvers.values())): 

330 with contextlib.suppress(Exception): 

331 resolver.__exit__(None, None, None) 

332 raise 

333 

334 # 8. Commit — only reached on success 

335 self._resolvers = resolvers 

336 self._instances = instances 

337 self._state = _State.ACTIVE 

338 

339 @override 

340 def resume(self) -> None: 

341 if self._state != _State.ACTIVE: 

342 raise RuntimeError('not active') 

343 invoke_sync(self._operator.resume) 

344 self._state = _State.RESUMED 

345 

346 @override 

347 def pause(self) -> None: 

348 if self._state != _State.RESUMED: 

349 raise RuntimeError('not resumed') 

350 invoke_sync(self._operator.pause) 

351 self._state = _State.ACTIVE 

352 

353 def call(self, *args: Any, **kwargs: Any) -> Any: 

354 if self._state != _State.RESUMED: 

355 raise RuntimeError('not resumed') 

356 return invoke_sync(self._operator.__call__, *args, **kwargs) 

357 

358 @override 

359 def teardown(self) -> None: 

360 if self._state == _State.IDLE: 

361 return 

362 

363 if self._state == _State.RESUMED: 

364 with contextlib.suppress(Exception): 

365 invoke_sync(self._operator.pause) 

366 

367 with contextlib.suppress(Exception): 

368 invoke_sync(self._operator.teardown) 

369 

370 for resolver in reversed(list(self._resolvers.values())): 

371 with contextlib.suppress(Exception): 

372 resolver.__exit__(None, None, None) 

373 

374 self._operator = None 

375 self._resolvers = {} 

376 self._instances = {} 

377 self._state = _State.IDLE 

378 

379 

380__all__ = ['OperatorContext']