Coverage for agentos/agent/tool_agent.py: 27%

244 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Tool-Using Agent — 基于 LLM Function Calling 的自主 Agent 循环。 

3 

4核心模式: 

5 用户任务 → LLM 推理(tool_calls) → 工具执行 → 结果回传 → 循环直到完成 

6 

7v1.3.38: +streaming, retry, checkpoint/resume, tool error handling, mock provider. 

8""" 

9 

10from __future__ import annotations 

11 

12import json 

13import os 

14import time 

15from dataclasses import dataclass, field 

16from typing import Any, Callable, Generator 

17 

18from agentos.llm.base import ( 

19 LLMProvider, 

20 Message, 

21 MessageRole, 

22 CompletionResult, 

23 CompletionChoice, 

24 CompletionUsage, 

25 Tool, 

26 ToolCall, 

27) 

28 

29__all__ = [ 

30 "ToolAgent", 

31 "AgentConfig", 

32 "AgentStep", 

33 "AgentResult", 

34 "ToolExecutor", 

35 "MockLLMProvider", 

36] 

37 

38 

39# ── 数据类型 ───────────────────────────────────────────────────── 

40 

41@dataclass 

42class AgentConfig: 

43 max_steps: int = 10 

44 temperature: float = 0.0 

45 max_tokens: int = 4096 

46 verbose: bool = False 

47 stop_on_error: bool = True 

48 max_retries: int = 2 

49 retry_delay: float = 0.5 

50 checkpoint_dir: str = "" 

51 

52 

53@dataclass 

54class AgentStep: 

55 step: int 

56 thought: str = "" 

57 tool_calls: list[ToolCall] = field(default_factory=list) 

58 tool_results: dict[str, str] = field(default_factory=dict) 

59 finish_reason: str = "" 

60 tokens_used: int = 0 

61 cost_usd: float = 0.0 

62 duration_ms: float = 0.0 

63 

64 

65@dataclass 

66class AgentResult: 

67 success: bool = True 

68 final_answer: str = "" 

69 steps: list[AgentStep] = field(default_factory=list) 

70 total_steps: int = 0 

71 total_tokens: int = 0 

72 total_cost_usd: float = 0.0 

73 total_duration_ms: float = 0.0 

74 error: str | None = None 

75 

76 

77# ── 工具执行器 ─────────────────────────────────────────────────── 

78 

79class ToolExecutor: 

80 def __init__(self): 

81 self._tools: dict[str, Callable[..., str]] = {} 

82 self._schemas: dict[str, Tool] = {} 

83 

84 def register(self, tool: Tool, handler: Callable[..., str]) -> None: 

85 self._tools[tool.function.name] = handler 

86 self._schemas[tool.function.name] = tool 

87 

88 def get_schemas(self) -> list[Tool]: 

89 return list(self._schemas.values()) 

90 

91 def execute(self, tool_call: ToolCall) -> str: 

92 handler = self._tools.get(tool_call.name) 

93 if handler is None: 

94 return json.dumps({"error": f"Unknown tool: {tool_call.name}"}) 

95 try: 

96 return str(handler(**tool_call.parsed_arguments)) 

97 except Exception as e: 

98 return json.dumps({"error": str(e)}) 

99 

100 

101# ── MockLLMProvider ────────────────────────────────────────────── 

102 

103class MockLLMProvider(LLMProvider): 

104 """可编程响应的 Mock Provider,供集成测试使用。""" 

105 

106 def __init__(self, responses: list[dict]): 

107 super().__init__(model="mock", api_key="mock") 

108 self._responses = responses 

109 self._cursor = 0 

110 self.calls: list[dict] = [] 

111 

112 def chat(self, messages=None, *, temperature=0, max_tokens=4096, tools=None, **kwargs): 

113 if self._cursor >= len(self._responses): 

114 return self._build_result({"content": "done", "finish_reason": "stop"}) 

115 resp = self._responses[self._cursor] 

116 self._cursor += 1 

117 self.calls.append({ 

118 "tools": [t.function.name for t in (tools or [])], 

119 "cursor": self._cursor - 1, 

120 }) 

121 return self._build_result(resp) 

122 

123 async def achat(self, *args, **kwargs): 

124 return self.chat(*args, **kwargs) 

125 

126 @property 

127 def provider_name(self) -> str: 

128 return "mock" 

129 

130 @staticmethod 

131 def text_response(content: str, finish_reason: str = "stop") -> dict: 

132 return {"content": content, "finish_reason": finish_reason} 

133 

134 @staticmethod 

135 def tool_response(name: str, arguments: dict, tool_call_id: str = "") -> dict: 

136 tid = tool_call_id or f"tc_{name}" 

137 return { 

138 "content": "", 

139 "tool_calls": [ToolCall(id=tid, name=name, arguments=json.dumps(arguments))], 

140 "finish_reason": "tool_calls", 

141 } 

142 

143 def _build_result(self, resp: dict) -> CompletionResult: 

144 msg = Message( 

145 role=MessageRole.ASSISTANT, 

146 content=resp.get("content", ""), 

147 tool_calls=resp.get("tool_calls"), 

148 ) 

149 choice = CompletionChoice( 

150 index=0, 

151 message=msg, 

152 finish_reason=resp.get("finish_reason", "stop"), 

153 ) 

154 return CompletionResult( 

155 id=f"mock_{self._cursor}", 

156 model="mock-model", 

157 choices=[choice], 

158 usage=CompletionUsage(prompt_tokens=5, completion_tokens=len(resp.get("content", "")) + 3, total_tokens=len(resp.get("content", "")) + 8), 

159 ) 

160 

161 

162# ── Tool-Using Agent ───────────────────────────────────────────── 

163 

164class ToolAgent: 

165 """基于 LLM Function Calling 的自主 Agent。 

