Coverage for agentos/tools/di_container.py: 0%

99 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 07:37 +0800

1""" 

2DIContainer — lightweight dependency injection container. 

3 

4Supports: 

5 - Singleton and transient lifetimes 

6 - Constructor autowiring via type annotations 

7 - Factory registration 

8 - Instance registration 

9 - Scoped sub-containers (snapshot-based) 

10 - Circular dependency detection 

11""" 

12 

13from __future__ import annotations 

14 

15import inspect 

16from enum import Enum 

17from threading import RLock 

18from typing import Any, Callable, Dict, Optional, Set, Type, TypeVar, get_type_hints 

19 

20T = TypeVar("T") 

21 

22 

23# ============================================================================ 

24# Lifetime 

25# ============================================================================ 

26 

27class Lifetime(Enum): 

28 SINGLETON = "singleton" 

29 TRANSIENT = "transient" 

30 

31 

32# ============================================================================ 

33# Registration 

34# ============================================================================ 

35 

36class Registration: 

37 __slots__ = ("interface", "implementation", "lifetime", "factory", "instance", "instance_lock") 

38 

39 def __init__( 

40 self, 

41 interface: Type, 

42 implementation: Optional[Type] = None, 

43 lifetime: Lifetime = Lifetime.TRANSIENT, 

44 factory: Optional[Callable[[], Any]] = None, 

45 instance: Any = None, 

46 ): 

47 self.interface = interface 

48 self.implementation = implementation 

49 self.lifetime = lifetime 

50 self.factory = factory 

51 self.instance = instance 

52 self.instance_lock = RLock() 

53 

54 

55# ============================================================================ 

56# Circular Dependency Error 

57# ============================================================================ 

58 

59class CircularDependencyError(Exception): 

60 def __init__(self, chain: list): 

61 self.chain = chain 

62 super().__init__(f"Circular dependency detected: {' → '.join(str(c) for c in chain)}") 

63 

64 

65# ============================================================================ 

66# DIContainer 

67# ============================================================================ 

68 

69class DIContainer: 

70 """Lightweight dependency injection container. 

71 

72 Usage: 

73 container = DIContainer() 

74 

75 # Register singleton 

76 container.register(AbstractDB, ConcretePostgres, Lifetime.SINGLETON) 

77 

78 # Register transient 

79 container.register(AbstractCache, RedisCache, Lifetime.TRANSIENT) 

80 

81 # Register instance 

82 container.register_instance(ConfigService, config_obj) 

83 

84 # Resolve 

85 db = container.resolve(AbstractDB) 

86 

87 # Factory 

88 container.register_factory(AbstractQueue, lambda: build_queue()) 

89 """ 

90 

91 def __init__(self, parent: Optional["DIContainer"] = None): 

92 self._registrations: Dict[Any, Registration] = {} 

93 self._lock = RLock() 

94 self._parent = parent 

95 

96 # ---------- register ---------- 

97 

98 def register( 

99 self, 

100 interface: Type, 

101 implementation: Optional[Type] = None, 

102 lifetime: Lifetime = Lifetime.TRANSIENT, 

103 ) -> None: 

104 """Register an interface with its implementation.""" 

105 if implementation is None: 

106 implementation = interface 

107 with self._lock: 

108 self._registrations[interface] = Registration( 

109 interface=interface, 

110 implementation=implementation, 

111 lifetime=lifetime, 

112 ) 

113 

114 def register_instance(self, interface: Type, instance: Any) -> None: 

115 """Register a pre-built instance.""" 

116 with self._lock: 

117 self._registrations[interface] = Registration( 

118 interface=interface, 

119 implementation=type(instance), 

120 lifetime=Lifetime.SINGLETON, 

121 instance=instance, 

122 ) 

123 

124 def register_factory(self, interface: Type, factory: Callable[[], Any], lifetime: Lifetime = Lifetime.TRANSIENT) -> None: 

125 """Register a factory callable for the interface.""" 

126 with self._lock: 

127 self._registrations[interface] = Registration( 

128 interface=interface, 

129 implementation=None, 

130 lifetime=lifetime, 

131 factory=factory, 

132 ) 

133 

134 # ---------- resolve ---------- 

135 

136 def resolve(self, interface: Type[T]) -> T: 

137 """Resolve and return an instance of the given interface.""" 

138 return self._resolve(interface, set()) 

139 

140 def _resolve(self, interface: Type, resolving: Set[Type]) -> Any: 

141 # Check circular deps 

142 if interface in resolving: 

143 raise CircularDependencyError(list(resolving) + [interface]) 

144 

145 reg = self._get_registration(interface) 

146 resolving.add(interface) 

147 

148 try: 

149 # Instance already cached 

150 if reg.instance is not None: 

151 return reg.instance 

152 

153 # Singleton: create once 

154 if reg.lifetime == Lifetime.SINGLETON: 

155 with reg.instance_lock: 

156 if reg.instance is not None: 

157 return reg.instance 

158 instance = self._build(reg, resolving) 

159 reg.instance = instance 

160 return instance 

161 

162 # Transient: create every time 

163 return self._build(reg, resolving) 

164 finally: 

165 resolving.discard(interface) 

166 

167 def _get_registration(self, interface: Type) -> Registration: 

168 with self._lock: 

169 if interface in self._registrations: 

170 return self._registrations[interface] 

171 if self._parent: 

172 return self._parent._get_registration(interface) 

173 raise KeyError(f"No registration for {interface.__name__}") 

174 

175 def _build(self, reg: Registration, resolving: Set[Type]) -> Any: 

176 # Factory takes priority 

177 if reg.factory is not None: 

178 return reg.factory() 

179 

180 # Constructor injection 

181 impl = reg.implementation or reg.interface 

182 hints = self._safe_get_type_hints(impl.__init__) 

183 kwargs: Dict[str, Any] = {} 

184 

185 for param_name, param in inspect.signature(impl.__init__).parameters.items(): 

186 if param_name == "self": 

187 continue 

188 param_type = hints.get(param_name) 

189 if param_type is not None: 

190 try: 

191 kwargs[param_name] = self._resolve(param_type, resolving.copy()) 

192 except (KeyError, CircularDependencyError): 

193 if param.default is not inspect.Parameter.empty: 

194 kwargs[param_name] = param.default 

195 else: 

196 raise 

197 elif param.default is not inspect.Parameter.empty: 

198 kwargs[param_name] = param.default 

199 

200 return impl(**kwargs) 

201 

202 @staticmethod 

203 def _safe_get_type_hints(func) -> Dict[str, Any]: 

204 try: 

205 return get_type_hints(func) 

206 except Exception: 

207 return {} 

208 

209 # ---------- scoped ---------- 

210 

211 def create_scope(self) -> "DIContainer": 

212 """Create a scoped child container (snapshot of current registrations).""" 

213 return DIContainer(parent=self) 

214 

215 # ---------- check ---------- 

216 

217 def is_registered(self, interface: Type) -> bool: 

218 with self._lock: 

219 if interface in self._registrations: 

220 return True 

221 if self._parent: 

222 return self._parent.is_registered(interface) 

223 return False