Coverage for src/endow/runtime.py: 100%

121 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-15 13:36 +0200

1"""Runtime graph construction and wiring helpers.""" 

2 

3from __future__ import annotations 

4 

5import inspect 

6import sys 

7import typing as t 

8import warnings 

9 

10from .base import Domain, Injectable, Service 

11 

12MISSING = object() 

13 

14 

15class Graph: 

16 """Resolve injectables and runtime values into a shared object graph.""" 

17 

18 def __init__(self, runtime_inputs: dict[str, t.Any], strict: bool | None = None) -> None: 

19 """Store the root runtime inputs and initialize the instance cache.""" 

20 self.runtime_inputs = runtime_inputs 

21 self.strict = strict 

22 self.instances: dict[type[Injectable], Injectable] = {} 

23 

24 def build[T: Injectable](self, cls: type[T]) -> T: 

25 """Build or reuse an injectable instance of the requested type.""" 

26 cached = self.instances.get(cls) 

27 if cached is not None: 

28 return t.cast(T, cached) 

29 

30 instance, local_inputs, type_inputs = self._make_instance(cls) 

31 self.instances[cls] = instance 

32 self.instances.setdefault(type(instance), instance) 

33 self._wire_instance(instance, local_inputs=local_inputs, type_inputs=type_inputs) 

34 return t.cast(T, instance) 

35 

36 def _make_instance[T: Injectable]( 

37 self, cls: type[T] 

38 ) -> tuple[Injectable, dict[str, t.Any] | None, dict[str, t.Any] | None]: 

39 factory = self._get_factory(cls) 

40 if factory is None: 

41 return cls(), None, None 

42 

43 factory_args, local_inputs, type_inputs = self._collect_factory_args(factory) 

44 instance = factory(**factory_args) 

45 if not isinstance(instance, Injectable): 

46 msg = f"{cls.__name__}.from_env() must return an Injectable instance" 

47 raise TypeError(msg) 

48 return instance, local_inputs, type_inputs 

49 

50 def _wire_instance( 

51 self, 

52 instance: Injectable, 

53 local_inputs: dict[str, t.Any] | None = None, 

54 type_inputs: dict[str, t.Any] | None = None, 

55 ) -> None: 

56 owner_type = type(instance) 

57 for name, annotation in iter_injected_fields(owner_type).items(): 

58 if inspect.getattr_static(instance, name, MISSING) is not MISSING: 

59 continue 

60 self._check_dependency_direction(owner_type, name, annotation) 

61 setattr(instance, name, self._resolve(name, annotation, local_inputs, type_inputs)) 

62 

63 def _check_dependency_direction(self, owner_type: type[Injectable], name: str, annotation: t.Any) -> None: 

64 if self.strict is None: 

65 return 

66 

67 if not ( 

68 inspect.isclass(owner_type) 

69 and issubclass(owner_type, Service) 

70 and inspect.isclass(annotation) 

71 and issubclass(annotation, Domain) 

72 ): 

73 return 

74 

75 msg = f"Service '{owner_type.__name__}' should not depend on Domain '{annotation.__name__}' via field '{name}'" 

76 if self.strict: 

77 raise TypeError(msg) 

78 warnings.warn(msg, stacklevel=4) 

79 

80 def _resolve( 

81 self, 

82 name: str, 

83 annotation: t.Any, 

84 local_inputs: dict[str, t.Any] | None = None, 

85 type_inputs: dict[str, t.Any] | None = None, 

86 ) -> t.Any: 

87 if inspect.isclass(annotation) and issubclass(annotation, Injectable): 

88 return self.build(annotation) 

89 

90 runtime_value = self._match_runtime_input(name, annotation, local_inputs, type_inputs) 

91 if runtime_value is not MISSING: 

92 return runtime_value 

93 

94 if inspect.isclass(annotation): 

95 msg = f"Missing runtime input for field '{name}'" 

96 raise TypeError(msg) 

97 

98 msg = f"Cannot resolve field '{name}' with annotation {annotation!r}" 

99 raise TypeError(msg) 

100 

101 def _collect_factory_args(self, factory: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any], dict[str, t.Any]]: 

102 args: dict[str, t.Any] = {} 

103 local_inputs: dict[str, t.Any] = {} 

104 type_inputs: dict[str, t.Any] = {} 

105 module = sys.modules[factory.__module__] 

106 type_hints = t.get_type_hints(factory, globalns=vars(module)) 

107 

108 for parameter in inspect.signature(factory).parameters.values(): 

109 if parameter.kind not in ( 

110 inspect.Parameter.POSITIONAL_OR_KEYWORD, 

111 inspect.Parameter.KEYWORD_ONLY, 

112 ): 

113 continue 

114 

115 annotation = type_hints.get(parameter.name, parameter.annotation) 

116 value = self._match_runtime_input(parameter.name, annotation) 

117 if value is MISSING: 

118 if parameter.default is not inspect._empty: 

119 local_inputs[parameter.name] = parameter.default 

120 continue 

121 msg = f"Missing runtime input '{parameter.name}' for from_env()" 

122 raise TypeError(msg) 

123 args[parameter.name] = value 

124 local_inputs[parameter.name] = value 

125 type_inputs[parameter.name] = value 

126 

127 return args, local_inputs, type_inputs 

128 

129 def _match_runtime_input( 

130 self, 

131 name: str, 

132 annotation: t.Any, 

133 local_inputs: dict[str, t.Any] | None = None, 

134 type_inputs: dict[str, t.Any] | None = None, 

135 ) -> t.Any: 

136 for scope in (local_inputs, self.runtime_inputs): 

137 if not scope: 

138 continue 

139 if name in scope: 

140 return scope[name] 

141 

142 normalized_name = name.lstrip("_") 

143 if normalized_name and normalized_name in scope: 

144 return scope[normalized_name] 

145 

146 if annotation is inspect._empty: 

147 return MISSING 

148 

149 for scope in (self.runtime_inputs, type_inputs): 

150 if not scope: 

151 continue 

152 for candidate in scope.values(): 

153 if matches_annotation(candidate, annotation): 

154 return candidate 

155 

156 return MISSING 

157 

158 @staticmethod 

159 def _get_factory(cls: type[Injectable]) -> t.Any | None: 

160 from .backend import BackendBase 

161 

162 for base in cls.__mro__: 

163 descriptor = base.__dict__.get("from_env") 

164 if descriptor is None: 

165 continue 

166 if base is BackendBase: 

167 continue 

168 return descriptor.__get__(None, cls) 

169 return None 

170 

171 

172def build_graph[T: Injectable]( 

173 cls: type[T], 

174 runtime_inputs: dict[str, t.Any], 

175 strict: bool | None = None, 

176) -> T: 

177 """Build a dependency graph rooted at ``cls``.""" 

178 graph = Graph(runtime_inputs, strict=strict) 

179 return graph.build(cls) 

180 

181 

182def iter_injected_fields(cls: type[Injectable]) -> dict[str, t.Any]: 

183 """Return annotated injected fields declared on ``cls`` and its bases.""" 

184 fields: dict[str, t.Any] = {} 

185 

186 for base in reversed(cls.__mro__): 

187 if not issubclass(base, Injectable) or base is Injectable: 

188 continue 

189 

190 module = sys.modules[base.__module__] 

191 fields.update(t.get_type_hints(base, globalns=vars(module))) 

192 

193 return fields 

194 

195 

196def matches_annotation(value: t.Any, annotation: t.Any) -> bool: 

197 """Return whether ``value`` matches a concrete class annotation.""" 

198 return inspect.isclass(annotation) and isinstance(value, annotation)