Coverage for agentos/models/backends/gemini.py: 25%

154 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2AgentOS v0.70 — Google Gemini Provider 全集成。 

3基因来源: Google AI Studio SDK + Vertex AI 

4支持: Gemini 2.5 Pro/Flash、Vision、System Instruction、Streaming、Token Counting、Safety Settings。 

5""" 

6 

7from __future__ import annotations 

8 

9import asyncio 

10import json 

11import os 

12from dataclasses import dataclass, field 

13from typing import Any, AsyncIterator 

14 

15import httpx 

16 

17from agentos.models.router import ModelResponse, ModelSpec 

18from agentos.core.context import AgentContext 

19from agentos.tools.base import ToolCall 

20 

21 

22# ── Gemini Public API Endpoint ────────────────── 

23GEMINI_API_BASE = "https://generativelanguage.googleapis.com/v1beta" 

24 

25# Prebuilt Gemini model specs 

26GEMINI_MODELS: dict[str, ModelSpec] = { 

27 "gemini-2.5-pro": ModelSpec( 

28 provider="gemini", 

29 model_id="gemini-2.5-pro-exp-03-25", 

30 context_window=1_048_576, 

31 cost_per_1m_input=1.25, 

32 cost_per_1m_output=10.00, 

33 ), 

34 "gemini-2.5-flash": ModelSpec( 

35 provider="gemini", 

36 model_id="gemini-2.5-flash-preview-04-17", 

37 context_window=1_048_576, 

38 cost_per_1m_input=0.15, 

39 cost_per_1m_output=0.60, 

40 ), 

41 "gemini-2.0-flash": ModelSpec( 

42 provider="gemini", 

43 model_id="gemini-2.0-flash", 

44 context_window=1_048_576, 

45 cost_per_1m_input=0.10, 

46 cost_per_1m_output=0.40, 

47 ), 

48} 

49 

50 

51@dataclass 

52class GeminiSafetySetting: 

53 """安全过滤配置。""" 

54 

55 category: str # HARM_CATEGORY_HARASSMENT | HATE_SPEECH | SEXUALLY_EXPLICIT | DANGEROUS_CONTENT 

56 threshold: str = "BLOCK_ONLY_HIGH" # BLOCK_NONE | BLOCK_ONLY_HIGH | BLOCK_MEDIUM_AND_ABOVE | BLOCK_LOW_AND_ABOVE 

57 

58 

59@dataclass 

60class GeminiConfig: 

61 """Gemini调用配置。""" 

62 

63 api_key: str = "" 

64 temperature: float = 0.7 

65 top_p: float = 0.95 

66 top_k: int = 40 

67 max_output_tokens: int = 8192 

68 safety_settings: list[GeminiSafetySetting] = field(default_factory=lambda: [ 

69 GeminiSafetySetting("HARM_CATEGORY_HARASSMENT", "BLOCK_ONLY_HIGH"), 

70 GeminiSafetySetting("HARM_CATEGORY_HATE_SPEECH", "BLOCK_ONLY_HIGH"), 

71 GeminiSafetySetting("HARM_CATEGORY_SEXUALLY_EXPLICIT", "BLOCK_ONLY_HIGH"), 

72 GeminiSafetySetting("HARM_CATEGORY_DANGEROUS_CONTENT", "BLOCK_ONLY_HIGH"), 

73 ]) 

74 

75 

76# ── Tool Declaration Helpers ───────────────────── 

77 

78def _convert_tools_to_gemini(openai_tools: list[dict]) -> list[dict]: 

79 """将OpenAI格式的tools转换为Gemini functionDeclarations。""" 

80 declarations = [] 

81 for tool in openai_tools: 

82 if tool.get("type") != "function": 

83 continue 

84 func = tool.get("function", {}) 

85 declarations.append({ 

86 "name": func.get("name", ""), 

87 "description": func.get("description", ""), 

88 "parameters": func.get("parameters", {}), 

89 }) 

90 return [{"function_declarations": declarations}] if declarations else [] 

91 

92 

93def _convert_gemini_tool_calls(parts: list[dict]) -> list[ToolCall]: 

94 """将Gemini functionCall parts转为ToolCall列表。""" 

95 tool_calls = [] 

96 for part in parts: 

97 fc = part.get("functionCall") 

98 if not fc: 

99 continue 

100 args = fc.get("args", {}) 

101 if isinstance(args, str): 

102 try: 

103 args = json.loads(args) 

104 except json.JSONDecodeError: 

105 args = {} 

106 tool_calls.append(ToolCall( 

107 id=fc.get("name", "unknown"), 

108 name=fc.get("name", "unknown"), 

109 arguments=args, 

110 )) 

111 return tool_calls 

112 

113 

114# ── Core Gemini Client ─────────────────────────── 

115 

116class GeminiClient: 

117 """ 

