Coverage for agentos/llm/openai_provider.py: 25%

103 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2OpenAI Provider 实现 — 基于官方 openai SDK 的对话补全。 

3v1.3.36: +Function Calling / Tool Use 支持。 

4""" 

5 

6from __future__ import annotations 

7 

8import json 

9from typing import Any, Iterator 

10 

11try: 

12 from openai import AsyncOpenAI, OpenAI 

13 from openai.types.chat import ChatCompletionMessageParam 

14except ImportError as e: 

15 raise ImportError( 

16 "openai SDK not installed. Run: pip install 'nexus-agentos[openai]'" 

17 ) from e 

18 

19from agentos.llm.base import ( 

20 CompletionChoice, 

21 CompletionResult, 

22 CompletionUsage, 

23 LLMProvider, 

24 Message, 

25 MessageRole, 

26 StreamChunk, 

27 Tool, 

28 ToolCall, 

29) 

30 

31__all__ = ["OpenAIProvider"] 

32 

33 

34_ROLE_MAP: dict[MessageRole, str] = { 

35 MessageRole.SYSTEM: "system", 

36 MessageRole.USER: "user", 

37 MessageRole.ASSISTANT: "assistant", 

38 MessageRole.TOOL: "tool", 

39} 

40 

41_REVERSE_ROLE_MAP: dict[str, MessageRole] = {v: k for k, v in _ROLE_MAP.items()} 

42 

43# USD per 1K tokens (as of 2025-06) 

44_PRICING: dict[str, tuple[float, float]] = { 

45 "gpt-4o": (0.0025, 0.0100), 

46 "gpt-4o-mini": (0.00015, 0.0006), 

47 "gpt-4.1": (0.0020, 0.0080), 

48 "gpt-4.1-mini": (0.0004, 0.0016), 

49 "gpt-4.1-nano": (0.0001, 0.0004), 

50 "o3": (0.0100, 0.0400), 

51 "o3-mini": (0.0011, 0.0044), 

52 "o4-mini": (0.0011, 0.0044), 

53} 

54 

55 

56def _messages_to_openai(messages: list[Message]) -> list[ChatCompletionMessageParam]: 

57 """将 Message 列表转换为 OpenAI SDK 格式。""" 

58 result: list[ChatCompletionMessageParam] = [] 

59 for m in messages: 

60 entry: dict[str, Any] = {"role": _ROLE_MAP[m.role], "content": m.content} 

61 if m.tool_call_id: 

62 entry["tool_call_id"] = m.tool_call_id 

63 if m.tool_calls: 

64 entry["tool_calls"] = [ 

65 { 

66 "id": tc.id, 

67 "type": "function", 

68 "function": {"name": tc.name, "arguments": tc.arguments}, 

69 } 

70 for tc in m.tool_calls 

71 ] 

72 result.append(entry) 

73 return result 

74 

75 

76def _tools_to_openai(tools: list[Tool] | None) -> list[dict[str, Any]] | None: 

77 if not tools: 

78 return None 

79 return [t.as_schema() for t in tools] 

80 

81 

82def _extract_tool_calls(message_obj) -> list[ToolCall]: 

83 """从 OpenAI message 对象中提取 ToolCall 列表。""" 

84 raw = getattr(message_obj, "tool_calls", None) or [] 

85 result: list[ToolCall] = [] 

86 for tc in raw: 

87 fn = getattr(tc, "function", None) 

88 result.append(ToolCall( 

89 id=tc.id, 

90 name=fn.name if fn else "", 

91 arguments=fn.arguments if fn else "{}", 

92 )) 

93 return result 

94 

95 

96def _build_result(raw, model: str | None = None) -> CompletionResult: 

97 """从 OpenAI SDK 响应构建 CompletionResult。""" 

98 m = raw.choices[0].message 

99 role = _REVERSE_ROLE_MAP.get(m.role, MessageRole.ASSISTANT) 

100 tool_calls = _extract_tool_calls(m) 

101 choice = CompletionChoice( 

102 index=raw.choices[0].index, 

103 message=Message( 

104 role=role, content=m.content or "", 

105 tool_calls=tool_calls if tool_calls else None, 

106 ), 

107 finish_reason=raw.choices[0].finish_reason or "stop", 

108 ) 

109 usage = raw.usage 

110 tokens = CompletionUsage( 

111 prompt_tokens=usage.prompt_tokens if usage else 0, 

112 completion_tokens=usage.completion_tokens if usage else 0, 

113 total_tokens=usage.total_tokens if usage else 0, 

114 ) 

115 resolved_model = model or raw.model or "" 

116 if resolved_model in _PRICING: 

117 in_price, out_price = _PRICING[resolved_model] 

118 tokens.cost_usd = round( 

119 tokens.prompt_tokens / 1000 * in_price + tokens.completion_tokens / 1000 * out_price, 6 

120 ) 

121 return CompletionResult( 

122 id=raw.id, model=resolved_model, choices=[choice], usage=tokens, created=raw.created 

123 ) 

124 

125 

126class OpenAIProvider(LLMProvider): 

127 """OpenAI SDK 提供商。支持 openai、azure、及所有 OpenAI 兼容的三方端点。""" 

128 

129 _sync_client: OpenAI | None = None 

130 _async_client: AsyncOpenAI | None = None 

131 

132 def __init__( 

133 self, 

134 model: str = "gpt-4o-mini", 

135 api_key: str = "", 

136 base_url: str = "", 

137 organization: str = "", 

138 timeout: float = 60.0, 

139 ): 

140 super().__init__(model=model, api_key=api_key, base_url=base_url) 

141 self._organization = organization 

142 self._timeout = timeout 

143 

144 @property 

145 def provider_name(self) -> str: 

146 return "openai" 

147 

148 def _get_client(self) -> OpenAI: 

149 if self._sync_client is None: 

150 kwargs: dict[str, Any] = {"timeout": self._timeout, "max_retries": 2} 

151 if self.api_key: 

152 kwargs["api_key"] = self.api_key 

153 if self.base_url: 

154 kwargs["base_url"] = self.base_url 

155 if self._organization: 

156 kwargs["organization"] = self._organization 

157 self._sync_client = OpenAI(**kwargs) 

158 return self._sync_client 

159 

160 def _get_async_client(self) -> AsyncOpenAI: 

161 if self._async_client is None: 

162 kwargs: dict[str, Any] = {"timeout": self._timeout, "max_retries": 2} 

163 if self.api_key: 

164 kwargs["api_key"] = self.api_key 

165 if self.base_url: 

166 kwargs["base_url"] = self.base_url 

167 if self._organization: 

168 kwargs["organization"] = self._organization 

169 self._async_client = AsyncOpenAI(**kwargs) 

170 return self._async_client 

171 

172 def chat( 

173 self, 

174 messages: list[Message], 

175 *, 

176 temperature: float = 0.7, 

177 max_tokens: int = 4096, 

178 top_p: float = 1.0, 

179 stop: list[str] | None = None, 

180 tools: list[Tool] | None = None, 

181 tool_choice: str = "auto", 

182 **kwargs: Any, 

183 ) -> CompletionResult: 

184 client = self._get_client() 

185 params: dict[str, Any] = { 

186 "model": self.model, 

187 "messages": _messages_to_openai(messages), 

188 "temperature": temperature, 

189 "max_tokens": max_tokens, 

190 "top_p": top_p, 

191 "stop": stop, 

192 **kwargs, 

193 } 

194 if tools: 

195 params["tools"] = _tools_to_openai(tools) 

196 params["tool_choice"] = tool_choice 

197 resp = client.chat.completions.create(**params) 

198 return _build_result(resp, model=self.model) 

199 

200 async def achat( 

201 self, 

202 messages: list[Message], 

203 *, 

204 temperature: float = 0.7, 

205 max_tokens: int = 4096, 

206 top_p: float = 1.0, 

207 stop: list[str] | None = None, 

208 tools: list[Tool] | None = None, 

209 tool_choice: str = "auto", 

210 **kwargs: Any, 

211 ) -> CompletionResult: 

212 client = self._get_async_client() 

213 params: dict[str, Any] = { 

214 "model": self.model, 

215 "messages": _messages_to_openai(messages), 

216 "temperature": temperature, 

217 "max_tokens": max_tokens, 

218 "top_p": top_p, 

219 "stop": stop, 

220 **kwargs, 

221 } 

222 if tools: 

223 params["tools"] = _tools_to_openai(tools) 

224 params["tool_choice"] = tool_choice 

225 resp = await client.chat.completions.create(**params) 

226 return _build_result(resp, model=self.model) 

227 

228 def stream( 

229 self, 

230 messages: list[Message], 

231 *, 

232 temperature: float = 0.7, 

233 max_tokens: int = 4096, 

234 tools: list[Tool] | None = None, 

235 **kwargs: Any, 

236 ) -> Iterator[StreamChunk]: 

237 client = self._get_client() 

238 params: dict[str, Any] = { 

239 "model": self.model, 

240 "messages": _messages_to_openai(messages), 

241 "temperature": temperature, 

242 "max_tokens": max_tokens, 

243 "stream": True, 

244 **kwargs, 

245 } 

246 if tools: 

247 params["tools"] = _tools_to_openai(tools) 

248 stream_resp = client.chat.completions.create(**params) 

249 for chunk in stream_resp: 

250 if chunk.choices and chunk.choices[0].delta.content: 

251 yield StreamChunk( 

252 content=chunk.choices[0].delta.content, 

253 finish_reason=( 

254 chunk.choices[0].finish_reason if chunk.choices[0].finish_reason else None 

255 ), 

256 index=chunk.choices[0].index, 

257 )