Coverage for agentos/tools/di_container.py: 0%
99 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 07:37 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 07:37 +0800
1"""
2DIContainer — lightweight dependency injection container.
4Supports:
5 - Singleton and transient lifetimes
6 - Constructor autowiring via type annotations
7 - Factory registration
8 - Instance registration
9 - Scoped sub-containers (snapshot-based)
10 - Circular dependency detection
11"""
13from __future__ import annotations
15import inspect
16from enum import Enum
17from threading import RLock
18from typing import Any, Callable, Dict, Optional, Set, Type, TypeVar, get_type_hints
20T = TypeVar("T")
23# ============================================================================
24# Lifetime
25# ============================================================================
27class Lifetime(Enum):
28 SINGLETON = "singleton"
29 TRANSIENT = "transient"
32# ============================================================================
33# Registration
34# ============================================================================
36class Registration:
37 __slots__ = ("interface", "implementation", "lifetime", "factory", "instance", "instance_lock")
39 def __init__(
40 self,
41 interface: Type,
42 implementation: Optional[Type] = None,
43 lifetime: Lifetime = Lifetime.TRANSIENT,
44 factory: Optional[Callable[[], Any]] = None,
45 instance: Any = None,
46 ):
47 self.interface = interface
48 self.implementation = implementation
49 self.lifetime = lifetime
50 self.factory = factory
51 self.instance = instance
52 self.instance_lock = RLock()
55# ============================================================================
56# Circular Dependency Error
57# ============================================================================
59class CircularDependencyError(Exception):
60 def __init__(self, chain: list):
61 self.chain = chain
62 super().__init__(f"Circular dependency detected: {' → '.join(str(c) for c in chain)}")
65# ============================================================================
66# DIContainer
67# ============================================================================
69class DIContainer:
70 """Lightweight dependency injection container.
72 Usage:
73 container = DIContainer()
75 # Register singleton
76 container.register(AbstractDB, ConcretePostgres, Lifetime.SINGLETON)
78 # Register transient
79 container.register(AbstractCache, RedisCache, Lifetime.TRANSIENT)
81 # Register instance
82 container.register_instance(ConfigService, config_obj)
84 # Resolve
85 db = container.resolve(AbstractDB)
87 # Factory
88 container.register_factory(AbstractQueue, lambda: build_queue())
89 """
91 def __init__(self, parent: Optional["DIContainer"] = None):
92 self._registrations: Dict[Any, Registration] = {}
93 self._lock = RLock()
94 self._parent = parent
96 # ---------- register ----------
98 def register(
99 self,
100 interface: Type,
101 implementation: Optional[Type] = None,
102 lifetime: Lifetime = Lifetime.TRANSIENT,
103 ) -> None:
104 """Register an interface with its implementation."""
105 if implementation is None:
106 implementation = interface
107 with self._lock:
108 self._registrations[interface] = Registration(
109 interface=interface,
110 implementation=implementation,
111 lifetime=lifetime,
112 )
114 def register_instance(self, interface: Type, instance: Any) -> None:
115 """Register a pre-built instance."""
116 with self._lock:
117 self._registrations[interface] = Registration(
118 interface=interface,
119 implementation=type(instance),
120 lifetime=Lifetime.SINGLETON,
121 instance=instance,
122 )
124 def register_factory(self, interface: Type, factory: Callable[[], Any], lifetime: Lifetime = Lifetime.TRANSIENT) -> None:
125 """Register a factory callable for the interface."""
126 with self._lock:
127 self._registrations[interface] = Registration(
128 interface=interface,
129 implementation=None,
130 lifetime=lifetime,
131 factory=factory,
132 )
134 # ---------- resolve ----------
136 def resolve(self, interface: Type[T]) -> T:
137 """Resolve and return an instance of the given interface."""
138 return self._resolve(interface, set())
140 def _resolve(self, interface: Type, resolving: Set[Type]) -> Any:
141 # Check circular deps
142 if interface in resolving:
143 raise CircularDependencyError(list(resolving) + [interface])
145 reg = self._get_registration(interface)
146 resolving.add(interface)
148 try:
149 # Instance already cached
150 if reg.instance is not None:
151 return reg.instance
153 # Singleton: create once
154 if reg.lifetime == Lifetime.SINGLETON:
155 with reg.instance_lock:
156 if reg.instance is not None:
157 return reg.instance
158 instance = self._build(reg, resolving)
159 reg.instance = instance
160 return instance
162 # Transient: create every time
163 return self._build(reg, resolving)
164 finally:
165 resolving.discard(interface)
167 def _get_registration(self, interface: Type) -> Registration:
168 with self._lock:
169 if interface in self._registrations:
170 return self._registrations[interface]
171 if self._parent:
172 return self._parent._get_registration(interface)
173 raise KeyError(f"No registration for {interface.__name__}")
175 def _build(self, reg: Registration, resolving: Set[Type]) -> Any:
176 # Factory takes priority
177 if reg.factory is not None:
178 return reg.factory()
180 # Constructor injection
181 impl = reg.implementation or reg.interface
182 hints = self._safe_get_type_hints(impl.__init__)
183 kwargs: Dict[str, Any] = {}
185 for param_name, param in inspect.signature(impl.__init__).parameters.items():
186 if param_name == "self":
187 continue
188 param_type = hints.get(param_name)
189 if param_type is not None:
190 try:
191 kwargs[param_name] = self._resolve(param_type, resolving.copy())
192 except (KeyError, CircularDependencyError):
193 if param.default is not inspect.Parameter.empty:
194 kwargs[param_name] = param.default
195 else:
196 raise
197 elif param.default is not inspect.Parameter.empty:
198 kwargs[param_name] = param.default
200 return impl(**kwargs)
202 @staticmethod
203 def _safe_get_type_hints(func) -> Dict[str, Any]:
204 try:
205 return get_type_hints(func)
206 except Exception:
207 return {}
209 # ---------- scoped ----------
211 def create_scope(self) -> "DIContainer":
212 """Create a scoped child container (snapshot of current registrations)."""
213 return DIContainer(parent=self)
215 # ---------- check ----------
217 def is_registered(self, interface: Type) -> bool:
218 with self._lock:
219 if interface in self._registrations:
220 return True
221 if self._parent:
222 return self._parent.is_registered(interface)
223 return False