Coverage for agentos/agent/tool_agent.py: 25%

281 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 13:55 +0800

1""" 

2Tool-Using Agent — 基于 LLM Function Calling 的自主 Agent 循环。 

3 

4核心模式: 

5 用户任务 → LLM 推理(tool_calls) → 工具执行 → 结果回传 → 循环直到完成 

6 

7v1.16.1: +CircuitBreaker +ToolOutputValidator +Metrics integrated into ToolExecutor/Agent. 

8v1.3.38: +streaming, retry, checkpoint/resume, tool error handling, mock provider. 

9""" 

10 

11from __future__ import annotations 

12 

13import json 

14import os 

15import time 

16from dataclasses import dataclass, field 

17from typing import Any, Callable, Generator, Optional 

18 

19from agentos.llm.base import ( 

20 LLMProvider, 

21 Message, 

22 MessageRole, 

23 CompletionResult, 

24 CompletionChoice, 

25 CompletionUsage, 

26 Tool, 

27 ToolCall, 

28) 

29from agentos.tools.circuit_breaker import CircuitBreaker, CircuitState 

30from agentos.tools.validation import ToolOutputValidator, ToolResult, ValidationResult 

31from agentos.tools.metrics import MetricsCollector, Counter, Timer 

32 

33__all__ = [ 

34 "ToolAgent", 

35 "AgentConfig", 

36 "AgentStep", 

37 "AgentResult", 

38 "ToolExecutor", 

39 "MockLLMProvider", 

40] 

41 

42 

43# ── 数据类型 ───────────────────────────────────────────────────── 

44 

45@dataclass 

46class AgentConfig: 

47 max_steps: int = 10 

48 temperature: float = 0.0 

49 max_tokens: int = 4096 

50 verbose: bool = False 

51 stop_on_error: bool = True 

52 max_retries: int = 2 

53 retry_delay: float = 0.5 

54 checkpoint_dir: str = "" 

55 

56 

57@dataclass 

58class AgentStep: 

59 step: int 

60 thought: str = "" 

61 tool_calls: list[ToolCall] = field(default_factory=list) 

62 tool_results: dict[str, str] = field(default_factory=dict) 

63 finish_reason: str = "" 

64 tokens_used: int = 0 

65 cost_usd: float = 0.0 

66 duration_ms: float = 0.0 

67 

68 

69@dataclass 

70class AgentResult: 

71 success: bool = True 

72 final_answer: str = "" 

73 steps: list[AgentStep] = field(default_factory=list) 

74 total_steps: int = 0 

75 total_tokens: int = 0 

76 total_cost_usd: float = 0.0 

77 total_duration_ms: float = 0.0 

78 error: str | None = None 

79 

80 

81# ── 工具执行器 ─────────────────────────────────────────────────── 

82 

83class ToolExecutor: 

84 """工具注册与执行器。 

85 

86 v1.16.1: 集成 CircuitBreaker(熔断保护)、ToolOutputValidator(输出校验)、 

87 MetricsCollector(指标收集)。所有参数均为可选,不传则退化为原始行为。 

88 """ 

89 

90 def __init__( 

91 self, 

92 circuit_breaker: Optional[CircuitBreaker] = None, 

93 validator: Optional[ToolOutputValidator] = None, 

94 metrics: Optional[MetricsCollector] = None, 

95 ): 

96 self._tools: dict[str, Callable[..., str]] = {} 

97 self._schemas: dict[str, Tool] = {} 

98 self._cb = circuit_breaker 

99 self._validator = validator 

100 self._metrics = metrics 

101 

102 def register(self, tool: Tool, handler: Callable[..., str]) -> None: 

103 self._tools[tool.function.name] = handler 

104 self._schemas[tool.function.name] = tool 

105 

106 def get_schemas(self) -> list[Tool]: 

107 return list(self._schemas.values()) 

108 

109 def execute(self, tool_call: ToolCall) -> str: 

110 """执行工具调用,经 CircuitBreaker → 执行 → Validator → Metrics 全链路。""" 

111 handler = self._tools.get(tool_call.name) 

112 if handler is None: 

113 return json.dumps({"error": f"Unknown tool: {tool_call.name}"}) 

114 

115 # 1. 熔断器检查 

116 if self._cb is not None: 

117 if not self._cb.allow_request(): 

118 self._cb.record_failure() 

