Coverage for agentos/models/backends/openai.py: 29%

91 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1"""OpenAI backend for AgentOS. 

2 

3Supports OpenAI, Azure OpenAI, and any OpenAI-compatible API (DeepSeek, Groq, etc.). 

4""" 

5 

6from dataclasses import dataclass, field 

7from typing import Any, AsyncIterator, Dict, List, Optional 

8 

9import os 

10 

11 

12@dataclass 

13class OpenAIConfig: 

14 """Configuration for OpenAI backend.""" 

15 api_key: str = field(default_factory=lambda: os.environ.get("OPENAI_API_KEY", "")) 

16 base_url: str = field(default_factory=lambda: os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")) 

17 model: str = "gpt-4o" 

18 temperature: float = 0.7 

19 max_tokens: int = 4096 

20 top_p: float = 1.0 

21 frequency_penalty: float = 0.0 

22 presence_penalty: float = 0.0 

23 timeout: int = 60 

24 max_retries: int = 3 

25 organization: str = "" 

26 

27 

28class OpenAIClient: 

29 """OpenAI-compatible LLM client. 

30 

31 Works with: 

32 - OpenAI (GPT-4o, GPT-4, GPT-3.5) 

33 - Azure OpenAI 

34 - DeepSeek 

35 - Groq 

36 - Together AI 

37 - Any OpenAI-compatible endpoint 

38 """ 

39 

40 def __init__(self, config: Optional[OpenAIConfig] = None): 

41 self.config = config or OpenAIConfig() 

42 self._client = None 

43 self._async_client = None 

44 

45 @property 

46 def headers(self) -> Dict[str, str]: 

47 h = { 

48 "Authorization": f"Bearer {self.config.api_key}", 

49 "Content-Type": "application/json", 

50 } 

51 if self.config.organization: 

52 h["OpenAI-Organization"] = self.config.organization 

53 return h 

54 

55 @property 

56 def _chat_url(self) -> str: 

57 return f"{self.config.base_url}/chat/completions" 

58 

59 async def _async_request(self, messages: List[Dict], **kwargs) -> Dict: 

60 import httpx 

61 

62 timeout = httpx.Timeout(self.config.timeout) 

63 payload = { 

64 "model": self.config.model, 

65 "messages": messages, 

66 "temperature": kwargs.get("temperature", self.config.temperature), 

67 "max_tokens": kwargs.get("max_tokens", self.config.max_tokens), 

68 "top_p": kwargs.get("top_p", self.config.top_p), 

69 "frequency_penalty": self.config.frequency_penalty, 

70 "presence_penalty": self.config.presence_penalty, 

71 } 

72 

73 if kwargs.get("tools"): 

74 payload["tools"] = kwargs["tools"] 

75 payload["tool_choice"] = kwargs.get("tool_choice", "auto") 

76 

77 if kwargs.get("response_format"): 

78 payload["response_format"] = kwargs["response_format"] 

79 

80 for attempt in range(self.config.max_retries): 

81 try: 

82 async with httpx.AsyncClient() as client: 

83 resp = await client.post( 

84 self._chat_url, 

85 json=payload, 

86 headers=self.headers, 

87 timeout=timeout, 

88 ) 

89 resp.raise_for_status() 

90 return resp.json() 

91 except httpx.HTTPStatusError as e: 

92 if attempt == self.config.max_retries - 1: 

93 raise 

94 if e.response.status_code >= 500: 

95 import asyncio 

96 await asyncio.sleep(2 ** attempt) 

97 continue 

98 raise 

99 

100 async def chat( 

101 self, 

102 messages: List[Dict[str, str]], 

103 system: Optional[str] = None, 

104 **kwargs, 

105 ) -> Dict[str, Any]: 

106 """Send a chat completion request. 

107 

108 Args: 

109 messages: List of message dicts with 'role' and 'content'. 

110 system: Optional system prompt. 

111 **kwargs: Override config parameters. 

112 

113 Returns: 

114 Response dict with 'content', 'role', 'usage', 'model'. 

115 """ 

116 msgs = messages.copy() 

117 if system: 

118 msgs.insert(0, {"role": "system", "content": system}) 

119 

120 result = await self._async_request(msgs, **kwargs) 

121 choice = result["choices"][0] 

122 message = choice.get("message", {}) 

123 

124 tool_calls = message.get("tool_calls", []) 

125 

126 return { 

127 "content": message.get("content", ""), 

128 "role": message.get("role", "assistant"), 

129 "usage": result.get("usage", {}), 

130 "model": result.get("model", self.config.model), 

131 "finish_reason": choice.get("finish_reason", ""), 

132 "tool_calls": [ 

133 { 

134 "id": tc.get("id", ""), 

135 "name": tc.get("function", {}).get("name", ""), 

136 "arguments": tc.get("function", {}).get("arguments", "{}"), 

137 } 

138 for tc in tool_calls 

139 ], 

140 } 

141 

142 async def chat_stream( 

143 self, 

144 messages: List[Dict[str, str]], 

145 system: Optional[str] = None, 

146 **kwargs, 

147 ) -> AsyncIterator[Dict[str, Any]]: 

148 """Stream chat completion tokens. 

149 

150 Yields dicts with 'delta', 'finish_reason', 'tool_call_delta'. 

151 """ 

152 import httpx 

153 

154 msgs = messages.copy() 

155 if system: 

156 msgs.insert(0, {"role": "system", "content": system}) 

157 

158 payload = { 

159 "model": self.config.model, 

160 "messages": msgs, 

161 "temperature": kwargs.get("temperature", self.config.temperature), 

162 "max_tokens": kwargs.get("max_tokens", self.config.max_tokens), 

163 "top_p": kwargs.get("top_p", self.config.top_p), 

164 "stream": True, 

165 } 

166 

167 timeout = httpx.Timeout(self.config.timeout * 2) 

168 async with httpx.AsyncClient() as client: 

169 async with client.stream( 

170 "POST", 

171 self._chat_url, 

172 json=payload, 

173 headers=self.headers, 

174 timeout=timeout, 

175 ) as response: 

176 response.raise_for_status() 

177 async for line in response.aiter_lines(): 

178 if line.startswith("data: "): 

179 data = line[6:].strip() 

180 if data == "[DONE]": 

181 break 

182 import json 

183 try: 

184 chunk = json.loads(data) 

185 choice = chunk["choices"][0] 

186 delta = choice.get("delta", {}) 

187 yield { 

188 "delta": delta.get("content", ""), 

189 "finish_reason": choice.get("finish_reason"), 

190 "tool_call_delta": delta.get("tool_calls"), 

191 } 

192 except (json.JSONDecodeError, KeyError): 

193 continue 

194 

195 def sync_chat( 

196 self, 

197 messages: List[Dict[str, str]], 

198 system: Optional[str] = None, 

199 **kwargs, 

200 ) -> Dict[str, Any]: 

201 """Synchronous chat completion.""" 

202 import asyncio 

203 

204 loop = asyncio.new_event_loop() 

205 try: 

206 return loop.run_until_complete(self.chat(messages, system, **kwargs)) 

207 finally: 

208 loop.close()