Coverage for agentos/tools/generator.py: 23%
145 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 17:49 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 17:49 +0800
1"""
2OpenAPI工具自动生成器 — 从OpenAPI/Swagger spec自动生成Agent工具包装器。
3v0.50: 新增模块。将REST API端点自动转换为Agent可调用的ToolCall格式。
4"""
6from __future__ import annotations
8import json
9import re
10from dataclasses import dataclass, field
11from pathlib import Path
12from typing import Any
14import httpx
15import yaml
18@dataclass
19class GeneratedTool:
20 """单个生成的工具描述。"""
21 name: str
22 description: str
23 operation_id: str = ""
24 method: str = "GET"
25 path: str = ""
26 parameters_schema: dict = field(default_factory=dict)
27 auth_header: str = ""
28 base_url: str = ""
30 def to_openai_function(self) -> dict:
31 """转换为OpenAI function calling格式。"""
32 func = {
33 "type": "function",
34 "function": {
35 "name": self.name,
36 "description": self.description,
37 },
38 }
39 if self.parameters_schema:
40 func["function"]["parameters"] = self.parameters_schema
41 return func
43 def to_tool_dict(self) -> dict:
44 """转换为通用工具描述字典。"""
45 return {
46 "name": self.name,
47 "description": self.description,
48 "operation_id": self.operation_id,
49 "method": self.method,
50 "path_template": self.path,
51 "parameters": self.parameters_schema,
52 "base_url": self.base_url,
53 "auth_header": self.auth_header,
54 }
57class OpenAPIToolGenerator:
58 """
59 从OpenAPI 3.x / Swagger 2.0 spec生成Agent工具。
61 用法:
62 gen = OpenAPIToolGenerator("https://api.example.com/openapi.json")
63 tools = await gen.generate()
64 # tools是GeneratedTool列表,可直接注入Agent context
65 """
67 PARAM_TYPE_MAP = {
68 "string": {"type": "string"},
69 "integer": {"type": "integer"},
70 "number": {"type": "number"},
71 "boolean": {"type": "boolean"},
72 "array": {"type": "array", "items": {"type": "string"}},
73 "object": {"type": "object"},
74 }
76 def __init__(self, spec_url: str = "", spec_path: str = "", api_base: str = "",
77 auth_header: str = "Authorization", auth_value: str = ""):
78 self.spec_url = spec_url
79 self.spec_path = spec_path
80 self.api_base = api_base
81 self.auth_header = auth_header
82 self.auth_value = auth_value
83 self._http = httpx.AsyncClient(timeout=30)
85 async def load_spec(self) -> dict:
86 """加载OpenAPI spec(URL或本地文件)。"""
87 if self.spec_url:
88 resp = await self._http.get(self.spec_url)
89 resp.raise_for_status()
90 if self.spec_url.endswith((".yaml", ".yml")):
91 return yaml.safe_load(resp.text)
92 return resp.json()
94 if self.spec_path:
95 path = Path(self.spec_path)
96 text = path.read_text(encoding="utf-8")
97 if path.suffix in (".yaml", ".yml"):
98 return yaml.safe_load(text)
99 return json.loads(text)
101 raise ValueError("spec_url or spec_path required")
103 async def generate(self, filter_tag: str = "", max_tools: int = 100) -> list[GeneratedTool]:
104 """解析spec并生成工具列表。"""
105 spec = await self.load_spec()
106 tools: list[GeneratedTool] = []
107 base_url = self.api_base or self._extract_base_url(spec)
108 paths = spec.get("paths", {})
110 for path_url, methods in paths.items():
111 if not isinstance(methods, dict):
112 continue
113 for method, operation in methods.items():
114 if method.upper() not in ("GET", "POST", "PUT", "DELETE", "PATCH"):
115 continue
116 if not isinstance(operation, dict):
117 continue
119 tags = operation.get("tags", [])
120 if filter_tag and filter_tag not in tags:
121 continue
123 tool = self._build_tool(path_url, method, operation, base_url)
124 tools.append(tool)
125 if len(tools) >= max_tools:
126 return tools
128 return tools
130 def _build_tool(self, path_url: str, method: str, operation: dict,
131 base_url: str) -> GeneratedTool:
132 """从单个endpoint构建GeneratedTool。"""
133 operation_id = operation.get("operationId", self._generate_operation_id(method, path_url))
134 summary = operation.get("summary", "")
135 description = operation.get("description", summary or f"{method.upper()} {path_url}")
136 tool_name = self._sanitize_name(operation_id)
138 schema = self._build_parameters_schema(operation)
139 return GeneratedTool(
140 name=tool_name,
141 description=description,
142 operation_id=operation_id,
143 method=method.upper(),
144 path=path_url,
145 parameters_schema=schema,
146 base_url=base_url,
147 auth_header=self.auth_header,
148 )
150 def _extract_base_url(self, spec: dict) -> str:
151 """提取API base URL。"""
152 servers = spec.get("servers", [])
153 if servers:
154 return servers[0].get("url", "")
155 host = spec.get("host", "")
156 base_path = spec.get("basePath", "")
157 schemes = spec.get("schemes", ["https"])
158 if host:
159 return f"{schemes[0]}://{host}{base_path}"
160 return ""
162 def _build_parameters_schema(self, operation: dict) -> dict:
163 """构建parameters JSON Schema。"""
164 properties: dict[str, Any] = {}
165 required: list[str] = []
167 # 路径/查询/header参数
168 for param in operation.get("parameters", []):
169 name = param["name"]
170 schema = param.get("schema", {})
171 param_type = schema.get("type") or param.get("type", "string")
172 properties[name] = self.PARAM_TYPE_MAP.get(param_type, {"type": "string"})
173 description = param.get("description", "")
174 if description:
175 properties[name]["description"] = description
176 if param.get("required"):
177 required.append(name)
179 # requestBody (POST/PUT/PATCH)
180 request_body = operation.get("requestBody", {})
181 content = request_body.get("content", {})
182 json_content = content.get("application/json", {})
183 json_schema = json_content.get("schema", {})
184 if json_schema.get("properties"):
185 for prop_name, prop_schema in json_schema["properties"].items():
186 properties[prop_name] = prop_schema
187 if json_schema.get("required"):
188 required.extend(json_schema["required"])
190 if not properties:
191 return {}
193 schema = {"type": "object", "properties": properties}
194 if required:
195 schema["required"] = required
196 return schema
198 @staticmethod
199 def _sanitize_name(operation_id: str) -> str:
200 """清理operationId为合法的函数名。"""
201 name = re.sub(r"[^a-zA-Z0-9_]", "_", operation_id)
202 name = re.sub(r"_{2,}", "_", name)
203 name = name.strip("_").lower()
204 if not name[0].isalpha() and name[0] != "_":
205 name = "tool_" + name
206 return name[:64]
208 @staticmethod
209 def _generate_operation_id(method: str, path: str) -> str:
210 """无operationId时从method+path生成。"""
211 clean = re.sub(r"[{}]", "", path).replace("/", "_").strip("_")
212 clean = re.sub(r"[^a-zA-Z0-9_]", "_", clean)
213 return f"{method.lower()}_{clean}"
215 async def invoke(self, tool: GeneratedTool, params: dict) -> dict:
216 """执行生成的工具调用。"""
217 url = tool.base_url.rstrip("/") + tool.path
218 # 替换路径参数
219 for key, val in params.items():
220 placeholder = "{" + key + "}"
221 if placeholder in url:
222 url = url.replace(placeholder, str(val))
223 params = {k: v for k, v in params.items() if k != key}
225 headers = {}
226 if tool.auth_header:
227 headers[tool.auth_header] = self.auth_value
229 if tool.method == "GET":
230 resp = await self._http.get(url, params=params, headers=headers)
231 else:
232 headers.setdefault("Content-Type", "application/json")
233 resp = await self._http.request(
234 tool.method, url, json=params, headers=headers
235 )
237 resp.raise_for_status()
238 return resp.json()
240 async def close(self):
241 await self._http.aclose()