119 return json.dumps({ 

120 "error": f"Circuit breaker OPEN for '{self._cb.name}'", 

121 "circuit_state": self._cb.state.name, 

122 }) 

123 

124 t0 = time.monotonic() 

125 try: 

126 raw_output = str(handler(**tool_call.parsed_arguments)) 

127 

128 # 2. 输出校验 

129 validation_msg = "" 

130 if self._validator is not None: 

131 tr = ToolResult(output=raw_output, tool_name=tool_call.name) 

132 val_result: ValidationResult = self._validator.validate(tr) 

133 if not val_result.is_valid: 

134 issues = "; ".join(i.message for i in val_result.issues) 

135 validation_msg = f" [validation: {issues}]" 

136 

137 # 3. 熔断器记录成功 

138 if self._cb is not None: 

139 self._cb.record_success() 

140 

141 # 4. 指标记录 

142 elapsed = (time.monotonic() - t0) * 1000 

143 if self._metrics is not None: 

144 self._metrics.get_counter("tool_calls_total").inc(tool_call.name) 

145 self._metrics.get_counter("tool_calls_success").inc(tool_call.name) 

146 self._metrics.get_timer("tool_latency_ms").record(elapsed) 

147 

148 return raw_output if not validation_msg else raw_output + validation_msg 

149 

150 except Exception as e: 

151 elapsed = (time.monotonic() - t0) * 1000 

152 if self._cb is not None: 

153 self._cb.record_failure() 

154 if self._metrics is not None: 

155 self._metrics.get_counter("tool_calls_total").inc(tool_call.name) 

156 self._metrics.get_counter("tool_calls_errors").inc(tool_call.name) 

157 return json.dumps({"error": str(e)}) 

158 

159 

160# ── MockLLMProvider ────────────────────────────────────────────── 

161 

162class MockLLMProvider(LLMProvider): 

163 """可编程响应的 Mock Provider,供集成测试使用。""" 

164 

165 def __init__(self, responses: list[dict]): 

166 super().__init__(model="mock", api_key="mock") 

167 self._responses = responses 

168 self._cursor = 0 

169 self.calls: list[dict] = [] 

170 

171 def chat(self, messages=None, *, temperature=0, max_tokens=4096, tools=None, **kwargs): 

172 if self._cursor >= len(self._responses): 

173 return self._build_result({"content": "done", "finish_reason": "stop"}) 

174 resp = self._responses[self._cursor] 

175 self._cursor += 1 

176 self.calls.append({ 

177 "tools": [t.function.name for t in (tools or [])], 

178 "cursor": self._cursor - 1, 

179 }) 

180 return self._build_result(resp) 

181 

182 async def achat(self, *args, **kwargs): 

183 return self.chat(*args, **kwargs) 

184 

185 @property 

186 def provider_name(self) -> str: 

187 return "mock" 

188 

189 @staticmethod 

190 def text_response(content: str, finish_reason: str = "stop") -> dict: 

191 return {"content": content, "finish_reason": finish_reason} 

192 

193 @staticmethod 

194 def tool_response(name: str, arguments: dict, tool_call_id: str = "") -> dict: 

195 tid = tool_call_id or f"tc_{name}" 

196 return { 

197 "content": "", 

198 "tool_calls": [ToolCall(id=tid, name=name, arguments=json.dumps(arguments))], 

199 "finish_reason": "tool_calls", 

200 } 

201 

202 def _build_result(self, resp: dict) -> CompletionResult: 

203 msg = Message( 

204 role=MessageRole.ASSISTANT, 

205 content=resp.get("content", ""), 

206 tool_calls=resp.get("tool_calls"), 

207 ) 

208 choice = CompletionChoice( 

209 index=0, 

210 message=msg, 

211 finish_reason=resp.get("finish_reason", "stop"), 

212 ) 

213 return CompletionResult( 

214 id=f"mock_{self._cursor}", 

215 model="mock-model", 

216 choices=[choice], 

217 usage=CompletionUsage(prompt_tokens=5, completion_tokens=len(resp.get("content", "")) + 3, total_tokens=len(resp.get("content", "")) + 8), 

218 ) 

219 

220 

221# ── Tool-Using Agent ───────────────────────────────────────────── 

222 

223class ToolAgent: 