166 

167 用法: 

168 from agentos.agent import ToolAgent, ToolExecutor 

169 from agentos.llm import create_provider, Tool 

170 

171 provider = create_provider("openai") 

172 executor = ToolExecutor() 

173 executor.register( 

174 Tool.from_function("get_weather", "获取天气", {"city": ...}), 

175 lambda city: f"{city}: 22°C sunny" 

176 ) 

177 agent = ToolAgent(provider, executor) 

178 result = agent.run("北京天气怎么样?") 

179 print(result.final_answer) 

180 """ 

181 

182 def __init__( 

183 self, 

184 provider: LLMProvider, 

185 tool_executor: ToolExecutor, 

186 *, 

187 config: AgentConfig | None = None, 

188 system_prompt: str = "", 

189 ): 

190 self._provider = provider 

191 self._executor = tool_executor 

192 self._config = config or AgentConfig() 

193 self._system_prompt = system_prompt or ( 

194 "你是一个智能助手。你可以使用工具来获取信息。" 

195 "当你可以给出最终答案时,直接回答,不要再调用工具。" 

196 "用中文回答。" 

197 ) 

198 

199 # ── 同步 ────────────────────────────────────────────────── 

200 

201 def run(self, task: str) -> AgentResult: 

202 t0 = time.monotonic() 

203 steps: list[AgentStep] = [] 

204 tools = self._executor.get_schemas() 

205 messages: list[Message] = [ 

206 Message(role=MessageRole.SYSTEM, content=self._system_prompt), 

207 Message(role=MessageRole.USER, content=task), 

208 ] 

209 return self._run_loop(messages, task, tools, steps, 1, t0) 

210 

211 def _run_loop( 

212 self, messages, task, tools, steps, start_step, t0, 

213 ) -> AgentResult: 

214 final_answer = "" 

215 total_tokens = 0 

216 total_cost = 0.0 

217 step_num = start_step 

218 

219 try: 

220 for step_num in range(start_step, self._config.max_steps + 1): 

221 result = self._call_with_retry(messages, tools) 

222 step, done, final = self._process_step(result, step_num) 

223 total_tokens += step.tokens_used 

224 total_cost += step.cost_usd 

225 steps.append(step) 

226 if done: 

227 final_answer = final 

228 break 

229 messages.append(result.choices[0].message) 

230 for tc in step.tool_calls: 

231 messages.append(Message( 

232 role=MessageRole.TOOL, 

233 content=step.tool_results.get(tc.id, ""), 

234 tool_call_id=tc.id, 

235 )) 

236 self._checkpoint(messages, task, step_num) 

237 else: 

238 return self._make_result( 

239 False, "", steps, total_tokens, total_cost, t0, 

240 f"Reached max steps ({self._config.max_steps}) without final answer", 

241 ) 

242 except Exception as e: 

243 return self._make_result(False, "", steps, total_tokens, total_cost, t0, str(e)) 

244 

245 return self._make_result(True, final_answer, steps, total_tokens, total_cost, t0) 

246 

247 # ── 流式 ────────────────────────────────────────────────── 

248 

249 def run_stream(self, task: str) -> Generator[AgentStep, None, AgentResult]: 

250 t0 = time.monotonic() 

251 steps: list[AgentStep] = [] 

252 tools = self._executor.get_schemas() 

253 messages: list[Message] = [ 

254 Message(role=MessageRole.SYSTEM, content=self._system_prompt), 

255 Message(role=MessageRole.USER, content=task), 

256 ] 

257 total_tokens = 0 

258 total_cost = 0.0 

259 final_answer = "" 

260 

261 try: 

262 for step_num in range(1, self._config.max_steps + 1): 

263 result = self._call_with_retry(messages, tools) 

264 step, done, final = self._process_step(result, step_num) 

265 total_tokens += step.tokens_used 

266 total_cost += step.cost_usd 

267 yield step 

268 steps.append(step) 

269 if done: 

270 final_answer = final 

271 break 

272 messages.append(result.choices[0].message) 

273 for tc in step.tool_calls: 

274 messages.append(Message( 

275 role=MessageRole.TOOL, 

276 content=step.tool_results.get(tc.id, ""), 

277 tool_call_id=tc.id, 

278 )) 

279 self._checkpoint(messages, task, step_num) 

280 else: 

281 return self._make_result( 

282 False, "", steps, total_tokens, total_cost, t0, 

283 f"Reached max steps ({self._config.max_steps}) without final answer", 

284 ) 

285 except Exception as e: 

286 return self._make_result(False, "", steps, total_tokens, total_cost, t0, str(e)) 

287 

288 return self._make_result(True, final_answer, steps, total_tokens, total_cost, t0) 

289 

290 # ── 异步 ────────────────────────────────────────────────── 

291 

292 async def arun(self, task: str) -> AgentResult: 

293 t0 = time.monotonic() 

294 steps: list[AgentStep] = [] 

295 tools = self._executor.get_schemas() 

296 messages: list[Message] = [ 

297 Message(role=MessageRole.SYSTEM, content=self._system_prompt), 

298 Message(role=MessageRole.USER, content=task), 

299 ] 

300 final_answer = "" 

301 total_tokens = 0 

302 total_cost = 0.0 

303 

304 try: 

305 for step_num in range(1, self._config.max_steps + 1): 

306 result = await self._acall_with_retry(messages, tools) 

307 step, done, final = self._process_step(result, step_num) 

308 total_tokens += step.tokens_used 

309 total_cost += step.cost_usd 

310 steps.append(step) 

311 if done: 

312 final_answer = final 

313 break 

314 messages.append(result.choices[0].message) 

315 for tc in step.tool_calls: 

316 messages.append(Message( 

317 role=MessageRole.TOOL, 

318 content=step.tool_results.get(tc.id, ""), 

319 tool_call_id=tc.id, 

320 )) 

321 self._checkpoint(messages, task, step_num) 

322 else: 

323 return self._make_result( 

324 False, "", steps, total_tokens, total_cost, t0, 

325 f"Reached max steps ({self._config.max_steps}) without final answer", 

326 ) 

327 except Exception as e: 

328 return self._make_result(False, "", steps, total_tokens, total_cost, t0, str(e)) 

329 

330 return self._make_result(True, final_answer, steps, total_tokens, total_cost, t0) 

331 

332 # ── 共享步骤逻辑 ─────────────────────────────────────────────── 

333 

334 def _process_step( 

335 self, result: CompletionResult, step_num: int, 

336 ) -> tuple[AgentStep, bool, str]: 

337 """处理单步 LLM 结果:构建 AgentStep、执行工具、判断终止。 

