Coverage for agentos/core/di.py: 33%
79 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Dependency Injection system for NexusAgent.
4Provides type-safe Agent[Deps, Out] generic base class,
5RunContext for dependency injection, and Depends() for
6automatic dependency resolution.
7"""
9from __future__ import annotations
11import uuid
12from dataclasses import dataclass, field
13from typing import Any, Callable, Generic, TypeVar, get_type_hints, get_origin, get_args
15# Type variables for Agent generic
16Deps = TypeVar("Deps")
17Out = TypeVar("Out")
20@dataclass
21class RunContext(Generic[Deps]):
22 """
23 Runtime context passed to Agent.run().
25 Contains:
26 - deps: The dependencies for this agent
27 - agent_name: Name of the agent
28 - run_id: Unique ID for this run
29 - metadata: Additional metadata
30 """
31 deps: Deps
32 agent_name: str = ""
33 run_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
34 metadata: dict[str, Any] = field(default_factory=dict)
36 def get(self, key: str, default: Any = None) -> Any:
37 """Get metadata value."""
38 return self.metadata.get(key, default)
40 def set(self, key: str, value: Any) -> None:
41 """Set metadata value."""
42 self.metadata[key] = value
45class Depends:
46 """
47 Dependency marker for automatic injection.
49 Usage:
50 def get_db() -> Database:
51 return Database()
53 class MyAgent(Agent[Depends(get_db), str]):
54 async def run(self, ctx):
55 db = ctx.deps # Database instance
56 """
57 def __init__(self, callable: Callable[..., Any]):
58 self.callable = callable
60 def resolve(self) -> Any:
61 """Resolve the dependency."""
62 return self.callable()
65def inject_tool(tool: Callable[..., Any]) -> Callable[..., Any]:
66 """
67 Decorator to inject a tool into an agent.
69 Usage:
70 @inject_tool(search_tool)
71 class MyAgent(Agent):
72 ...
73 """
74 def decorator(cls):
75 if not hasattr(cls, '_tools'):
76 cls._tools = []
77 cls._tools.append(tool)
78 return cls
79 return decorator
82def requires_context(*fields: str) -> Callable[..., Any]:
83 """
84 Decorator to declare required context fields.
86 Usage:
87 @requires_context("user_id", "session_id")
88 class MyAgent(Agent):
89 ...
90 """
91 def decorator(cls):
92 if not hasattr(cls, '_required_context'):
93 cls._required_context = []
94 cls._required_context.extend(fields)
95 return cls
96 return decorator
99class Agent(Generic[Deps, Out]):
100 """
101 Base class for all agents.
103 Type-safe generic: Agent[Deps, Out]
104 - Deps: Type of dependencies
105 - Out: Type of output
107 Usage:
108 class MyAgent(Agent[str, str]):
109 async def run(self, ctx: RunContext[str]) -> str:
110 return f"Hello, {ctx.deps}!"
112 agent = MyAgent()
113 result = await agent.invoke("World")
114 """
116 def __init__(self, name: str = ""):
117 self.name = name or self.__class__.__name__
118 self._tools: list[Callable[..., Any]] = getattr(self.__class__, '_tools', [])
119 self._required_context: list[str] = getattr(self.__class__, '_required_context', [])
121 async def run(self, ctx: RunContext[Deps]) -> Out:
122 """
123 Main agent logic. Override in subclass.
125 Args:
126 ctx: Runtime context with dependencies
128 Returns:
129 Agent output (type-checked against Out)
130 """
131 raise NotImplementedError("Subclass must implement run()")
133 async def invoke(self, deps: Deps, **metadata) -> Out:
134 """
135 Invoke the agent with dependencies.
137 Args:
138 deps: Dependencies to inject
139 **metadata: Additional metadata
141 Returns:
142 Agent output
143 """
144 # Resolve Depends if needed
145 if isinstance(deps, Depends):
146 deps = deps.resolve()
148 # Create context
149 ctx = RunContext[Deps](
150 deps=deps,
151 agent_name=self.name,
152 metadata=metadata,
153 )
155 # Validate required context
156 for field in self._required_context:
157 if field not in ctx.metadata:
158 raise ValueError(f"Required context field missing: {field}")
160 # Run agent
161 result = await self.run(ctx)
163 # Validate output type (if type hints available)
164 result = self._validate_output(result)
166 return result
168 def _validate_output(self, result: Any) -> Out:
169 """
170 Validate output against declared type.
172 Uses Pydantic validation if Out is a BaseModel,
173 otherwise basic type checking.
174 """
175 # Get type hints
176 hints = get_type_hints(self.__class__)
177 out_type = hints.get('Out')
179 if out_type is None:
180 # Try to get from generic base
181 for base in self.__class__.__mro__:
182 origin = get_origin(base)
183 if origin is Agent:
184 args = get_args(base)
185 if len(args) >= 2:
186 out_type = args[1]
187 break
189 if out_type is None:
190 return result
192 # Check if it's a Pydantic model
193 try:
194 from pydantic import BaseModel
195 if isinstance(out_type, type) and issubclass(out_type, BaseModel):
196 if not isinstance(result, out_type):
197 # Try to validate/convert
198 if isinstance(result, dict):
199 result = out_type(**result)
200 else:
201 result = out_type.model_validate(result)
202 except ImportError:
203 pass
205 return result
207 def get_tools(self) -> list[Callable[..., Any]]:
208 """Get registered tools."""
209 return self._tools.copy()
211 def __repr__(self) -> str:
212 return f"{self.__class__.__name__}(name={self.name!r})"