Coverage for agentos/llm/openai_provider.py: 25%
103 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2OpenAI Provider 实现 — 基于官方 openai SDK 的对话补全。
3v1.3.36: +Function Calling / Tool Use 支持。
4"""
6from __future__ import annotations
8import json
9from typing import Any, Iterator
11try:
12 from openai import AsyncOpenAI, OpenAI
13 from openai.types.chat import ChatCompletionMessageParam
14except ImportError as e:
15 raise ImportError(
16 "openai SDK not installed. Run: pip install 'nexus-agentos[openai]'"
17 ) from e
19from agentos.llm.base import (
20 CompletionChoice,
21 CompletionResult,
22 CompletionUsage,
23 LLMProvider,
24 Message,
25 MessageRole,
26 StreamChunk,
27 Tool,
28 ToolCall,
29)
31__all__ = ["OpenAIProvider"]
34_ROLE_MAP: dict[MessageRole, str] = {
35 MessageRole.SYSTEM: "system",
36 MessageRole.USER: "user",
37 MessageRole.ASSISTANT: "assistant",
38 MessageRole.TOOL: "tool",
39}
41_REVERSE_ROLE_MAP: dict[str, MessageRole] = {v: k for k, v in _ROLE_MAP.items()}
43# USD per 1K tokens (as of 2025-06)
44_PRICING: dict[str, tuple[float, float]] = {
45 "gpt-4o": (0.0025, 0.0100),
46 "gpt-4o-mini": (0.00015, 0.0006),
47 "gpt-4.1": (0.0020, 0.0080),
48 "gpt-4.1-mini": (0.0004, 0.0016),
49 "gpt-4.1-nano": (0.0001, 0.0004),
50 "o3": (0.0100, 0.0400),
51 "o3-mini": (0.0011, 0.0044),
52 "o4-mini": (0.0011, 0.0044),
53}
56def _messages_to_openai(messages: list[Message]) -> list[ChatCompletionMessageParam]:
57 """将 Message 列表转换为 OpenAI SDK 格式。"""
58 result: list[ChatCompletionMessageParam] = []
59 for m in messages:
60 entry: dict[str, Any] = {"role": _ROLE_MAP[m.role], "content": m.content}
61 if m.tool_call_id:
62 entry["tool_call_id"] = m.tool_call_id
63 if m.tool_calls:
64 entry["tool_calls"] = [
65 {
66 "id": tc.id,
67 "type": "function",
68 "function": {"name": tc.name, "arguments": tc.arguments},
69 }
70 for tc in m.tool_calls
71 ]
72 result.append(entry)
73 return result
76def _tools_to_openai(tools: list[Tool] | None) -> list[dict[str, Any]] | None:
77 if not tools:
78 return None
79 return [t.as_schema() for t in tools]
82def _extract_tool_calls(message_obj) -> list[ToolCall]:
83 """从 OpenAI message 对象中提取 ToolCall 列表。"""
84 raw = getattr(message_obj, "tool_calls", None) or []
85 result: list[ToolCall] = []
86 for tc in raw:
87 fn = getattr(tc, "function", None)
88 result.append(ToolCall(
89 id=tc.id,
90 name=fn.name if fn else "",
91 arguments=fn.arguments if fn else "{}",
92 ))
93 return result
96def _build_result(raw, model: str | None = None) -> CompletionResult:
97 """从 OpenAI SDK 响应构建 CompletionResult。"""
98 m = raw.choices[0].message
99 role = _REVERSE_ROLE_MAP.get(m.role, MessageRole.ASSISTANT)
100 tool_calls = _extract_tool_calls(m)
101 choice = CompletionChoice(
102 index=raw.choices[0].index,
103 message=Message(
104 role=role, content=m.content or "",
105 tool_calls=tool_calls if tool_calls else None,
106 ),
107 finish_reason=raw.choices[0].finish_reason or "stop",
108 )
109 usage = raw.usage
110 tokens = CompletionUsage(
111 prompt_tokens=usage.prompt_tokens if usage else 0,
112 completion_tokens=usage.completion_tokens if usage else 0,
113 total_tokens=usage.total_tokens if usage else 0,
114 )
115 resolved_model = model or raw.model or ""
116 if resolved_model in _PRICING:
117 in_price, out_price = _PRICING[resolved_model]
118 tokens.cost_usd = round(
119 tokens.prompt_tokens / 1000 * in_price + tokens.completion_tokens / 1000 * out_price, 6
120 )
121 return CompletionResult(
122 id=raw.id, model=resolved_model, choices=[choice], usage=tokens, created=raw.created
123 )
126class OpenAIProvider(LLMProvider):
127 """OpenAI SDK 提供商。支持 openai、azure、及所有 OpenAI 兼容的三方端点。"""
129 _sync_client: OpenAI | None = None
130 _async_client: AsyncOpenAI | None = None
132 def __init__(
133 self,
134 model: str = "gpt-4o-mini",
135 api_key: str = "",
136 base_url: str = "",
137 organization: str = "",
138 timeout: float = 60.0,
139 ):
140 super().__init__(model=model, api_key=api_key, base_url=base_url)
141 self._organization = organization
142 self._timeout = timeout
144 @property
145 def provider_name(self) -> str:
146 return "openai"
148 def _get_client(self) -> OpenAI:
149 if self._sync_client is None:
150 kwargs: dict[str, Any] = {"timeout": self._timeout, "max_retries": 2}
151 if self.api_key:
152 kwargs["api_key"] = self.api_key
153 if self.base_url:
154 kwargs["base_url"] = self.base_url
155 if self._organization:
156 kwargs["organization"] = self._organization
157 self._sync_client = OpenAI(**kwargs)
158 return self._sync_client
160 def _get_async_client(self) -> AsyncOpenAI:
161 if self._async_client is None:
162 kwargs: dict[str, Any] = {"timeout": self._timeout, "max_retries": 2}
163 if self.api_key:
164 kwargs["api_key"] = self.api_key
165 if self.base_url:
166 kwargs["base_url"] = self.base_url
167 if self._organization:
168 kwargs["organization"] = self._organization
169 self._async_client = AsyncOpenAI(**kwargs)
170 return self._async_client
172 def chat(
173 self,
174 messages: list[Message],
175 *,
176 temperature: float = 0.7,
177 max_tokens: int = 4096,
178 top_p: float = 1.0,
179 stop: list[str] | None = None,
180 tools: list[Tool] | None = None,
181 tool_choice: str = "auto",
182 **kwargs: Any,
183 ) -> CompletionResult:
184 client = self._get_client()
185 params: dict[str, Any] = {
186 "model": self.model,
187 "messages": _messages_to_openai(messages),
188 "temperature": temperature,
189 "max_tokens": max_tokens,
190 "top_p": top_p,
191 "stop": stop,
192 **kwargs,
193 }
194 if tools:
195 params["tools"] = _tools_to_openai(tools)
196 params["tool_choice"] = tool_choice
197 resp = client.chat.completions.create(**params)
198 return _build_result(resp, model=self.model)
200 async def achat(
201 self,
202 messages: list[Message],
203 *,
204 temperature: float = 0.7,
205 max_tokens: int = 4096,
206 top_p: float = 1.0,
207 stop: list[str] | None = None,
208 tools: list[Tool] | None = None,
209 tool_choice: str = "auto",
210 **kwargs: Any,
211 ) -> CompletionResult:
212 client = self._get_async_client()
213 params: dict[str, Any] = {
214 "model": self.model,
215 "messages": _messages_to_openai(messages),
216 "temperature": temperature,
217 "max_tokens": max_tokens,
218 "top_p": top_p,
219 "stop": stop,
220 **kwargs,
221 }
222 if tools:
223 params["tools"] = _tools_to_openai(tools)
224 params["tool_choice"] = tool_choice
225 resp = await client.chat.completions.create(**params)
226 return _build_result(resp, model=self.model)
228 def stream(
229 self,
230 messages: list[Message],
231 *,
232 temperature: float = 0.7,
233 max_tokens: int = 4096,
234 tools: list[Tool] | None = None,
235 **kwargs: Any,
236 ) -> Iterator[StreamChunk]:
237 client = self._get_client()
238 params: dict[str, Any] = {
239 "model": self.model,
240 "messages": _messages_to_openai(messages),
241 "temperature": temperature,
242 "max_tokens": max_tokens,
243 "stream": True,
244 **kwargs,
245 }
246 if tools:
247 params["tools"] = _tools_to_openai(tools)
248 stream_resp = client.chat.completions.create(**params)
249 for chunk in stream_resp:
250 if chunk.choices and chunk.choices[0].delta.content:
251 yield StreamChunk(
252 content=chunk.choices[0].delta.content,
253 finish_reason=(
254 chunk.choices[0].finish_reason if chunk.choices[0].finish_reason else None
255 ),
256 index=chunk.choices[0].index,
257 )