338 

339 Returns: 

340 (step, done, final_answer) 

341 """ 

342 step_t0 = time.monotonic() 

343 choice = result.choices[0] 

344 assistant_msg = choice.message 

345 

346 step = AgentStep( 

347 step=step_num, 

348 thought=assistant_msg.content, 

349 tool_calls=assistant_msg.tool_calls or [], 

350 finish_reason=choice.finish_reason, 

351 tokens_used=result.usage.total_tokens, 

352 cost_usd=result.usage.cost_usd, 

353 duration_ms=(time.monotonic() - step_t0) * 1000, 

354 ) 

355 

356 if self._config.verbose: 

357 self._log_step(step) 

358 

359 # 无工具调用 → 终止,内容即为答案 

360 if not assistant_msg.tool_calls: 

361 return step, True, assistant_msg.content 

362 

363 # 执行工具调用 

364 for tc in assistant_msg.tool_calls: 

365 tool_result = self._executor.execute(tc) 

366 step.tool_results[tc.id] = tool_result 

367 if "error" in tool_result and self._config.stop_on_error: 

368 raise RuntimeError(f"Tool '{tc.name}' error: {tool_result}") 

369 

370 # finish_reason == "stop" → 提前终止 

371 if choice.finish_reason == "stop": 

372 return step, True, assistant_msg.content 

373 

374 return step, False, "" 

375 

376 def _make_result( 

377 self, success: bool, answer: str, steps: list[AgentStep], 

378 total_tokens: int, total_cost: float, t0: float, error: str = None, 

379 ) -> AgentResult: 

380 """统一构造 AgentResult。""" 

381 return AgentResult( 

382 success=success, 

383 final_answer=answer, 

384 steps=steps, 

385 total_steps=len(steps), 

386 total_tokens=total_tokens, 

387 total_cost_usd=total_cost, 

388 total_duration_ms=(time.monotonic() - t0) * 1000, 

389 error=error, 

390 ) 

391 

392 # ── Checkpoint / Resume ─────────────────────────────────── 

393 

394 def resume(self) -> AgentResult: 

395 if not self._config.checkpoint_dir: 

396 raise ValueError("checkpoint_dir not configured") 

397 ckpt_path = os.path.join(self._config.checkpoint_dir, "agent_checkpoint.json") 

398 if not os.path.exists(ckpt_path): 

399 raise FileNotFoundError(f"No checkpoint found at {ckpt_path}") 

400 with open(ckpt_path) as f: 

401 data = json.load(f) 

402 task = data["task"] 

403 start_step = data["step"] + 1 

404 messages_raw = data["messages"] 

405 messages = [ 

406 Message( 

407 role=MessageRole(m["role"]), 

408 content=m["content"], 

409 tool_call_id=m.get("tool_call_id"), 

410 tool_calls=[ToolCall(**tc) for tc in m["tool_calls"]] if m.get("tool_calls") else None, 

411 ) 

412 for m in messages_raw 

413 ] 

414 t0 = time.monotonic() 

415 tools = self._executor.get_schemas() 

416 steps: list[AgentStep] = [] 

417 return self._run_loop(messages, task, tools, steps, start_step, t0) 

418 

419 def _checkpoint(self, messages: list[Message], task: str, step: int) -> None: 

420 if not self._config.checkpoint_dir: 

421 return 

422 ckpt_path = os.path.join(self._config.checkpoint_dir, "agent_checkpoint.json") 

423 data = { 

424 "task": task, 

425 "step": step, 

426 "messages": [ 

427 { 

428 "role": m.role.value, 

429 "content": m.content, 

430 "tool_call_id": m.tool_call_id, 

431 "tool_calls": [ 

432 {"id": tc.id, "name": tc.name, "arguments": tc.arguments} 

433 for tc in m.tool_calls 

434 ] if m.tool_calls else None, 

435 } 

436 for m in messages 

437 ], 

438 } 

439 with open(ckpt_path, "w") as f: 

440 json.dump(data, f, ensure_ascii=False) 

441 

442 # ── 内部方法 ────────────────────────────────────────────── 

443 

444 def _call_with_retry(self, messages: list[Message], tools: list[Tool]) -> CompletionResult: 

445 last_error = None 

446 for attempt in range(self._config.max_retries + 1): 

447 try: 

448 return self._provider.chat( 

449 messages, 

450 temperature=self._config.temperature, 

451 max_tokens=self._config.max_tokens, 

452 tools=tools if tools else None, 

453 ) 

454 except Exception as e: 

455 last_error = e 

456 if attempt < self._config.max_retries: 

457 time.sleep(self._config.retry_delay) 

458 raise last_error # type: ignore 

459 

460 async def _acall_with_retry(self, messages: list[Message], tools: list[Tool]) -> CompletionResult: 

461 last_error = None 

462 for attempt in range(self._config.max_retries + 1): 

463 try: 

464 return await self._provider.achat( 

465 messages, 

466 temperature=self._config.temperature, 

467 max_tokens=self._config.max_tokens, 

468 tools=tools if tools else None, 

469 ) 

470 except Exception as e: 

471 last_error = e 

472 if attempt < self._config.max_retries: 

473 import asyncio 

474 await asyncio.sleep(self._config.retry_delay) 

475 raise last_error # type: ignore 

476 

477 def _log_step(self, step: AgentStep) -> None: 

478 print(f"\n── Step {step.step} ({step.duration_ms:.0f}ms, {step.tokens_used}t, ${step.cost_usd:.6f}) ──") 

479 if step.thought: 

480 print(f" Thought: {step.thought}") 

481 if step.tool_calls: 

482 for tc in step.tool_calls: 

483 result_preview = step.tool_results.get(tc.id, "")[:100] 

484 print(f" Tool: {tc.name}({tc.arguments}) → {result_preview}") 

485 print(f" Finish: {step.finish_reason}")