224 """基于 LLM Function Calling 的自主 Agent。 

225 

226 用法: 

227 from agentos.agent import ToolAgent, ToolExecutor 

228 from agentos.llm import create_provider, Tool 

229 

230 provider = create_provider("openai") 

231 executor = ToolExecutor() 

232 executor.register( 

233 Tool.from_function("get_weather", "获取天气", {"city": ...}), 

234 lambda city: f"{city}: 22°C sunny" 

235 ) 

236 agent = ToolAgent(provider, executor) 

237 result = agent.run("北京天气怎么样?") 

238 print(result.final_answer) 

239 """ 

240 

241 def __init__( 

242 self, 

243 provider: LLMProvider, 

244 tool_executor: ToolExecutor, 

245 *, 

246 config: AgentConfig | None = None, 

247 system_prompt: str = "", 

248 metrics: Optional[MetricsCollector] = None, 

249 ): 

250 self._provider = provider 

251 self._executor = tool_executor 

252 self._config = config or AgentConfig() 

253 self._system_prompt = system_prompt or ( 

254 "你是一个智能助手。你可以使用工具来获取信息。" 

255 "当你可以给出最终答案时,直接回答,不要再调用工具。" 

256 "用中文回答。" 

257 ) 

258 self._metrics = metrics 

259 

260 # ── 同步 ────────────────────────────────────────────────── 

261 

262 def run(self, task: str) -> AgentResult: 

263 t0 = time.monotonic() 

264 steps: list[AgentStep] = [] 

265 tools = self._executor.get_schemas() 

266 messages: list[Message] = [ 

267 Message(role=MessageRole.SYSTEM, content=self._system_prompt), 

268 Message(role=MessageRole.USER, content=task), 

269 ] 

270 return self._run_loop(messages, task, tools, steps, 1, t0) 

271 

272 def _run_loop( 

273 self, messages, task, tools, steps, start_step, t0, 

274 ) -> AgentResult: 

275 final_answer = "" 

276 total_tokens = 0 

277 total_cost = 0.0 

278 step_num = start_step 

279 

280 try: 

281 for step_num in range(start_step, self._config.max_steps + 1): 

282 result = self._call_with_retry(messages, tools) 

283 step, done, final = self._process_step(result, step_num) 

284 total_tokens += step.tokens_used 

285 total_cost += step.cost_usd 

286 steps.append(step) 

287 if done: 

288 final_answer = final 

289 break 

290 messages.append(result.choices[0].message) 

291 for tc in step.tool_calls: 

292 messages.append(Message( 

293 role=MessageRole.TOOL, 

294 content=step.tool_results.get(tc.id, ""), 

295 tool_call_id=tc.id, 

296 )) 

297 self._checkpoint(messages, task, step_num) 

298 else: 

299 return self._make_result( 

300 False, "", steps, total_tokens, total_cost, t0, 

301 f"Reached max steps ({self._config.max_steps}) without final answer", 

302 ) 

303 except Exception as e: 

304 return self._make_result(False, "", steps, total_tokens, total_cost, t0, str(e)) 

305 

306 return self._make_result(True, final_answer, steps, total_tokens, total_cost, t0) 

307 

308 # ── 流式 ────────────────────────────────────────────────── 

309 

310 def run_stream(self, task: str) -> Generator[AgentStep, None, AgentResult]: 

311 t0 = time.monotonic() 

312 steps: list[AgentStep] = [] 

313 tools = self._executor.get_schemas() 

314 messages: list[Message] = [ 

315 Message(role=MessageRole.SYSTEM, content=self._system_prompt), 

316 Message(role=MessageRole.USER, content=task), 

317 ] 

318 total_tokens = 0 

319 total_cost = 0.0 

320 final_answer = "" 

321 

322 try: 

323 for step_num in range(1, self._config.max_steps + 1): 

324 result = self._call_with_retry(messages, tools) 

325 step, done, final = self._process_step(result, step_num) 

326 total_tokens += step.tokens_used 

327 total_cost += step.cost_usd 

328 yield step 

329 steps.append(step) 

330 if done: 

331 final_answer = final 

332 break 

333 messages.append(result.choices[0].message) 

334 for tc in step.tool_calls: 

335 messages.append(Message( 

336 role=MessageRole.TOOL, 

337 content=step.tool_results.get(tc.id, ""), 

338 tool_call_id=tc.id, 

339 )) 

