Coverage for agentos/tools/generator.py: 79%

145 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 17:38 +0800

1""" 

2OpenAPI工具自动生成器 — 从OpenAPI/Swagger spec自动生成Agent工具包装器。 

3v0.50: 新增模块。将REST API端点自动转换为Agent可调用的ToolCall格式。 

4""" 

5 

6from __future__ import annotations 

7 

8import json 

9import re 

10from dataclasses import dataclass, field 

11from pathlib import Path 

12from typing import Any 

13 

14import httpx 

15import yaml 

16 

17 

18@dataclass 

19class GeneratedTool: 

20 """单个生成的工具描述。""" 

21 name: str 

22 description: str 

23 operation_id: str = "" 

24 method: str = "GET" 

25 path: str = "" 

26 parameters_schema: dict = field(default_factory=dict) 

27 auth_header: str = "" 

28 base_url: str = "" 

29 

30 def to_openai_function(self) -> dict: 

31 """转换为OpenAI function calling格式。""" 

32 func = { 

33 "type": "function", 

34 "function": { 

35 "name": self.name, 

36 "description": self.description, 

37 }, 

38 } 

39 if self.parameters_schema: 

40 func["function"]["parameters"] = self.parameters_schema 

41 return func 

42 

43 def to_tool_dict(self) -> dict: 

44 """转换为通用工具描述字典。""" 

45 return { 

46 "name": self.name, 

47 "description": self.description, 

48 "operation_id": self.operation_id, 

49 "method": self.method, 

50 "path_template": self.path, 

51 "parameters": self.parameters_schema, 

52 "base_url": self.base_url, 

53 "auth_header": self.auth_header, 

54 } 

55 

56 

57class OpenAPIToolGenerator: 

58 """ 

59 从OpenAPI 3.x / Swagger 2.0 spec生成Agent工具。 

60 

61 用法: 

62 gen = OpenAPIToolGenerator("https://api.example.com/openapi.json") 

63 tools = await gen.generate() 

64 # tools是GeneratedTool列表,可直接注入Agent context 

65 """ 

66 

67 PARAM_TYPE_MAP = { 

68 "string": {"type": "string"}, 

69 "integer": {"type": "integer"}, 

70 "number": {"type": "number"}, 

71 "boolean": {"type": "boolean"}, 

72 "array": {"type": "array", "items": {"type": "string"}}, 

73 "object": {"type": "object"}, 

74 } 

75 

76 def __init__(self, spec_url: str = "", spec_path: str = "", api_base: str = "", 

77 auth_header: str = "Authorization", auth_value: str = ""): 

78 self.spec_url = spec_url 

79 self.spec_path = spec_path 

80 self.api_base = api_base 

81 self.auth_header = auth_header 

82 self.auth_value = auth_value 

83 self._http = httpx.AsyncClient(timeout=30) 

84 

85 async def load_spec(self) -> dict: 

86 """加载OpenAPI spec(URL或本地文件)。""" 

87 if self.spec_url: 

88 resp = await self._http.get(self.spec_url) 

89 resp.raise_for_status() 

90 if self.spec_url.endswith((".yaml", ".yml")): 

91 return yaml.safe_load(resp.text) 

92 return resp.json() 

93 

94 if self.spec_path: 

95 path = Path(self.spec_path) 

96 text = path.read_text(encoding="utf-8") 

97 if path.suffix in (".yaml", ".yml"): 

98 return yaml.safe_load(text) 

99 return json.loads(text) 

100 

101 raise ValueError("spec_url or spec_path required") 

102 

103 async def generate(self, filter_tag: str = "", max_tools: int = 100) -> list[GeneratedTool]: 

104 """解析spec并生成工具列表。""" 

105 spec = await self.load_spec() 

106 tools: list[GeneratedTool] = [] 

107 base_url = self.api_base or self._extract_base_url(spec) 

108 paths = spec.get("paths", {}) 

109 

110 for path_url, methods in paths.items(): 

111 if not isinstance(methods, dict): 

112 continue 

113 for method, operation in methods.items(): 

114 if method.upper() not in ("GET", "POST", "PUT", "DELETE", "PATCH"): 

115 continue 

116 if not isinstance(operation, dict): 

117 continue 

118 

119 tags = operation.get("tags", []) 

120 if filter_tag and filter_tag not in tags: 

121 continue 

122 

123 tool = self._build_tool(path_url, method, operation, base_url) 

124 tools.append(tool) 

125 if len(tools) >= max_tools: 

