Coverage for agentos/models/backends/openai.py: 29%
91 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""OpenAI backend for AgentOS.
3Supports OpenAI, Azure OpenAI, and any OpenAI-compatible API (DeepSeek, Groq, etc.).
4"""
6from dataclasses import dataclass, field
7from typing import Any, AsyncIterator, Dict, List, Optional
9import os
12@dataclass
13class OpenAIConfig:
14 """Configuration for OpenAI backend."""
15 api_key: str = field(default_factory=lambda: os.environ.get("OPENAI_API_KEY", ""))
16 base_url: str = field(default_factory=lambda: os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"))
17 model: str = "gpt-4o"
18 temperature: float = 0.7
19 max_tokens: int = 4096
20 top_p: float = 1.0
21 frequency_penalty: float = 0.0
22 presence_penalty: float = 0.0
23 timeout: int = 60
24 max_retries: int = 3
25 organization: str = ""
28class OpenAIClient:
29 """OpenAI-compatible LLM client.
31 Works with:
32 - OpenAI (GPT-4o, GPT-4, GPT-3.5)
33 - Azure OpenAI
34 - DeepSeek
35 - Groq
36 - Together AI
37 - Any OpenAI-compatible endpoint
38 """
40 def __init__(self, config: Optional[OpenAIConfig] = None):
41 self.config = config or OpenAIConfig()
42 self._client = None
43 self._async_client = None
45 @property
46 def headers(self) -> Dict[str, str]:
47 h = {
48 "Authorization": f"Bearer {self.config.api_key}",
49 "Content-Type": "application/json",
50 }
51 if self.config.organization:
52 h["OpenAI-Organization"] = self.config.organization
53 return h
55 @property
56 def _chat_url(self) -> str:
57 return f"{self.config.base_url}/chat/completions"
59 async def _async_request(self, messages: List[Dict], **kwargs) -> Dict:
60 import httpx
62 timeout = httpx.Timeout(self.config.timeout)
63 payload = {
64 "model": self.config.model,
65 "messages": messages,
66 "temperature": kwargs.get("temperature", self.config.temperature),
67 "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
68 "top_p": kwargs.get("top_p", self.config.top_p),
69 "frequency_penalty": self.config.frequency_penalty,
70 "presence_penalty": self.config.presence_penalty,
71 }
73 if kwargs.get("tools"):
74 payload["tools"] = kwargs["tools"]
75 payload["tool_choice"] = kwargs.get("tool_choice", "auto")
77 if kwargs.get("response_format"):
78 payload["response_format"] = kwargs["response_format"]
80 for attempt in range(self.config.max_retries):
81 try:
82 async with httpx.AsyncClient() as client:
83 resp = await client.post(
84 self._chat_url,
85 json=payload,
86 headers=self.headers,
87 timeout=timeout,
88 )
89 resp.raise_for_status()
90 return resp.json()
91 except httpx.HTTPStatusError as e:
92 if attempt == self.config.max_retries - 1:
93 raise
94 if e.response.status_code >= 500:
95 import asyncio
96 await asyncio.sleep(2 ** attempt)
97 continue
98 raise
100 async def chat(
101 self,
102 messages: List[Dict[str, str]],
103 system: Optional[str] = None,
104 **kwargs,
105 ) -> Dict[str, Any]:
106 """Send a chat completion request.
108 Args:
109 messages: List of message dicts with 'role' and 'content'.
110 system: Optional system prompt.
111 **kwargs: Override config parameters.
113 Returns:
114 Response dict with 'content', 'role', 'usage', 'model'.
115 """
116 msgs = messages.copy()
117 if system:
118 msgs.insert(0, {"role": "system", "content": system})
120 result = await self._async_request(msgs, **kwargs)
121 choice = result["choices"][0]
122 message = choice.get("message", {})
124 tool_calls = message.get("tool_calls", [])
126 return {
127 "content": message.get("content", ""),
128 "role": message.get("role", "assistant"),
129 "usage": result.get("usage", {}),
130 "model": result.get("model", self.config.model),
131 "finish_reason": choice.get("finish_reason", ""),
132 "tool_calls": [
133 {
134 "id": tc.get("id", ""),
135 "name": tc.get("function", {}).get("name", ""),
136 "arguments": tc.get("function", {}).get("arguments", "{}"),
137 }
138 for tc in tool_calls
139 ],
140 }
142 async def chat_stream(
143 self,
144 messages: List[Dict[str, str]],
145 system: Optional[str] = None,
146 **kwargs,
147 ) -> AsyncIterator[Dict[str, Any]]:
148 """Stream chat completion tokens.
150 Yields dicts with 'delta', 'finish_reason', 'tool_call_delta'.
151 """
152 import httpx
154 msgs = messages.copy()
155 if system:
156 msgs.insert(0, {"role": "system", "content": system})
158 payload = {
159 "model": self.config.model,
160 "messages": msgs,
161 "temperature": kwargs.get("temperature", self.config.temperature),
162 "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
163 "top_p": kwargs.get("top_p", self.config.top_p),
164 "stream": True,
165 }
167 timeout = httpx.Timeout(self.config.timeout * 2)
168 async with httpx.AsyncClient() as client:
169 async with client.stream(
170 "POST",
171 self._chat_url,
172 json=payload,
173 headers=self.headers,
174 timeout=timeout,
175 ) as response:
176 response.raise_for_status()
177 async for line in response.aiter_lines():
178 if line.startswith("data: "):
179 data = line[6:].strip()
180 if data == "[DONE]":
181 break
182 import json
183 try:
184 chunk = json.loads(data)
185 choice = chunk["choices"][0]
186 delta = choice.get("delta", {})
187 yield {
188 "delta": delta.get("content", ""),
189 "finish_reason": choice.get("finish_reason"),
190 "tool_call_delta": delta.get("tool_calls"),
191 }
192 except (json.JSONDecodeError, KeyError):
193 continue
195 def sync_chat(
196 self,
197 messages: List[Dict[str, str]],
198 system: Optional[str] = None,
199 **kwargs,
200 ) -> Dict[str, Any]:
201 """Synchronous chat completion."""
202 import asyncio
204 loop = asyncio.new_event_loop()
205 try:
206 return loop.run_until_complete(self.chat(messages, system, **kwargs))
207 finally:
208 loop.close()