340 self._checkpoint(messages, task, step_num) 

341 else: 

342 return self._make_result( 

343 False, "", steps, total_tokens, total_cost, t0, 

344 f"Reached max steps ({self._config.max_steps}) without final answer", 

345 ) 

346 except Exception as e: 

347 return self._make_result(False, "", steps, total_tokens, total_cost, t0, str(e)) 

348 

349 return self._make_result(True, final_answer, steps, total_tokens, total_cost, t0) 

350 

351 # ── 异步 ────────────────────────────────────────────────── 

352 

353 async def arun(self, task: str) -> AgentResult: 

354 t0 = time.monotonic() 

355 steps: list[AgentStep] = [] 

356 tools = self._executor.get_schemas() 

357 messages: list[Message] = [ 

358 Message(role=MessageRole.SYSTEM, content=self._system_prompt), 

359 Message(role=MessageRole.USER, content=task), 

360 ] 

361 final_answer = "" 

362 total_tokens = 0 

363 total_cost = 0.0 

364 

365 try: 

366 for step_num in range(1, self._config.max_steps + 1): 

367 result = await self._acall_with_retry(messages, tools) 

368 step, done, final = self._process_step(result, step_num) 

369 total_tokens += step.tokens_used 

370 total_cost += step.cost_usd 

371 steps.append(step) 

372 if done: 

373 final_answer = final 

374 break 

375 messages.append(result.choices[0].message) 

376 for tc in step.tool_calls: 

377 messages.append(Message( 

378 role=MessageRole.TOOL, 

379 content=step.tool_results.get(tc.id, ""), 

380 tool_call_id=tc.id, 

381 )) 

382 self._checkpoint(messages, task, step_num) 

383 else: 

384 return self._make_result( 

385 False, "", steps, total_tokens, total_cost, t0, 

386 f"Reached max steps ({self._config.max_steps}) without final answer", 

387 ) 

388 except Exception as e: 

389 return self._make_result(False, "", steps, total_tokens, total_cost, t0, str(e)) 

390 

391 return self._make_result(True, final_answer, steps, total_tokens, total_cost, t0) 

392 

393 # ── 共享步骤逻辑 ─────────────────────────────────────────────── 

394 

395 def _process_step( 

396 self, result: CompletionResult, step_num: int, 

397 ) -> tuple[AgentStep, bool, str]: 

398 """处理单步 LLM 结果:构建 AgentStep、执行工具、判断终止。 

399 

400 Returns: 

401 (step, done, final_answer) 

402 """ 

403 step_t0 = time.monotonic() 

404 choice = result.choices[0] 

405 assistant_msg = choice.message 

406 

407 step = AgentStep( 

408 step=step_num, 

409 thought=assistant_msg.content, 

410 tool_calls=assistant_msg.tool_calls or [], 

411 finish_reason=choice.finish_reason, 

412 tokens_used=result.usage.total_tokens, 

413 cost_usd=result.usage.cost_usd, 

414 duration_ms=(time.monotonic() - step_t0) * 1000, 

415 ) 

416 

417 if self._config.verbose: 

418 self._log_step(step) 

419 

420 # Metrics: track LLM calls and tokens 

421 if self._metrics is not None: 

422 self._metrics.get_counter("llm_calls_total").inc() 

423 self._metrics.get_counter("llm_tokens_total").inc(result.usage.total_tokens) 

424 self._metrics.get_counter("agent_steps_total").inc() 

425 

426 # 无工具调用 → 终止,内容即为答案 

427 if not assistant_msg.tool_calls: 

428 return step, True, assistant_msg.content 

429 

430 # 执行工具调用 

431 for tc in assistant_msg.tool_calls: 

432 tool_result = self._executor.execute(tc) 

433 step.tool_results[tc.id] = tool_result 

434 if "error" in tool_result and self._config.stop_on_error: 

435 raise RuntimeError(f"Tool '{tc.name}' error: {tool_result}") 

436 

437 # finish_reason == "stop" → 提前终止 

438 if choice.finish_reason == "stop": 

439 return step, True, assistant_msg.content 

440 

441 return step, False, "" 

442 

443 def _make_result( 

444 self, success: bool, answer: str, steps: list[AgentStep], 

445 total_tokens: int, total_cost: float, t0: float, error: str = None, 

446 ) -> AgentResult: 