126 return tools 

127 

128 return tools 

129 

130 def _build_tool(self, path_url: str, method: str, operation: dict, 

131 base_url: str) -> GeneratedTool: 

132 """从单个endpoint构建GeneratedTool。""" 

133 operation_id = operation.get("operationId", self._generate_operation_id(method, path_url)) 

134 summary = operation.get("summary", "") 

135 description = operation.get("description", summary or f"{method.upper()} {path_url}") 

136 tool_name = self._sanitize_name(operation_id) 

137 

138 schema = self._build_parameters_schema(operation) 

139 return GeneratedTool( 

140 name=tool_name, 

141 description=description, 

142 operation_id=operation_id, 

143 method=method.upper(), 

144 path=path_url, 

145 parameters_schema=schema, 

146 base_url=base_url, 

147 auth_header=self.auth_header, 

148 ) 

149 

150 def _extract_base_url(self, spec: dict) -> str: 

151 """提取API base URL。""" 

152 servers = spec.get("servers", []) 

153 if servers: 

154 return servers[0].get("url", "") 

155 host = spec.get("host", "") 

156 base_path = spec.get("basePath", "") 

157 schemes = spec.get("schemes", ["https"]) 

158 if host: 

159 return f"{schemes[0]}://{host}{base_path}" 

160 return "" 

161 

162 def _build_parameters_schema(self, operation: dict) -> dict: 

163 """构建parameters JSON Schema。""" 

164 properties: dict[str, Any] = {} 

165 required: list[str] = [] 

166 

167 # 路径/查询/header参数 

168 for param in operation.get("parameters", []): 

169 name = param["name"] 

170 schema = param.get("schema", {}) 

171 param_type = schema.get("type") or param.get("type", "string") 

172 properties[name] = self.PARAM_TYPE_MAP.get(param_type, {"type": "string"}) 

173 description = param.get("description", "") 

174 if description: 

175 properties[name]["description"] = description 

176 if param.get("required"): 

177 required.append(name) 

178 

179 # requestBody (POST/PUT/PATCH) 

180 request_body = operation.get("requestBody", {}) 

181 content = request_body.get("content", {}) 

182 json_content = content.get("application/json", {}) 

183 json_schema = json_content.get("schema", {}) 

184 if json_schema.get("properties"): 

185 for prop_name, prop_schema in json_schema["properties"].items(): 

186 properties[prop_name] = prop_schema 

187 if json_schema.get("required"): 

188 required.extend(json_schema["required"]) 

189 

190 if not properties: 

191 return {} 

192 

193 schema = {"type": "object", "properties": properties} 

194 if required: 

195 schema["required"] = required 

196 return schema 

197 

198 @staticmethod 

199 def _sanitize_name(operation_id: str) -> str: 

200 """清理operationId为合法的函数名。""" 

201 name = re.sub(r"[^a-zA-Z0-9_]", "_", operation_id) 

202 name = re.sub(r"_{2,}", "_", name) 

203 name = name.strip("_").lower() 

204 if not name[0].isalpha() and name[0] != "_": 

205 name = "tool_" + name 

206 return name[:64] 

207 

208 @staticmethod 

209 def _generate_operation_id(method: str, path: str) -> str: 

210 """无operationId时从method+path生成。""" 

211 clean = re.sub(r"[{}]", "", path).replace("/", "_").strip("_") 

212 clean = re.sub(r"[^a-zA-Z0-9_]", "_", clean) 

213 return f"{method.lower()}_{clean}" 

214 

215 async def invoke(self, tool: GeneratedTool, params: dict) -> dict: 

216 """执行生成的工具调用。""" 

217 url = tool.base_url.rstrip("/") + tool.path 

218 # 替换路径参数 

219 for key, val in params.items(): 

220 placeholder = "{" + key + "}" 

221 if placeholder in url: 

222 url = url.replace(placeholder, str(val)) 

223 params = {k: v for k, v in params.items() if k != key} 

224 

225 headers = {} 

226 if tool.auth_header: 

227 headers[tool.auth_header] = self.auth_value 

228 

229 if tool.method == "GET": 

230 resp = await self._http.get(url, params=params, headers=headers) 

231 else: 

232 headers.setdefault("Content-Type", "application/json") 

233 resp = await self._http.request( 

234 tool.method, url, json=params, headers=headers 

235 ) 

236 

237 resp.raise_for_status() 

238 return resp.json() 

239 

240 async def close(self): 

241 await self._http.aclose()