Coverage for src/endow/runtime.py: 100%
121 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 13:36 +0200
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-15 13:36 +0200
1"""Runtime graph construction and wiring helpers."""
3from __future__ import annotations
5import inspect
6import sys
7import typing as t
8import warnings
10from .base import Domain, Injectable, Service
12MISSING = object()
15class Graph:
16 """Resolve injectables and runtime values into a shared object graph."""
18 def __init__(self, runtime_inputs: dict[str, t.Any], strict: bool | None = None) -> None:
19 """Store the root runtime inputs and initialize the instance cache."""
20 self.runtime_inputs = runtime_inputs
21 self.strict = strict
22 self.instances: dict[type[Injectable], Injectable] = {}
24 def build[T: Injectable](self, cls: type[T]) -> T:
25 """Build or reuse an injectable instance of the requested type."""
26 cached = self.instances.get(cls)
27 if cached is not None:
28 return t.cast(T, cached)
30 instance, local_inputs, type_inputs = self._make_instance(cls)
31 self.instances[cls] = instance
32 self.instances.setdefault(type(instance), instance)
33 self._wire_instance(instance, local_inputs=local_inputs, type_inputs=type_inputs)
34 return t.cast(T, instance)
36 def _make_instance[T: Injectable](
37 self, cls: type[T]
38 ) -> tuple[Injectable, dict[str, t.Any] | None, dict[str, t.Any] | None]:
39 factory = self._get_factory(cls)
40 if factory is None:
41 return cls(), None, None
43 factory_args, local_inputs, type_inputs = self._collect_factory_args(factory)
44 instance = factory(**factory_args)
45 if not isinstance(instance, Injectable):
46 msg = f"{cls.__name__}.from_env() must return an Injectable instance"
47 raise TypeError(msg)
48 return instance, local_inputs, type_inputs
50 def _wire_instance(
51 self,
52 instance: Injectable,
53 local_inputs: dict[str, t.Any] | None = None,
54 type_inputs: dict[str, t.Any] | None = None,
55 ) -> None:
56 owner_type = type(instance)
57 for name, annotation in iter_injected_fields(owner_type).items():
58 if inspect.getattr_static(instance, name, MISSING) is not MISSING:
59 continue
60 self._check_dependency_direction(owner_type, name, annotation)
61 setattr(instance, name, self._resolve(name, annotation, local_inputs, type_inputs))
63 def _check_dependency_direction(self, owner_type: type[Injectable], name: str, annotation: t.Any) -> None:
64 if self.strict is None:
65 return
67 if not (
68 inspect.isclass(owner_type)
69 and issubclass(owner_type, Service)
70 and inspect.isclass(annotation)
71 and issubclass(annotation, Domain)
72 ):
73 return
75 msg = f"Service '{owner_type.__name__}' should not depend on Domain '{annotation.__name__}' via field '{name}'"
76 if self.strict:
77 raise TypeError(msg)
78 warnings.warn(msg, stacklevel=4)
80 def _resolve(
81 self,
82 name: str,
83 annotation: t.Any,
84 local_inputs: dict[str, t.Any] | None = None,
85 type_inputs: dict[str, t.Any] | None = None,
86 ) -> t.Any:
87 if inspect.isclass(annotation) and issubclass(annotation, Injectable):
88 return self.build(annotation)
90 runtime_value = self._match_runtime_input(name, annotation, local_inputs, type_inputs)
91 if runtime_value is not MISSING:
92 return runtime_value
94 if inspect.isclass(annotation):
95 msg = f"Missing runtime input for field '{name}'"
96 raise TypeError(msg)
98 msg = f"Cannot resolve field '{name}' with annotation {annotation!r}"
99 raise TypeError(msg)
101 def _collect_factory_args(self, factory: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any], dict[str, t.Any]]:
102 args: dict[str, t.Any] = {}
103 local_inputs: dict[str, t.Any] = {}
104 type_inputs: dict[str, t.Any] = {}
105 module = sys.modules[factory.__module__]
106 type_hints = t.get_type_hints(factory, globalns=vars(module))
108 for parameter in inspect.signature(factory).parameters.values():
109 if parameter.kind not in (
110 inspect.Parameter.POSITIONAL_OR_KEYWORD,
111 inspect.Parameter.KEYWORD_ONLY,
112 ):
113 continue
115 annotation = type_hints.get(parameter.name, parameter.annotation)
116 value = self._match_runtime_input(parameter.name, annotation)
117 if value is MISSING:
118 if parameter.default is not inspect._empty:
119 local_inputs[parameter.name] = parameter.default
120 continue
121 msg = f"Missing runtime input '{parameter.name}' for from_env()"
122 raise TypeError(msg)
123 args[parameter.name] = value
124 local_inputs[parameter.name] = value
125 type_inputs[parameter.name] = value
127 return args, local_inputs, type_inputs
129 def _match_runtime_input(
130 self,
131 name: str,
132 annotation: t.Any,
133 local_inputs: dict[str, t.Any] | None = None,
134 type_inputs: dict[str, t.Any] | None = None,
135 ) -> t.Any:
136 for scope in (local_inputs, self.runtime_inputs):
137 if not scope:
138 continue
139 if name in scope:
140 return scope[name]
142 normalized_name = name.lstrip("_")
143 if normalized_name and normalized_name in scope:
144 return scope[normalized_name]
146 if annotation is inspect._empty:
147 return MISSING
149 for scope in (self.runtime_inputs, type_inputs):
150 if not scope:
151 continue
152 for candidate in scope.values():
153 if matches_annotation(candidate, annotation):
154 return candidate
156 return MISSING
158 @staticmethod
159 def _get_factory(cls: type[Injectable]) -> t.Any | None:
160 from .backend import BackendBase
162 for base in cls.__mro__:
163 descriptor = base.__dict__.get("from_env")
164 if descriptor is None:
165 continue
166 if base is BackendBase:
167 continue
168 return descriptor.__get__(None, cls)
169 return None
172def build_graph[T: Injectable](
173 cls: type[T],
174 runtime_inputs: dict[str, t.Any],
175 strict: bool | None = None,
176) -> T:
177 """Build a dependency graph rooted at ``cls``."""
178 graph = Graph(runtime_inputs, strict=strict)
179 return graph.build(cls)
182def iter_injected_fields(cls: type[Injectable]) -> dict[str, t.Any]:
183 """Return annotated injected fields declared on ``cls`` and its bases."""
184 fields: dict[str, t.Any] = {}
186 for base in reversed(cls.__mro__):
187 if not issubclass(base, Injectable) or base is Injectable:
188 continue
190 module = sys.modules[base.__module__]
191 fields.update(t.get_type_hints(base, globalns=vars(module)))
193 return fields
196def matches_annotation(value: t.Any, annotation: t.Any) -> bool:
197 """Return whether ``value`` matches a concrete class annotation."""
198 return inspect.isclass(annotation) and isinstance(value, annotation)