Coverage for agentos/agent/tool_agent.py: 25%
281 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 13:55 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 13:55 +0800
1"""
2Tool-Using Agent — 基于 LLM Function Calling 的自主 Agent 循环。
4核心模式:
5 用户任务 → LLM 推理(tool_calls) → 工具执行 → 结果回传 → 循环直到完成
7v1.16.1: +CircuitBreaker +ToolOutputValidator +Metrics integrated into ToolExecutor/Agent.
8v1.3.38: +streaming, retry, checkpoint/resume, tool error handling, mock provider.
9"""
11from __future__ import annotations
13import json
14import os
15import time
16from dataclasses import dataclass, field
17from typing import Any, Callable, Generator, Optional
19from agentos.llm.base import (
20 LLMProvider,
21 Message,
22 MessageRole,
23 CompletionResult,
24 CompletionChoice,
25 CompletionUsage,
26 Tool,
27 ToolCall,
28)
29from agentos.tools.circuit_breaker import CircuitBreaker, CircuitState
30from agentos.tools.validation import ToolOutputValidator, ToolResult, ValidationResult
31from agentos.tools.metrics import MetricsCollector, Counter, Timer
33__all__ = [
34 "ToolAgent",
35 "AgentConfig",
36 "AgentStep",
37 "AgentResult",
38 "ToolExecutor",
39 "MockLLMProvider",
40]
43# ── 数据类型 ─────────────────────────────────────────────────────
45@dataclass
46class AgentConfig:
47 max_steps: int = 10
48 temperature: float = 0.0
49 max_tokens: int = 4096
50 verbose: bool = False
51 stop_on_error: bool = True
52 max_retries: int = 2
53 retry_delay: float = 0.5
54 checkpoint_dir: str = ""
57@dataclass
58class AgentStep:
59 step: int
60 thought: str = ""
61 tool_calls: list[ToolCall] = field(default_factory=list)
62 tool_results: dict[str, str] = field(default_factory=dict)
63 finish_reason: str = ""
64 tokens_used: int = 0
65 cost_usd: float = 0.0
66 duration_ms: float = 0.0
69@dataclass
70class AgentResult:
71 success: bool = True
72 final_answer: str = ""
73 steps: list[AgentStep] = field(default_factory=list)
74 total_steps: int = 0
75 total_tokens: int = 0
76 total_cost_usd: float = 0.0
77 total_duration_ms: float = 0.0
78 error: str | None = None
81# ── 工具执行器 ───────────────────────────────────────────────────
83class ToolExecutor:
84 """工具注册与执行器。
86 v1.16.1: 集成 CircuitBreaker(熔断保护)、ToolOutputValidator(输出校验)、
87 MetricsCollector(指标收集)。所有参数均为可选,不传则退化为原始行为。
88 """
90 def __init__(
91 self,
92 circuit_breaker: Optional[CircuitBreaker] = None,
93 validator: Optional[ToolOutputValidator] = None,
94 metrics: Optional[MetricsCollector] = None,
95 ):
96 self._tools: dict[str, Callable[..., str]] = {}
97 self._schemas: dict[str, Tool] = {}
98 self._cb = circuit_breaker
99 self._validator = validator
100 self._metrics = metrics
102 def register(self, tool: Tool, handler: Callable[..., str]) -> None:
103 self._tools[tool.function.name] = handler
104 self._schemas[tool.function.name] = tool
106 def get_schemas(self) -> list[Tool]:
107 return list(self._schemas.values())
109 def execute(self, tool_call: ToolCall) -> str:
110 """执行工具调用,经 CircuitBreaker → 执行 → Validator → Metrics 全链路。"""
111 handler = self._tools.get(tool_call.name)
112 if handler is None:
113 return json.dumps({"error": f"Unknown tool: {tool_call.name}"})
115 # 1. 熔断器检查
116 if self._cb is not None:
117 if not self._cb.allow_request():
118 self._cb.record_failure()
119 return json.dumps({
120 "error": f"Circuit breaker OPEN for '{self._cb.name}'",
121 "circuit_state": self._cb.state.name,
122 })
124 t0 = time.monotonic()
125 try:
126 raw_output = str(handler(**tool_call.parsed_arguments))
128 # 2. 输出校验
129 validation_msg = ""
130 if self._validator is not None:
131 tr = ToolResult(output=raw_output, tool_name=tool_call.name)
132 val_result: ValidationResult = self._validator.validate(tr)
133 if not val_result.is_valid:
134 issues = "; ".join(i.message for i in val_result.issues)
135 validation_msg = f" [validation: {issues}]"
137 # 3. 熔断器记录成功
138 if self._cb is not None:
139 self._cb.record_success()
141 # 4. 指标记录
142 elapsed = (time.monotonic() - t0) * 1000
143 if self._metrics is not None:
144 self._metrics.get_counter("tool_calls_total").inc(tool_call.name)
145 self._metrics.get_counter("tool_calls_success").inc(tool_call.name)
146 self._metrics.get_timer("tool_latency_ms").record(elapsed)
148 return raw_output if not validation_msg else raw_output + validation_msg
150 except Exception as e:
151 elapsed = (time.monotonic() - t0) * 1000
152 if self._cb is not None:
153 self._cb.record_failure()
154 if self._metrics is not None:
155 self._metrics.get_counter("tool_calls_total").inc(tool_call.name)
156 self._metrics.get_counter("tool_calls_errors").inc(tool_call.name)
157 return json.dumps({"error": str(e)})
160# ── MockLLMProvider ──────────────────────────────────────────────
162class MockLLMProvider(LLMProvider):
163 """可编程响应的 Mock Provider,供集成测试使用。"""
165 def __init__(self, responses: list[dict]):
166 super().__init__(model="mock", api_key="mock")
167 self._responses = responses
168 self._cursor = 0
169 self.calls: list[dict] = []
171 def chat(self, messages=None, *, temperature=0, max_tokens=4096, tools=None, **kwargs):
172 if self._cursor >= len(self._responses):
173 return self._build_result({"content": "done", "finish_reason": "stop"})
174 resp = self._responses[self._cursor]
175 self._cursor += 1
176 self.calls.append({
177 "tools": [t.function.name for t in (tools or [])],
178 "cursor": self._cursor - 1,
179 })
180 return self._build_result(resp)
182 async def achat(self, *args, **kwargs):
183 return self.chat(*args, **kwargs)
185 @property
186 def provider_name(self) -> str:
187 return "mock"
189 @staticmethod
190 def text_response(content: str, finish_reason: str = "stop") -> dict:
191 return {"content": content, "finish_reason": finish_reason}
193 @staticmethod
194 def tool_response(name: str, arguments: dict, tool_call_id: str = "") -> dict:
195 tid = tool_call_id or f"tc_{name}"
196 return {
197 "content": "",
198 "tool_calls": [ToolCall(id=tid, name=name, arguments=json.dumps(arguments))],
199 "finish_reason": "tool_calls",
200 }
202 def _build_result(self, resp: dict) -> CompletionResult:
203 msg = Message(
204 role=MessageRole.ASSISTANT,
205 content=resp.get("content", ""),
206 tool_calls=resp.get("tool_calls"),
207 )
208 choice = CompletionChoice(
209 index=0,
210 message=msg,
211 finish_reason=resp.get("finish_reason", "stop"),
212 )
213 return CompletionResult(
214 id=f"mock_{self._cursor}",
215 model="mock-model",
216 choices=[choice],
217 usage=CompletionUsage(prompt_tokens=5, completion_tokens=len(resp.get("content", "")) + 3, total_tokens=len(resp.get("content", "")) + 8),
218 )
221# ── Tool-Using Agent ─────────────────────────────────────────────
223class ToolAgent:
224 """基于 LLM Function Calling 的自主 Agent。
226 用法:
227 from agentos.agent import ToolAgent, ToolExecutor
228 from agentos.llm import create_provider, Tool
230 provider = create_provider("openai")
231 executor = ToolExecutor()
232 executor.register(
233 Tool.from_function("get_weather", "获取天气", {"city": ...}),
234 lambda city: f"{city}: 22°C sunny"
235 )
236 agent = ToolAgent(provider, executor)
237 result = agent.run("北京天气怎么样?")
238 print(result.final_answer)
239 """
241 def __init__(
242 self,
243 provider: LLMProvider,
244 tool_executor: ToolExecutor,
245 *,
246 config: AgentConfig | None = None,
247 system_prompt: str = "",
248 metrics: Optional[MetricsCollector] = None,
249 ):
250 self._provider = provider
251 self._executor = tool_executor
252 self._config = config or AgentConfig()
253 self._system_prompt = system_prompt or (
254 "你是一个智能助手。你可以使用工具来获取信息。"
255 "当你可以给出最终答案时,直接回答,不要再调用工具。"
256 "用中文回答。"
257 )
258 self._metrics = metrics
260 # ── 同步 ──────────────────────────────────────────────────
262 def run(self, task: str) -> AgentResult:
263 t0 = time.monotonic()
264 steps: list[AgentStep] = []
265 tools = self._executor.get_schemas()
266 messages: list[Message] = [
267 Message(role=MessageRole.SYSTEM, content=self._system_prompt),
268 Message(role=MessageRole.USER, content=task),
269 ]
270 return self._run_loop(messages, task, tools, steps, 1, t0)
272 def _run_loop(
273 self, messages, task, tools, steps, start_step, t0,
274 ) -> AgentResult:
275 final_answer = ""
276 total_tokens = 0
277 total_cost = 0.0
278 step_num = start_step
280 try:
281 for step_num in range(start_step, self._config.max_steps + 1):
282 result = self._call_with_retry(messages, tools)
283 step, done, final = self._process_step(result, step_num)
284 total_tokens += step.tokens_used
285 total_cost += step.cost_usd
286 steps.append(step)
287 if done:
288 final_answer = final
289 break
290 messages.append(result.choices[0].message)
291 for tc in step.tool_calls:
292 messages.append(Message(
293 role=MessageRole.TOOL,
294 content=step.tool_results.get(tc.id, ""),
295 tool_call_id=tc.id,
296 ))
297 self._checkpoint(messages, task, step_num)
298 else:
299 return self._make_result(
300 False, "", steps, total_tokens, total_cost, t0,
301 f"Reached max steps ({self._config.max_steps}) without final answer",
302 )
303 except Exception as e:
304 return self._make_result(False, "", steps, total_tokens, total_cost, t0, str(e))
306 return self._make_result(True, final_answer, steps, total_tokens, total_cost, t0)
308 # ── 流式 ──────────────────────────────────────────────────
310 def run_stream(self, task: str) -> Generator[AgentStep, None, AgentResult]:
311 t0 = time.monotonic()
312 steps: list[AgentStep] = []
313 tools = self._executor.get_schemas()
314 messages: list[Message] = [
315 Message(role=MessageRole.SYSTEM, content=self._system_prompt),
316 Message(role=MessageRole.USER, content=task),
317 ]
318 total_tokens = 0
319 total_cost = 0.0
320 final_answer = ""
322 try:
323 for step_num in range(1, self._config.max_steps + 1):
324 result = self._call_with_retry(messages, tools)
325 step, done, final = self._process_step(result, step_num)
326 total_tokens += step.tokens_used
327 total_cost += step.cost_usd
328 yield step
329 steps.append(step)
330 if done:
331 final_answer = final
332 break
333 messages.append(result.choices[0].message)
334 for tc in step.tool_calls:
335 messages.append(Message(
336 role=MessageRole.TOOL,
337 content=step.tool_results.get(tc.id, ""),
338 tool_call_id=tc.id,
339 ))
340 self._checkpoint(messages, task, step_num)
341 else:
342 return self._make_result(
343 False, "", steps, total_tokens, total_cost, t0,
344 f"Reached max steps ({self._config.max_steps}) without final answer",
345 )
346 except Exception as e:
347 return self._make_result(False, "", steps, total_tokens, total_cost, t0, str(e))
349 return self._make_result(True, final_answer, steps, total_tokens, total_cost, t0)
351 # ── 异步 ──────────────────────────────────────────────────
353 async def arun(self, task: str) -> AgentResult:
354 t0 = time.monotonic()
355 steps: list[AgentStep] = []
356 tools = self._executor.get_schemas()
357 messages: list[Message] = [
358 Message(role=MessageRole.SYSTEM, content=self._system_prompt),
359 Message(role=MessageRole.USER, content=task),
360 ]
361 final_answer = ""
362 total_tokens = 0
363 total_cost = 0.0
365 try:
366 for step_num in range(1, self._config.max_steps + 1):
367 result = await self._acall_with_retry(messages, tools)
368 step, done, final = self._process_step(result, step_num)
369 total_tokens += step.tokens_used
370 total_cost += step.cost_usd
371 steps.append(step)
372 if done:
373 final_answer = final
374 break
375 messages.append(result.choices[0].message)
376 for tc in step.tool_calls:
377 messages.append(Message(
378 role=MessageRole.TOOL,
379 content=step.tool_results.get(tc.id, ""),
380 tool_call_id=tc.id,
381 ))
382 self._checkpoint(messages, task, step_num)
383 else:
384 return self._make_result(
385 False, "", steps, total_tokens, total_cost, t0,
386 f"Reached max steps ({self._config.max_steps}) without final answer",
387 )
388 except Exception as e:
389 return self._make_result(False, "", steps, total_tokens, total_cost, t0, str(e))
391 return self._make_result(True, final_answer, steps, total_tokens, total_cost, t0)
393 # ── 共享步骤逻辑 ───────────────────────────────────────────────
395 def _process_step(
396 self, result: CompletionResult, step_num: int,
397 ) -> tuple[AgentStep, bool, str]:
398 """处理单步 LLM 结果:构建 AgentStep、执行工具、判断终止。
400 Returns:
401 (step, done, final_answer)
402 """
403 step_t0 = time.monotonic()
404 choice = result.choices[0]
405 assistant_msg = choice.message
407 step = AgentStep(
408 step=step_num,
409 thought=assistant_msg.content,
410 tool_calls=assistant_msg.tool_calls or [],
411 finish_reason=choice.finish_reason,
412 tokens_used=result.usage.total_tokens,
413 cost_usd=result.usage.cost_usd,
414 duration_ms=(time.monotonic() - step_t0) * 1000,
415 )
417 if self._config.verbose:
418 self._log_step(step)
420 # Metrics: track LLM calls and tokens
421 if self._metrics is not None:
422 self._metrics.get_counter("llm_calls_total").inc()
423 self._metrics.get_counter("llm_tokens_total").inc(result.usage.total_tokens)
424 self._metrics.get_counter("agent_steps_total").inc()
426 # 无工具调用 → 终止,内容即为答案
427 if not assistant_msg.tool_calls:
428 return step, True, assistant_msg.content
430 # 执行工具调用
431 for tc in assistant_msg.tool_calls:
432 tool_result = self._executor.execute(tc)
433 step.tool_results[tc.id] = tool_result
434 if "error" in tool_result and self._config.stop_on_error:
435 raise RuntimeError(f"Tool '{tc.name}' error: {tool_result}")
437 # finish_reason == "stop" → 提前终止
438 if choice.finish_reason == "stop":
439 return step, True, assistant_msg.content
441 return step, False, ""
443 def _make_result(
444 self, success: bool, answer: str, steps: list[AgentStep],
445 total_tokens: int, total_cost: float, t0: float, error: str = None,
446 ) -> AgentResult:
447 """统一构造 AgentResult。"""
448 return AgentResult(
449 success=success,
450 final_answer=answer,
451 steps=steps,
452 total_steps=len(steps),
453 total_tokens=total_tokens,
454 total_cost_usd=total_cost,
455 total_duration_ms=(time.monotonic() - t0) * 1000,
456 error=error,
457 )
459 # ── Checkpoint / Resume ───────────────────────────────────
461 def resume(self) -> AgentResult:
462 if not self._config.checkpoint_dir:
463 raise ValueError("checkpoint_dir not configured")
464 ckpt_path = os.path.join(self._config.checkpoint_dir, "agent_checkpoint.json")
465 if not os.path.exists(ckpt_path):
466 raise FileNotFoundError(f"No checkpoint found at {ckpt_path}")
467 with open(ckpt_path) as f:
468 data = json.load(f)
469 task = data["task"]
470 start_step = data["step"] + 1
471 messages_raw = data["messages"]
472 messages = [
473 Message(
474 role=MessageRole(m["role"]),
475 content=m["content"],
476 tool_call_id=m.get("tool_call_id"),
477 tool_calls=[ToolCall(**tc) for tc in m["tool_calls"]] if m.get("tool_calls") else None,
478 )
479 for m in messages_raw
480 ]
481 t0 = time.monotonic()
482 tools = self._executor.get_schemas()
483 steps: list[AgentStep] = []
484 return self._run_loop(messages, task, tools, steps, start_step, t0)
486 def _checkpoint(self, messages: list[Message], task: str, step: int) -> None:
487 if not self._config.checkpoint_dir:
488 return
489 ckpt_path = os.path.join(self._config.checkpoint_dir, "agent_checkpoint.json")
490 data = {
491 "task": task,
492 "step": step,
493 "messages": [
494 {
495 "role": m.role.value,
496 "content": m.content,
497 "tool_call_id": m.tool_call_id,
498 "tool_calls": [
499 {"id": tc.id, "name": tc.name, "arguments": tc.arguments}
500 for tc in m.tool_calls
501 ] if m.tool_calls else None,
502 }
503 for m in messages
504 ],
505 }
506 with open(ckpt_path, "w") as f:
507 json.dump(data, f, ensure_ascii=False)
509 # ── 内部方法 ──────────────────────────────────────────────
511 def _call_with_retry(self, messages: list[Message], tools: list[Tool]) -> CompletionResult:
512 last_error = None
513 for attempt in range(self._config.max_retries + 1):
514 try:
515 return self._provider.chat(
516 messages,
517 temperature=self._config.temperature,
518 max_tokens=self._config.max_tokens,
519 tools=tools if tools else None,
520 )
521 except Exception as e:
522 last_error = e
523 if attempt < self._config.max_retries:
524 time.sleep(self._config.retry_delay)
525 raise last_error # type: ignore
527 async def _acall_with_retry(self, messages: list[Message], tools: list[Tool]) -> CompletionResult:
528 last_error = None
529 for attempt in range(self._config.max_retries + 1):
530 try:
531 return await self._provider.achat(
532 messages,
533 temperature=self._config.temperature,
534 max_tokens=self._config.max_tokens,
535 tools=tools if tools else None,
536 )
537 except Exception as e:
538 last_error = e
539 if attempt < self._config.max_retries:
540 import asyncio
541 await asyncio.sleep(self._config.retry_delay)
542 raise last_error # type: ignore
544 def _log_step(self, step: AgentStep) -> None:
545 print(f"\n── Step {step.step} ({step.duration_ms:.0f}ms, {step.tokens_used}t, ${step.cost_usd:.6f}) ──")
546 if step.thought:
547 print(f" Thought: {step.thought}")
548 if step.tool_calls:
549 for tc in step.tool_calls:
550 result_preview = step.tool_results.get(tc.id, "")[:100]
551 print(f" Tool: {tc.name}({tc.arguments}) → {result_preview}")
552 print(f" Finish: {step.finish_reason}")