447 """统一构造 AgentResult。""" 

448 return AgentResult( 

449 success=success, 

450 final_answer=answer, 

451 steps=steps, 

452 total_steps=len(steps), 

453 total_tokens=total_tokens, 

454 total_cost_usd=total_cost, 

455 total_duration_ms=(time.monotonic() - t0) * 1000, 

456 error=error, 

457 ) 

458 

459 # ── Checkpoint / Resume ─────────────────────────────────── 

460 

461 def resume(self) -> AgentResult: 

462 if not self._config.checkpoint_dir: 

463 raise ValueError("checkpoint_dir not configured") 

464 ckpt_path = os.path.join(self._config.checkpoint_dir, "agent_checkpoint.json") 

465 if not os.path.exists(ckpt_path): 

466 raise FileNotFoundError(f"No checkpoint found at {ckpt_path}") 

467 with open(ckpt_path) as f: 

468 data = json.load(f) 

469 task = data["task"] 

470 start_step = data["step"] + 1 

471 messages_raw = data["messages"] 

472 messages = [ 

473 Message( 

474 role=MessageRole(m["role"]), 

475 content=m["content"], 

476 tool_call_id=m.get("tool_call_id"), 

477 tool_calls=[ToolCall(**tc) for tc in m["tool_calls"]] if m.get("tool_calls") else None, 

478 ) 

479 for m in messages_raw 

480 ] 

481 t0 = time.monotonic() 

482 tools = self._executor.get_schemas() 

483 steps: list[AgentStep] = [] 

484 return self._run_loop(messages, task, tools, steps, start_step, t0) 

485 

486 def _checkpoint(self, messages: list[Message], task: str, step: int) -> None: 

487 if not self._config.checkpoint_dir: 

488 return 

489 ckpt_path = os.path.join(self._config.checkpoint_dir, "agent_checkpoint.json") 

490 data = { 

491 "task": task, 

492 "step": step, 

493 "messages": [ 

494 { 

495 "role": m.role.value, 

496 "content": m.content, 

497 "tool_call_id": m.tool_call_id, 

498 "tool_calls": [ 

499 {"id": tc.id, "name": tc.name, "arguments": tc.arguments} 

500 for tc in m.tool_calls 

501 ] if m.tool_calls else None, 

502 } 

503 for m in messages 

504 ], 

505 } 

506 with open(ckpt_path, "w") as f: 

507 json.dump(data, f, ensure_ascii=False) 

508 

509 # ── 内部方法 ────────────────────────────────────────────── 

510 

511 def _call_with_retry(self, messages: list[Message], tools: list[Tool]) -> CompletionResult: 

512 last_error = None 

513 for attempt in range(self._config.max_retries + 1): 

514 try: 

515 return self._provider.chat( 

516 messages, 

517 temperature=self._config.temperature, 

518 max_tokens=self._config.max_tokens, 

519 tools=tools if tools else None, 

520 ) 

521 except Exception as e: 

522 last_error = e 

523 if attempt < self._config.max_retries: 

524 time.sleep(self._config.retry_delay) 

525 raise last_error # type: ignore 

526 

527 async def _acall_with_retry(self, messages: list[Message], tools: list[Tool]) -> CompletionResult: 

528 last_error = None 

529 for attempt in range(self._config.max_retries + 1): 

530 try: 

531 return await self._provider.achat( 

532 messages, 

533 temperature=self._config.temperature, 

534 max_tokens=self._config.max_tokens, 

535 tools=tools if tools else None, 

536 ) 

537 except Exception as e: 

538 last_error = e 

539 if attempt < self._config.max_retries: 

540 import asyncio 

541 await asyncio.sleep(self._config.retry_delay) 

542 raise last_error # type: ignore 

543 

544 def _log_step(self, step: AgentStep) -> None: 

545 print(f"\n── Step {step.step} ({step.duration_ms:.0f}ms, {step.tokens_used}t, ${step.cost_usd:.6f}) ──") 

546 if step.thought: 

547 print(f" Thought: {step.thought}") 

548 if step.tool_calls: 

549 for tc in step.tool_calls: 

550 result_preview = step.tool_results.get(tc.id, "")[:100] 

551 print(f" Tool: {tc.name}({tc.arguments}) → {result_preview}") 

552 print(f" Finish: {step.finish_reason}")