Coverage for src/endow/runtime.py: 100%
119 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-09 19:17 +0200
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-09 19:17 +0200
1"""Runtime graph construction and wiring helpers."""
3from __future__ import annotations
5import inspect
6import sys
7import typing as t
8import warnings
10from .base import Domain, Injectable, Service
12MISSING = object()
15class Graph:
16 """Resolve injectables and runtime values into a shared object graph."""
18 def __init__(self, runtime_inputs: dict[str, t.Any], strict: bool | None = None) -> None:
19 """Store the root runtime inputs and initialize the instance cache."""
20 self.runtime_inputs = runtime_inputs
21 self.strict = strict
22 self.instances: dict[type[Injectable], Injectable] = {}
24 def build[T: Injectable](self, cls: type[T]) -> T:
25 """Build or reuse an injectable instance of the requested type."""
26 cached = self.instances.get(cls)
27 if cached is not None:
28 return t.cast(T, cached)
30 instance, local_inputs, type_inputs = self._make_instance(cls)
31 self.instances[cls] = instance
32 self.instances.setdefault(type(instance), instance)
33 self._wire_instance(instance, local_inputs=local_inputs, type_inputs=type_inputs)
34 return t.cast(T, instance)
36 def _make_instance[T: Injectable](
37 self, cls: type[T]
38 ) -> tuple[Injectable, dict[str, t.Any] | None, dict[str, t.Any] | None]:
39 factory = self._get_factory(cls)
40 if factory is None:
41 return cls(), None, None
43 factory_args, local_inputs, type_inputs = self._collect_factory_args(factory)
44 instance = factory(**factory_args)
45 if not isinstance(instance, Injectable):
46 msg = f"{cls.__name__}.from_env() must return an Injectable instance"
47 raise TypeError(msg)
48 return instance, local_inputs, type_inputs
50 def _wire_instance(
51 self,
52 instance: Injectable,
53 local_inputs: dict[str, t.Any] | None = None,
54 type_inputs: dict[str, t.Any] | None = None,
55 ) -> None:
56 owner_type = type(instance)
57 for name, annotation in iter_injected_fields(owner_type).items():
58 self._check_dependency_direction(owner_type, name, annotation)
59 setattr(instance, name, self._resolve(name, annotation, local_inputs, type_inputs))
61 def _check_dependency_direction(self, owner_type: type[Injectable], name: str, annotation: t.Any) -> None:
62 if self.strict is None:
63 return
65 if not (
66 inspect.isclass(owner_type)
67 and issubclass(owner_type, Service)
68 and inspect.isclass(annotation)
69 and issubclass(annotation, Domain)
70 ):
71 return
73 msg = f"Service '{owner_type.__name__}' should not depend on Domain '{annotation.__name__}' via field '{name}'"
74 if self.strict:
75 raise TypeError(msg)
76 warnings.warn(msg, stacklevel=4)
78 def _resolve(
79 self,
80 name: str,
81 annotation: t.Any,
82 local_inputs: dict[str, t.Any] | None = None,
83 type_inputs: dict[str, t.Any] | None = None,
84 ) -> t.Any:
85 if inspect.isclass(annotation) and issubclass(annotation, Injectable):
86 return self.build(annotation)
88 runtime_value = self._match_runtime_input(name, annotation, local_inputs, type_inputs)
89 if runtime_value is not MISSING:
90 return runtime_value
92 if inspect.isclass(annotation):
93 msg = f"Missing runtime input for field '{name}'"
94 raise TypeError(msg)
96 msg = f"Cannot resolve field '{name}' with annotation {annotation!r}"
97 raise TypeError(msg)
99 def _collect_factory_args(self, factory: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any], dict[str, t.Any]]:
100 args: dict[str, t.Any] = {}
101 local_inputs: dict[str, t.Any] = {}
102 type_inputs: dict[str, t.Any] = {}
103 module = sys.modules[factory.__module__]
104 type_hints = t.get_type_hints(factory, globalns=vars(module))
106 for parameter in inspect.signature(factory).parameters.values():
107 if parameter.kind not in (
108 inspect.Parameter.POSITIONAL_OR_KEYWORD,
109 inspect.Parameter.KEYWORD_ONLY,
110 ):
111 continue
113 annotation = type_hints.get(parameter.name, parameter.annotation)
114 value = self._match_runtime_input(parameter.name, annotation)
115 if value is MISSING:
116 if parameter.default is not inspect._empty:
117 local_inputs[parameter.name] = parameter.default
118 continue
119 msg = f"Missing runtime input '{parameter.name}' for from_env()"
120 raise TypeError(msg)
121 args[parameter.name] = value
122 local_inputs[parameter.name] = value
123 type_inputs[parameter.name] = value
125 return args, local_inputs, type_inputs
127 def _match_runtime_input(
128 self,
129 name: str,
130 annotation: t.Any,
131 local_inputs: dict[str, t.Any] | None = None,
132 type_inputs: dict[str, t.Any] | None = None,
133 ) -> t.Any:
134 for scope in (local_inputs, self.runtime_inputs):
135 if not scope:
136 continue
137 if name in scope:
138 return scope[name]
140 normalized_name = name.lstrip("_")
141 if normalized_name and normalized_name in scope:
142 return scope[normalized_name]
144 if annotation is inspect._empty:
145 return MISSING
147 for scope in (self.runtime_inputs, type_inputs):
148 if not scope:
149 continue
150 for candidate in scope.values():
151 if matches_annotation(candidate, annotation):
152 return candidate
154 return MISSING
156 @staticmethod
157 def _get_factory(cls: type[Injectable]) -> t.Any | None:
158 from .backend import BackendBase
160 for base in cls.__mro__:
161 descriptor = base.__dict__.get("from_env")
162 if descriptor is None:
163 continue
164 if base is BackendBase:
165 continue
166 return descriptor.__get__(None, cls)
167 return None
170def build_graph[T: Injectable](
171 cls: type[T],
172 runtime_inputs: dict[str, t.Any],
173 strict: bool | None = None,
174) -> T:
175 """Build a dependency graph rooted at ``cls``."""
176 graph = Graph(runtime_inputs, strict=strict)
177 return graph.build(cls)
180def iter_injected_fields(cls: type[Injectable]) -> dict[str, t.Any]:
181 """Return annotated injected fields declared on ``cls`` and its bases."""
182 fields: dict[str, t.Any] = {}
184 for base in reversed(cls.__mro__):
185 if not issubclass(base, Injectable) or base is Injectable:
186 continue
188 module = sys.modules[base.__module__]
189 fields.update(t.get_type_hints(base, globalns=vars(module)))
191 return fields
194def matches_annotation(value: t.Any, annotation: t.Any) -> bool:
195 """Return whether ``value`` matches a concrete class annotation."""
196 return inspect.isclass(annotation) and isinstance(value, annotation)