118 Google Gemini API 客户端。 

119 支持: chat/completions、Vision多模态、Streaming、System Instruction。 

120 """ 

121 

122 def __init__( 

123 self, 

124 config: GeminiConfig | None = None, 

125 http_client: httpx.AsyncClient | None = None, 

126 ): 

127 self.config = config or GeminiConfig() 

128 self._http = http_client or httpx.AsyncClient(timeout=180) 

129 self._owned_http = http_client is None 

130 

131 @property 

132 def api_key(self) -> str: 

133 return self.config.api_key or os.environ.get("GEMINI_API_KEY", "") 

134 

135 async def close(self): 

136 if self._owned_http: 

137 await self._http.aclose() 

138 

139 async def call( 

140 self, 

141 spec: ModelSpec, 

142 context: AgentContext, 

143 ) -> ModelResponse: 

144 """同步调用Gemini API。""" 

145 contents, system_instruction = self._build_gemini_contents(context) 

146 body = { 

147 "contents": contents, 

148 "generationConfig": { 

149 "temperature": self.config.temperature, 

150 "topP": self.config.top_p, 

151 "topK": self.config.top_k, 

152 "maxOutputTokens": self.config.max_output_tokens, 

153 }, 

154 "safetySettings": [ 

155 {"category": s.category, "threshold": s.threshold} 

156 for s in self.config.safety_settings 

157 ], 

158 } 

159 if system_instruction: 

160 body["systemInstruction"] = system_instruction 

161 

162 if context.tools: 

163 body["tools"] = _convert_tools_to_gemini(context.tools) 

164 

165 url = f"{GEMINI_API_BASE}/models/{spec.model_id}:generateContent?key={self.api_key}" 

166 resp = await self._http.post(url, json=body) 

167 resp.raise_for_status() 

168 data = resp.json() 

169 

170 return self._parse_response(data) 

171 

172 async def call_stream( 

173 self, 

174 spec: ModelSpec, 

175 context: AgentContext, 

176 ) -> AsyncIterator[dict]: 

177 """流式调用Gemini API,逐个yield chunk。""" 

178 contents, system_instruction = self._build_gemini_contents(context) 

179 body = { 

180 "contents": contents, 

181 "generationConfig": { 

182 "temperature": self.config.temperature, 

183 "topP": self.config.top_p, 

184 "topK": self.config.top_k, 

185 "maxOutputTokens": self.config.max_output_tokens, 

186 }, 

187 "safetySettings": [ 

188 {"category": s.category, "threshold": s.threshold} 

189 for s in self.config.safety_settings 

190 ], 

191 } 

192 if system_instruction: 

193 body["systemInstruction"] = system_instruction 

194 

195 url = f"{GEMINI_API_BASE}/models/{spec.model_id}:streamGenerateContent?alt=sse&key={self.api_key}" 

196 async with self._http.stream("POST", url, json=body) as resp: 

197 resp.raise_for_status() 

198 async for line in resp.aiter_lines(): 

199 line = line.strip() 

200 if not line or not line.startswith("data: "): 

201 continue 

202 data_str = line[6:] 

203 if data_str == "[DONE]": 

204 break 

205 try: 

206 chunk = json.loads(data_str) 

207 except json.JSONDecodeError: 

208 continue 

209 # skip safety / promptFeedback 

210 if "candidates" not in chunk: 

211 continue 

212 yield chunk 

213 

214 async def call_with_image( 

215 self, 

216 spec: ModelSpec, 

217 prompt: str, 

218 image_data: bytes, 

219 mime_type: str = "image/jpeg", 

220 ) -> ModelResponse: 

221 """Vision多模态调用。image_data为base64之前的内容。""" 

222 import base64 

223 b64 = base64.b64encode(image_data).decode() 

224 contents = [{ 

225 "role": "user", 

226 "parts": [ 

227 {"text": prompt}, 

228 {"inlineData": {"mimeType": mime_type, "data": b64}}, 

229 ], 

230 }] 

231 body = { 

232 "contents": contents, 

233 "generationConfig": { 

234 "temperature": self.config.temperature, 

235 "maxOutputTokens": self.config.max_output_tokens, 

236 }, 

237 } 

238 url = f"{GEMINI_API_BASE}/models/{spec.model_id}:generateContent?key={self.api_key}" 

239 resp = await self._http.post(url, json=body) 

240 resp.raise_for_status() 

241 return self._parse_response(resp.json()) 

242 

243 async def count_tokens(self, spec: ModelSpec, context: AgentContext) -> dict: 

244 """使用Gemini API统计输入/输出token数。""" 

245 contents, _ = self._build_gemini_contents(context) 

246 url = f"{GEMINI_API_BASE}/models/{spec.model_id}:countTokens?key={self.api_key}" 

247 resp = await self._http.post(url, json={"contents": contents}) 

248 resp.raise_for_status() 

249 data = resp.json() 

250 return { 

251 "total_tokens": data.get("totalTokens", 0), 

252 "prompt_tokens": data.get("totalTokens", 0), # Gemini不区分输入输出 

253 "model": spec.model_id, 

254 } 

255 

256 # ── Internal helpers ────────────────────────── 

257 

258 def _build_gemini_contents(self, context: AgentContext) -> tuple[list[dict], dict | None]: 

259 """将AgentContext转为Gemini contents格式。""" 

260 contents = [] 

261 system_instruction = None 

262 

263 for msg in context.messages: 

264 role = self._map_role(msg.role) 

265 parts = [] 

266 

267 # system prompt → systemInstruction 

268 if msg.role == "system": 

269 system_instruction = {"parts": [{"text": msg.content}]} 

270 continue 

271 

272 # text content 

273 if msg.content: 

274 parts.append({"text": msg.content}) 

275 

276 # tool calls from assistant 

277 if msg.tool_calls: 

278 for tc in msg.tool_calls: 

279 parts.append({ 

280 "functionCall": { 

281 "name": tc.name, 

282 "args": tc.arguments, 

283 } 

284 }) 

285 

286 # tool results 

287 if msg.role == "tool" and msg.tool_call_id: 

288 # Gemini uses functionResponse in user role 

289 parts.append({ 

290 "functionResponse": { 

291 "name": msg.tool_call_id, 

292 "response": {"content": msg.content}, 

293 } 

294 }) 

295 

296 if parts: 

297 contents.append({"role": role, "parts": parts}) 

298 

299 # Ensure there's at least a user message 

300 if not contents: 

301 contents = [{"role": "user", "parts": [{"text": context.current_task or ""}]}] 

302 

303 return contents, system_instruction 

304 

305 def _map_role(self, role: str) -> str: 

306 mapping = { 

307 "user": "user", 

308 "assistant": "model", 

309 "system": "user", # handled separately via systemInstruction 

310 "tool": "user", # functionResponse must be in user turn 

311 } 

312 return mapping.get(role, "user") 

313 

314 def _parse_response(self, data: dict) -> ModelResponse: 

315 """解析Gemini API响应为ModelResponse。""" 

316 candidates = data.get("candidates", []) 

317 if not candidates: 

318 # Safety blocked 

319 block_reason = data.get("promptFeedback", {}).get("blockReason", "unknown") 

320 return ModelResponse(content=f"[SAFETY_BLOCKED] {block_reason}") 

321 

322 candidate = candidates[0] 

323 content = candidate.get("content", {}) 

324 parts = content.get("parts", []) 

325 

326 text_parts = [] 

327 tool_calls = [] 

328 

329 for part in parts: 

330 if "text" in part: 

331 text_parts.append(part["text"]) 

332 if "functionCall" in part: 

333 fc = part["functionCall"] 

334 args = fc.get("args", {}) 

335 if isinstance(args, str): 

336 try: 

337 args = json.loads(args) 

338 except json.JSONDecodeError: 

339 args = {} 

340 tool_calls.append(ToolCall( 

341 id=fc.get("name", "unknown"), 

342 name=fc.get("name", "unknown"), 

343 arguments=args, 

344 )) 

345 

346 return ModelResponse( 

347 content="\n".join(text_parts), 

348 tool_calls=tool_calls, 

349 )