Coverage for src/endow/runtime.py: 100%

119 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-09 19:17 +0200

1"""Runtime graph construction and wiring helpers.""" 

2 

3from __future__ import annotations 

4 

5import inspect 

6import sys 

7import typing as t 

8import warnings 

9 

10from .base import Domain, Injectable, Service 

11 

12MISSING = object() 

13 

14 

15class Graph: 

16 """Resolve injectables and runtime values into a shared object graph.""" 

17 

18 def __init__(self, runtime_inputs: dict[str, t.Any], strict: bool | None = None) -> None: 

19 """Store the root runtime inputs and initialize the instance cache.""" 

20 self.runtime_inputs = runtime_inputs 

21 self.strict = strict 

22 self.instances: dict[type[Injectable], Injectable] = {} 

23 

24 def build[T: Injectable](self, cls: type[T]) -> T: 

25 """Build or reuse an injectable instance of the requested type.""" 

26 cached = self.instances.get(cls) 

27 if cached is not None: 

28 return t.cast(T, cached) 

29 

30 instance, local_inputs, type_inputs = self._make_instance(cls) 

31 self.instances[cls] = instance 

32 self.instances.setdefault(type(instance), instance) 

33 self._wire_instance(instance, local_inputs=local_inputs, type_inputs=type_inputs) 

34 return t.cast(T, instance) 

35 

36 def _make_instance[T: Injectable]( 

37 self, cls: type[T] 

38 ) -> tuple[Injectable, dict[str, t.Any] | None, dict[str, t.Any] | None]: 

39 factory = self._get_factory(cls) 

40 if factory is None: 

41 return cls(), None, None 

42 

43 factory_args, local_inputs, type_inputs = self._collect_factory_args(factory) 

44 instance = factory(**factory_args) 

45 if not isinstance(instance, Injectable): 

46 msg = f"{cls.__name__}.from_env() must return an Injectable instance" 

47 raise TypeError(msg) 

48 return instance, local_inputs, type_inputs 

49 

50 def _wire_instance( 

51 self, 

52 instance: Injectable, 

53 local_inputs: dict[str, t.Any] | None = None, 

54 type_inputs: dict[str, t.Any] | None = None, 

55 ) -> None: 

56 owner_type = type(instance) 

57 for name, annotation in iter_injected_fields(owner_type).items(): 

58 self._check_dependency_direction(owner_type, name, annotation) 

59 setattr(instance, name, self._resolve(name, annotation, local_inputs, type_inputs)) 

60 

61 def _check_dependency_direction(self, owner_type: type[Injectable], name: str, annotation: t.Any) -> None: 

62 if self.strict is None: 

63 return 

64 

65 if not ( 

66 inspect.isclass(owner_type) 

67 and issubclass(owner_type, Service) 

68 and inspect.isclass(annotation) 

69 and issubclass(annotation, Domain) 

70 ): 

71 return 

72 

73 msg = f"Service '{owner_type.__name__}' should not depend on Domain '{annotation.__name__}' via field '{name}'" 

74 if self.strict: 

75 raise TypeError(msg) 

76 warnings.warn(msg, stacklevel=4) 

77 

78 def _resolve( 

79 self, 

80 name: str, 

81 annotation: t.Any, 

82 local_inputs: dict[str, t.Any] | None = None, 

83 type_inputs: dict[str, t.Any] | None = None, 

84 ) -> t.Any: 

85 if inspect.isclass(annotation) and issubclass(annotation, Injectable): 

86 return self.build(annotation) 

87 

88 runtime_value = self._match_runtime_input(name, annotation, local_inputs, type_inputs) 

89 if runtime_value is not MISSING: 

90 return runtime_value 

91 

92 if inspect.isclass(annotation): 

93 msg = f"Missing runtime input for field '{name}'" 

94 raise TypeError(msg) 

95 

96 msg = f"Cannot resolve field '{name}' with annotation {annotation!r}" 

97 raise TypeError(msg) 

98 

99 def _collect_factory_args(self, factory: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any], dict[str, t.Any]]: 

100 args: dict[str, t.Any] = {} 

101 local_inputs: dict[str, t.Any] = {} 

102 type_inputs: dict[str, t.Any] = {} 

103 module = sys.modules[factory.__module__] 

104 type_hints = t.get_type_hints(factory, globalns=vars(module)) 

105 

106 for parameter in inspect.signature(factory).parameters.values(): 

107 if parameter.kind not in ( 

108 inspect.Parameter.POSITIONAL_OR_KEYWORD, 

109 inspect.Parameter.KEYWORD_ONLY, 

110 ): 

111 continue 

112 

113 annotation = type_hints.get(parameter.name, parameter.annotation) 

114 value = self._match_runtime_input(parameter.name, annotation) 

115 if value is MISSING: 

116 if parameter.default is not inspect._empty: 

117 local_inputs[parameter.name] = parameter.default 

118 continue 

119 msg = f"Missing runtime input '{parameter.name}' for from_env()" 

120 raise TypeError(msg) 

121 args[parameter.name] = value 

122 local_inputs[parameter.name] = value 

123 type_inputs[parameter.name] = value 

124 

125 return args, local_inputs, type_inputs 

126 

127 def _match_runtime_input( 

128 self, 

129 name: str, 

130 annotation: t.Any, 

131 local_inputs: dict[str, t.Any] | None = None, 

132 type_inputs: dict[str, t.Any] | None = None, 

133 ) -> t.Any: 

134 for scope in (local_inputs, self.runtime_inputs): 

135 if not scope: 

136 continue 

137 if name in scope: 

138 return scope[name] 

139 

140 normalized_name = name.lstrip("_") 

141 if normalized_name and normalized_name in scope: 

142 return scope[normalized_name] 

143 

144 if annotation is inspect._empty: 

145 return MISSING 

146 

147 for scope in (self.runtime_inputs, type_inputs): 

148 if not scope: 

149 continue 

150 for candidate in scope.values(): 

151 if matches_annotation(candidate, annotation): 

152 return candidate 

153 

154 return MISSING 

155 

156 @staticmethod 

157 def _get_factory(cls: type[Injectable]) -> t.Any | None: 

158 from .backend import BackendBase 

159 

160 for base in cls.__mro__: 

161 descriptor = base.__dict__.get("from_env") 

162 if descriptor is None: 

163 continue 

164 if base is BackendBase: 

165 continue 

166 return descriptor.__get__(None, cls) 

167 return None 

168 

169 

170def build_graph[T: Injectable]( 

171 cls: type[T], 

172 runtime_inputs: dict[str, t.Any], 

173 strict: bool | None = None, 

174) -> T: 

175 """Build a dependency graph rooted at ``cls``.""" 

176 graph = Graph(runtime_inputs, strict=strict) 

177 return graph.build(cls) 

178 

179 

180def iter_injected_fields(cls: type[Injectable]) -> dict[str, t.Any]: 

181 """Return annotated injected fields declared on ``cls`` and its bases.""" 

182 fields: dict[str, t.Any] = {} 

183 

184 for base in reversed(cls.__mro__): 

185 if not issubclass(base, Injectable) or base is Injectable: 

186 continue 

187 

188 module = sys.modules[base.__module__] 

189 fields.update(t.get_type_hints(base, globalns=vars(module))) 

190 

191 return fields 

192 

193 

194def matches_annotation(value: t.Any, annotation: t.Any) -> bool: 

195 """Return whether ``value`` matches a concrete class annotation.""" 

196 return inspect.isclass(annotation) and isinstance(value, annotation)