Coverage for agentos/agent/tool_agent.py: 27%
244 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Tool-Using Agent — 基于 LLM Function Calling 的自主 Agent 循环。
4核心模式:
5 用户任务 → LLM 推理(tool_calls) → 工具执行 → 结果回传 → 循环直到完成
7v1.3.38: +streaming, retry, checkpoint/resume, tool error handling, mock provider.
8"""
10from __future__ import annotations
12import json
13import os
14import time
15from dataclasses import dataclass, field
16from typing import Any, Callable, Generator
18from agentos.llm.base import (
19 LLMProvider,
20 Message,
21 MessageRole,
22 CompletionResult,
23 CompletionChoice,
24 CompletionUsage,
25 Tool,
26 ToolCall,
27)
29__all__ = [
30 "ToolAgent",
31 "AgentConfig",
32 "AgentStep",
33 "AgentResult",
34 "ToolExecutor",
35 "MockLLMProvider",
36]
39# ── 数据类型 ─────────────────────────────────────────────────────
41@dataclass
42class AgentConfig:
43 max_steps: int = 10
44 temperature: float = 0.0
45 max_tokens: int = 4096
46 verbose: bool = False
47 stop_on_error: bool = True
48 max_retries: int = 2
49 retry_delay: float = 0.5
50 checkpoint_dir: str = ""
53@dataclass
54class AgentStep:
55 step: int
56 thought: str = ""
57 tool_calls: list[ToolCall] = field(default_factory=list)
58 tool_results: dict[str, str] = field(default_factory=dict)
59 finish_reason: str = ""
60 tokens_used: int = 0
61 cost_usd: float = 0.0
62 duration_ms: float = 0.0
65@dataclass
66class AgentResult:
67 success: bool = True
68 final_answer: str = ""
69 steps: list[AgentStep] = field(default_factory=list)
70 total_steps: int = 0
71 total_tokens: int = 0
72 total_cost_usd: float = 0.0
73 total_duration_ms: float = 0.0
74 error: str | None = None
77# ── 工具执行器 ───────────────────────────────────────────────────
79class ToolExecutor:
80 def __init__(self):
81 self._tools: dict[str, Callable[..., str]] = {}
82 self._schemas: dict[str, Tool] = {}
84 def register(self, tool: Tool, handler: Callable[..., str]) -> None:
85 self._tools[tool.function.name] = handler
86 self._schemas[tool.function.name] = tool
88 def get_schemas(self) -> list[Tool]:
89 return list(self._schemas.values())
91 def execute(self, tool_call: ToolCall) -> str:
92 handler = self._tools.get(tool_call.name)
93 if handler is None:
94 return json.dumps({"error": f"Unknown tool: {tool_call.name}"})
95 try:
96 return str(handler(**tool_call.parsed_arguments))
97 except Exception as e:
98 return json.dumps({"error": str(e)})
101# ── MockLLMProvider ──────────────────────────────────────────────
103class MockLLMProvider(LLMProvider):
104 """可编程响应的 Mock Provider,供集成测试使用。"""
106 def __init__(self, responses: list[dict]):
107 super().__init__(model="mock", api_key="mock")
108 self._responses = responses
109 self._cursor = 0
110 self.calls: list[dict] = []
112 def chat(self, messages=None, *, temperature=0, max_tokens=4096, tools=None, **kwargs):
113 if self._cursor >= len(self._responses):
114 return self._build_result({"content": "done", "finish_reason": "stop"})
115 resp = self._responses[self._cursor]
116 self._cursor += 1
117 self.calls.append({
118 "tools": [t.function.name for t in (tools or [])],
119 "cursor": self._cursor - 1,
120 })
121 return self._build_result(resp)
123 async def achat(self, *args, **kwargs):
124 return self.chat(*args, **kwargs)
126 @property
127 def provider_name(self) -> str:
128 return "mock"
130 @staticmethod
131 def text_response(content: str, finish_reason: str = "stop") -> dict:
132 return {"content": content, "finish_reason": finish_reason}
134 @staticmethod
135 def tool_response(name: str, arguments: dict, tool_call_id: str = "") -> dict:
136 tid = tool_call_id or f"tc_{name}"
137 return {
138 "content": "",
139 "tool_calls": [ToolCall(id=tid, name=name, arguments=json.dumps(arguments))],
140 "finish_reason": "tool_calls",
141 }
143 def _build_result(self, resp: dict) -> CompletionResult:
144 msg = Message(
145 role=MessageRole.ASSISTANT,
146 content=resp.get("content", ""),
147 tool_calls=resp.get("tool_calls"),
148 )
149 choice = CompletionChoice(
150 index=0,
151 message=msg,
152 finish_reason=resp.get("finish_reason", "stop"),
153 )
154 return CompletionResult(
155 id=f"mock_{self._cursor}",
156 model="mock-model",
157 choices=[choice],
158 usage=CompletionUsage(prompt_tokens=5, completion_tokens=len(resp.get("content", "")) + 3, total_tokens=len(resp.get("content", "")) + 8),
159 )
162# ── Tool-Using Agent ─────────────────────────────────────────────
164class ToolAgent:
165 """基于 LLM Function Calling 的自主 Agent。
167 用法:
168 from agentos.agent import ToolAgent, ToolExecutor
169 from agentos.llm import create_provider, Tool
171 provider = create_provider("openai")
172 executor = ToolExecutor()
173 executor.register(
174 Tool.from_function("get_weather", "获取天气", {"city": ...}),
175 lambda city: f"{city}: 22°C sunny"
176 )
177 agent = ToolAgent(provider, executor)
178 result = agent.run("北京天气怎么样?")
179 print(result.final_answer)
180 """
182 def __init__(
183 self,
184 provider: LLMProvider,
185 tool_executor: ToolExecutor,
186 *,
187 config: AgentConfig | None = None,
188 system_prompt: str = "",
189 ):
190 self._provider = provider
191 self._executor = tool_executor
192 self._config = config or AgentConfig()
193 self._system_prompt = system_prompt or (
194 "你是一个智能助手。你可以使用工具来获取信息。"
195 "当你可以给出最终答案时,直接回答,不要再调用工具。"
196 "用中文回答。"
197 )
199 # ── 同步 ──────────────────────────────────────────────────
201 def run(self, task: str) -> AgentResult:
202 t0 = time.monotonic()
203 steps: list[AgentStep] = []
204 tools = self._executor.get_schemas()
205 messages: list[Message] = [
206 Message(role=MessageRole.SYSTEM, content=self._system_prompt),
207 Message(role=MessageRole.USER, content=task),
208 ]
209 return self._run_loop(messages, task, tools, steps, 1, t0)
211 def _run_loop(
212 self, messages, task, tools, steps, start_step, t0,
213 ) -> AgentResult:
214 final_answer = ""
215 total_tokens = 0
216 total_cost = 0.0
217 step_num = start_step
219 try:
220 for step_num in range(start_step, self._config.max_steps + 1):
221 result = self._call_with_retry(messages, tools)
222 step, done, final = self._process_step(result, step_num)
223 total_tokens += step.tokens_used
224 total_cost += step.cost_usd
225 steps.append(step)
226 if done:
227 final_answer = final
228 break
229 messages.append(result.choices[0].message)
230 for tc in step.tool_calls:
231 messages.append(Message(
232 role=MessageRole.TOOL,
233 content=step.tool_results.get(tc.id, ""),
234 tool_call_id=tc.id,
235 ))
236 self._checkpoint(messages, task, step_num)
237 else:
238 return self._make_result(
239 False, "", steps, total_tokens, total_cost, t0,
240 f"Reached max steps ({self._config.max_steps}) without final answer",
241 )
242 except Exception as e:
243 return self._make_result(False, "", steps, total_tokens, total_cost, t0, str(e))
245 return self._make_result(True, final_answer, steps, total_tokens, total_cost, t0)
247 # ── 流式 ──────────────────────────────────────────────────
249 def run_stream(self, task: str) -> Generator[AgentStep, None, AgentResult]:
250 t0 = time.monotonic()
251 steps: list[AgentStep] = []
252 tools = self._executor.get_schemas()
253 messages: list[Message] = [
254 Message(role=MessageRole.SYSTEM, content=self._system_prompt),
255 Message(role=MessageRole.USER, content=task),
256 ]
257 total_tokens = 0
258 total_cost = 0.0
259 final_answer = ""
261 try:
262 for step_num in range(1, self._config.max_steps + 1):
263 result = self._call_with_retry(messages, tools)
264 step, done, final = self._process_step(result, step_num)
265 total_tokens += step.tokens_used
266 total_cost += step.cost_usd
267 yield step
268 steps.append(step)
269 if done:
270 final_answer = final
271 break
272 messages.append(result.choices[0].message)
273 for tc in step.tool_calls:
274 messages.append(Message(
275 role=MessageRole.TOOL,
276 content=step.tool_results.get(tc.id, ""),
277 tool_call_id=tc.id,
278 ))
279 self._checkpoint(messages, task, step_num)
280 else:
281 return self._make_result(
282 False, "", steps, total_tokens, total_cost, t0,
283 f"Reached max steps ({self._config.max_steps}) without final answer",
284 )
285 except Exception as e:
286 return self._make_result(False, "", steps, total_tokens, total_cost, t0, str(e))
288 return self._make_result(True, final_answer, steps, total_tokens, total_cost, t0)
290 # ── 异步 ──────────────────────────────────────────────────
292 async def arun(self, task: str) -> AgentResult:
293 t0 = time.monotonic()
294 steps: list[AgentStep] = []
295 tools = self._executor.get_schemas()
296 messages: list[Message] = [
297 Message(role=MessageRole.SYSTEM, content=self._system_prompt),
298 Message(role=MessageRole.USER, content=task),
299 ]
300 final_answer = ""
301 total_tokens = 0
302 total_cost = 0.0
304 try:
305 for step_num in range(1, self._config.max_steps + 1):
306 result = await self._acall_with_retry(messages, tools)
307 step, done, final = self._process_step(result, step_num)
308 total_tokens += step.tokens_used
309 total_cost += step.cost_usd
310 steps.append(step)
311 if done:
312 final_answer = final
313 break
314 messages.append(result.choices[0].message)
315 for tc in step.tool_calls:
316 messages.append(Message(
317 role=MessageRole.TOOL,
318 content=step.tool_results.get(tc.id, ""),
319 tool_call_id=tc.id,
320 ))
321 self._checkpoint(messages, task, step_num)
322 else:
323 return self._make_result(
324 False, "", steps, total_tokens, total_cost, t0,
325 f"Reached max steps ({self._config.max_steps}) without final answer",
326 )
327 except Exception as e:
328 return self._make_result(False, "", steps, total_tokens, total_cost, t0, str(e))
330 return self._make_result(True, final_answer, steps, total_tokens, total_cost, t0)
332 # ── 共享步骤逻辑 ───────────────────────────────────────────────
334 def _process_step(
335 self, result: CompletionResult, step_num: int,
336 ) -> tuple[AgentStep, bool, str]:
337 """处理单步 LLM 结果:构建 AgentStep、执行工具、判断终止。
339 Returns:
340 (step, done, final_answer)
341 """
342 step_t0 = time.monotonic()
343 choice = result.choices[0]
344 assistant_msg = choice.message
346 step = AgentStep(
347 step=step_num,
348 thought=assistant_msg.content,
349 tool_calls=assistant_msg.tool_calls or [],
350 finish_reason=choice.finish_reason,
351 tokens_used=result.usage.total_tokens,
352 cost_usd=result.usage.cost_usd,
353 duration_ms=(time.monotonic() - step_t0) * 1000,
354 )
356 if self._config.verbose:
357 self._log_step(step)
359 # 无工具调用 → 终止,内容即为答案
360 if not assistant_msg.tool_calls:
361 return step, True, assistant_msg.content
363 # 执行工具调用
364 for tc in assistant_msg.tool_calls:
365 tool_result = self._executor.execute(tc)
366 step.tool_results[tc.id] = tool_result
367 if "error" in tool_result and self._config.stop_on_error:
368 raise RuntimeError(f"Tool '{tc.name}' error: {tool_result}")
370 # finish_reason == "stop" → 提前终止
371 if choice.finish_reason == "stop":
372 return step, True, assistant_msg.content
374 return step, False, ""
376 def _make_result(
377 self, success: bool, answer: str, steps: list[AgentStep],
378 total_tokens: int, total_cost: float, t0: float, error: str = None,
379 ) -> AgentResult:
380 """统一构造 AgentResult。"""
381 return AgentResult(
382 success=success,
383 final_answer=answer,
384 steps=steps,
385 total_steps=len(steps),
386 total_tokens=total_tokens,
387 total_cost_usd=total_cost,
388 total_duration_ms=(time.monotonic() - t0) * 1000,
389 error=error,
390 )
392 # ── Checkpoint / Resume ───────────────────────────────────
394 def resume(self) -> AgentResult:
395 if not self._config.checkpoint_dir:
396 raise ValueError("checkpoint_dir not configured")
397 ckpt_path = os.path.join(self._config.checkpoint_dir, "agent_checkpoint.json")
398 if not os.path.exists(ckpt_path):
399 raise FileNotFoundError(f"No checkpoint found at {ckpt_path}")
400 with open(ckpt_path) as f:
401 data = json.load(f)
402 task = data["task"]
403 start_step = data["step"] + 1
404 messages_raw = data["messages"]
405 messages = [
406 Message(
407 role=MessageRole(m["role"]),
408 content=m["content"],
409 tool_call_id=m.get("tool_call_id"),
410 tool_calls=[ToolCall(**tc) for tc in m["tool_calls"]] if m.get("tool_calls") else None,
411 )
412 for m in messages_raw
413 ]
414 t0 = time.monotonic()
415 tools = self._executor.get_schemas()
416 steps: list[AgentStep] = []
417 return self._run_loop(messages, task, tools, steps, start_step, t0)
419 def _checkpoint(self, messages: list[Message], task: str, step: int) -> None:
420 if not self._config.checkpoint_dir:
421 return
422 ckpt_path = os.path.join(self._config.checkpoint_dir, "agent_checkpoint.json")
423 data = {
424 "task": task,
425 "step": step,
426 "messages": [
427 {
428 "role": m.role.value,
429 "content": m.content,
430 "tool_call_id": m.tool_call_id,
431 "tool_calls": [
432 {"id": tc.id, "name": tc.name, "arguments": tc.arguments}
433 for tc in m.tool_calls
434 ] if m.tool_calls else None,
435 }
436 for m in messages
437 ],
438 }
439 with open(ckpt_path, "w") as f:
440 json.dump(data, f, ensure_ascii=False)
442 # ── 内部方法 ──────────────────────────────────────────────
444 def _call_with_retry(self, messages: list[Message], tools: list[Tool]) -> CompletionResult:
445 last_error = None
446 for attempt in range(self._config.max_retries + 1):
447 try:
448 return self._provider.chat(
449 messages,
450 temperature=self._config.temperature,
451 max_tokens=self._config.max_tokens,
452 tools=tools if tools else None,
453 )
454 except Exception as e:
455 last_error = e
456 if attempt < self._config.max_retries:
457 time.sleep(self._config.retry_delay)
458 raise last_error # type: ignore
460 async def _acall_with_retry(self, messages: list[Message], tools: list[Tool]) -> CompletionResult:
461 last_error = None
462 for attempt in range(self._config.max_retries + 1):
463 try:
464 return await self._provider.achat(
465 messages,
466 temperature=self._config.temperature,
467 max_tokens=self._config.max_tokens,
468 tools=tools if tools else None,
469 )
470 except Exception as e:
471 last_error = e
472 if attempt < self._config.max_retries:
473 import asyncio
474 await asyncio.sleep(self._config.retry_delay)
475 raise last_error # type: ignore
477 def _log_step(self, step: AgentStep) -> None:
478 print(f"\n── Step {step.step} ({step.duration_ms:.0f}ms, {step.tokens_used}t, ${step.cost_usd:.6f}) ──")
479 if step.thought:
480 print(f" Thought: {step.thought}")
481 if step.tool_calls:
482 for tc in step.tool_calls:
483 result_preview = step.tool_results.get(tc.id, "")[:100]
484 print(f" Tool: {tc.name}({tc.arguments}) → {result_preview}")
485 print(f" Finish: {step.finish_reason}")