Coverage for agentos/core/di.py: 33%

79 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Dependency Injection system for NexusAgent. 

3 

4Provides type-safe Agent[Deps, Out] generic base class, 

5RunContext for dependency injection, and Depends() for 

6automatic dependency resolution. 

7""" 

8 

9from __future__ import annotations 

10 

11import uuid 

12from dataclasses import dataclass, field 

13from typing import Any, Callable, Generic, TypeVar, get_type_hints, get_origin, get_args 

14 

15# Type variables for Agent generic 

16Deps = TypeVar("Deps") 

17Out = TypeVar("Out") 

18 

19 

20@dataclass 

21class RunContext(Generic[Deps]): 

22 """ 

23 Runtime context passed to Agent.run(). 

24 

25 Contains: 

26 - deps: The dependencies for this agent 

27 - agent_name: Name of the agent 

28 - run_id: Unique ID for this run 

29 - metadata: Additional metadata 

30 """ 

31 deps: Deps 

32 agent_name: str = "" 

33 run_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12]) 

34 metadata: dict[str, Any] = field(default_factory=dict) 

35 

36 def get(self, key: str, default: Any = None) -> Any: 

37 """Get metadata value.""" 

38 return self.metadata.get(key, default) 

39 

40 def set(self, key: str, value: Any) -> None: 

41 """Set metadata value.""" 

42 self.metadata[key] = value 

43 

44 

45class Depends: 

46 """ 

47 Dependency marker for automatic injection. 

48 

49 Usage: 

50 def get_db() -> Database: 

51 return Database() 

52 

53 class MyAgent(Agent[Depends(get_db), str]): 

54 async def run(self, ctx): 

55 db = ctx.deps # Database instance 

56 """ 

57 def __init__(self, callable: Callable[..., Any]): 

58 self.callable = callable 

59 

60 def resolve(self) -> Any: 

61 """Resolve the dependency.""" 

62 return self.callable() 

63 

64 

65def inject_tool(tool: Callable[..., Any]) -> Callable[..., Any]: 

66 """ 

67 Decorator to inject a tool into an agent. 

68 

69 Usage: 

70 @inject_tool(search_tool) 

71 class MyAgent(Agent): 

72 ... 

73 """ 

74 def decorator(cls): 

75 if not hasattr(cls, '_tools'): 

76 cls._tools = [] 

77 cls._tools.append(tool) 

78 return cls 

79 return decorator 

80 

81 

82def requires_context(*fields: str) -> Callable[..., Any]: 

83 """ 

84 Decorator to declare required context fields. 

85 

86 Usage: 

87 @requires_context("user_id", "session_id") 

88 class MyAgent(Agent): 

89 ... 

90 """ 

91 def decorator(cls): 

92 if not hasattr(cls, '_required_context'): 

93 cls._required_context = [] 

94 cls._required_context.extend(fields) 

95 return cls 

96 return decorator 

97 

98 

99class Agent(Generic[Deps, Out]): 

100 """ 

101 Base class for all agents. 

102 

103 Type-safe generic: Agent[Deps, Out] 

104 - Deps: Type of dependencies 

105 - Out: Type of output 

106 

107 Usage: 

108 class MyAgent(Agent[str, str]): 

109 async def run(self, ctx: RunContext[str]) -> str: 

110 return f"Hello, {ctx.deps}!" 

111 

112 agent = MyAgent() 

113 result = await agent.invoke("World") 

114 """ 

115 

116 def __init__(self, name: str = ""): 

117 self.name = name or self.__class__.__name__ 

118 self._tools: list[Callable[..., Any]] = getattr(self.__class__, '_tools', []) 

119 self._required_context: list[str] = getattr(self.__class__, '_required_context', []) 

120 

121 async def run(self, ctx: RunContext[Deps]) -> Out: 

122 """ 

123 Main agent logic. Override in subclass. 

124 

125 Args: 

126 ctx: Runtime context with dependencies 

127 

128 Returns: 

129 Agent output (type-checked against Out) 

130 """ 

131 raise NotImplementedError("Subclass must implement run()") 

132 

133 async def invoke(self, deps: Deps, **metadata) -> Out: 

134 """ 

135 Invoke the agent with dependencies. 

136 

137 Args: 

138 deps: Dependencies to inject 

139 **metadata: Additional metadata 

140 

141 Returns: 

142 Agent output 

143 """ 

144 # Resolve Depends if needed 

145 if isinstance(deps, Depends): 

146 deps = deps.resolve() 

147 

148 # Create context 

149 ctx = RunContext[Deps]( 

150 deps=deps, 

151 agent_name=self.name, 

152 metadata=metadata, 

153 ) 

154 

155 # Validate required context 

156 for field in self._required_context: 

157 if field not in ctx.metadata: 

158 raise ValueError(f"Required context field missing: {field}") 

159 

160 # Run agent 

161 result = await self.run(ctx) 

162 

163 # Validate output type (if type hints available) 

164 result = self._validate_output(result) 

165 

166 return result 

167 

168 def _validate_output(self, result: Any) -> Out: 

169 """ 

170 Validate output against declared type. 

171 

172 Uses Pydantic validation if Out is a BaseModel, 

173 otherwise basic type checking. 

174 """ 

175 # Get type hints 

176 hints = get_type_hints(self.__class__) 

177 out_type = hints.get('Out') 

178 

179 if out_type is None: 

180 # Try to get from generic base 

181 for base in self.__class__.__mro__: 

182 origin = get_origin(base) 

183 if origin is Agent: 

184 args = get_args(base) 

185 if len(args) >= 2: 

186 out_type = args[1] 

187 break 

188 

189 if out_type is None: 

190 return result 

191 

192 # Check if it's a Pydantic model 

193 try: 

194 from pydantic import BaseModel 

195 if isinstance(out_type, type) and issubclass(out_type, BaseModel): 

196 if not isinstance(result, out_type): 

197 # Try to validate/convert 

198 if isinstance(result, dict): 

199 result = out_type(**result) 

200 else: 

201 result = out_type.model_validate(result) 

202 except ImportError: 

203 pass 

204 

205 return result 

206 

207 def get_tools(self) -> list[Callable[..., Any]]: 

208 """Get registered tools.""" 

209 return self._tools.copy() 

210 

211 def __repr__(self) -> str: 

212 return f"{self.__class__.__name__}(name={self.name!r})"