Coverage for agentos/tools/registry.py: 100%
41 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 17:38 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 17:38 +0800
1"""
2统一工具注册表 — 核心循环不关心具体实现。
3"""
5from __future__ import annotations
7import asyncio
8import uuid
9from typing import Any
11from agentos.tools.base import BaseTool, ToolCall, ToolResult
14class ToolRegistry:
15 """统一工具注册表。所有工具在这里注册,核心循环不关心具体实现。"""
17 def __init__(self):
18 self._tools: dict[str, BaseTool] = {}
20 def register(self, tool: BaseTool):
21 self._tools[tool.name] = tool
23 def register_many(self, tools: list[BaseTool]):
24 for tool in tools:
25 self.register(tool)
27 def get(self, name: str) -> BaseTool | None:
28 return self._tools.get(name)
30 def list_names(self) -> list[str]:
31 return list(self._tools.keys())
33 def get_schemas_for_model(self, model_type: str) -> list[dict]:
34 """根据模型类型生成工具schema。"""
35 if model_type in ("openai", "deepseek", "kimi", "qwen", "glm", "minimax"):
36 return [t.to_openai_schema() for t in self._tools.values()]
37 elif model_type == "anthropic":
38 return [t.to_anthropic_schema() for t in self._tools.values()]
39 else:
40 return [t.to_openai_schema() for t in self._tools.values()]
42 async def execute_batch(
43 self, calls: list[ToolCall], sandbox=None
44 ) -> list[ToolResult]:
45 """并行执行一组工具调用。"""
46 tasks = []
47 for call in calls:
48 tool = self._tools.get(call.name)
49 if not tool:
50 tasks.append(self._unknown_tool_result(call))
51 else:
52 tasks.append(self._execute_one(tool, call, sandbox))
53 return await asyncio.gather(*tasks)
55 async def _execute_one(
56 self, tool: BaseTool, call: ToolCall, sandbox=None
57 ) -> ToolResult:
58 try:
59 return await tool.execute(call.arguments, sandbox=sandbox)
60 except Exception as e:
61 return ToolResult(call_id=call.id, error=str(e))
63 async def _unknown_tool_result(self, call: ToolCall) -> ToolResult:
64 return ToolResult(
65 call_id=call.id,
66 error=f"Unknown tool: {call.name}. Available: {self.list_names()}",
67 )
69 @staticmethod
70 def make_call_id() -> str:
71 return f"call_{uuid.uuid4().hex[:12]}"