Coverage for agentos/models/backends/gemini.py: 25%
154 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2AgentOS v0.70 — Google Gemini Provider 全集成。
3基因来源: Google AI Studio SDK + Vertex AI
4支持: Gemini 2.5 Pro/Flash、Vision、System Instruction、Streaming、Token Counting、Safety Settings。
5"""
7from __future__ import annotations
9import asyncio
10import json
11import os
12from dataclasses import dataclass, field
13from typing import Any, AsyncIterator
15import httpx
17from agentos.models.router import ModelResponse, ModelSpec
18from agentos.core.context import AgentContext
19from agentos.tools.base import ToolCall
22# ── Gemini Public API Endpoint ──────────────────
23GEMINI_API_BASE = "https://generativelanguage.googleapis.com/v1beta"
25# Prebuilt Gemini model specs
26GEMINI_MODELS: dict[str, ModelSpec] = {
27 "gemini-2.5-pro": ModelSpec(
28 provider="gemini",
29 model_id="gemini-2.5-pro-exp-03-25",
30 context_window=1_048_576,
31 cost_per_1m_input=1.25,
32 cost_per_1m_output=10.00,
33 ),
34 "gemini-2.5-flash": ModelSpec(
35 provider="gemini",
36 model_id="gemini-2.5-flash-preview-04-17",
37 context_window=1_048_576,
38 cost_per_1m_input=0.15,
39 cost_per_1m_output=0.60,
40 ),
41 "gemini-2.0-flash": ModelSpec(
42 provider="gemini",
43 model_id="gemini-2.0-flash",
44 context_window=1_048_576,
45 cost_per_1m_input=0.10,
46 cost_per_1m_output=0.40,
47 ),
48}
51@dataclass
52class GeminiSafetySetting:
53 """安全过滤配置。"""
55 category: str # HARM_CATEGORY_HARASSMENT | HATE_SPEECH | SEXUALLY_EXPLICIT | DANGEROUS_CONTENT
56 threshold: str = "BLOCK_ONLY_HIGH" # BLOCK_NONE | BLOCK_ONLY_HIGH | BLOCK_MEDIUM_AND_ABOVE | BLOCK_LOW_AND_ABOVE
59@dataclass
60class GeminiConfig:
61 """Gemini调用配置。"""
63 api_key: str = ""
64 temperature: float = 0.7
65 top_p: float = 0.95
66 top_k: int = 40
67 max_output_tokens: int = 8192
68 safety_settings: list[GeminiSafetySetting] = field(default_factory=lambda: [
69 GeminiSafetySetting("HARM_CATEGORY_HARASSMENT", "BLOCK_ONLY_HIGH"),
70 GeminiSafetySetting("HARM_CATEGORY_HATE_SPEECH", "BLOCK_ONLY_HIGH"),
71 GeminiSafetySetting("HARM_CATEGORY_SEXUALLY_EXPLICIT", "BLOCK_ONLY_HIGH"),
72 GeminiSafetySetting("HARM_CATEGORY_DANGEROUS_CONTENT", "BLOCK_ONLY_HIGH"),
73 ])
76# ── Tool Declaration Helpers ─────────────────────
78def _convert_tools_to_gemini(openai_tools: list[dict]) -> list[dict]:
79 """将OpenAI格式的tools转换为Gemini functionDeclarations。"""
80 declarations = []
81 for tool in openai_tools:
82 if tool.get("type") != "function":
83 continue
84 func = tool.get("function", {})
85 declarations.append({
86 "name": func.get("name", ""),
87 "description": func.get("description", ""),
88 "parameters": func.get("parameters", {}),
89 })
90 return [{"function_declarations": declarations}] if declarations else []
93def _convert_gemini_tool_calls(parts: list[dict]) -> list[ToolCall]:
94 """将Gemini functionCall parts转为ToolCall列表。"""
95 tool_calls = []
96 for part in parts:
97 fc = part.get("functionCall")
98 if not fc:
99 continue
100 args = fc.get("args", {})
101 if isinstance(args, str):
102 try:
103 args = json.loads(args)
104 except json.JSONDecodeError:
105 args = {}
106 tool_calls.append(ToolCall(
107 id=fc.get("name", "unknown"),
108 name=fc.get("name", "unknown"),
109 arguments=args,
110 ))
111 return tool_calls
114# ── Core Gemini Client ───────────────────────────
116class GeminiClient:
117 """
118 Google Gemini API 客户端。
119 支持: chat/completions、Vision多模态、Streaming、System Instruction。
120 """
122 def __init__(
123 self,
124 config: GeminiConfig | None = None,
125 http_client: httpx.AsyncClient | None = None,
126 ):
127 self.config = config or GeminiConfig()
128 self._http = http_client or httpx.AsyncClient(timeout=180)
129 self._owned_http = http_client is None
131 @property
132 def api_key(self) -> str:
133 return self.config.api_key or os.environ.get("GEMINI_API_KEY", "")
135 async def close(self):
136 if self._owned_http:
137 await self._http.aclose()
139 async def call(
140 self,
141 spec: ModelSpec,
142 context: AgentContext,
143 ) -> ModelResponse:
144 """同步调用Gemini API。"""
145 contents, system_instruction = self._build_gemini_contents(context)
146 body = {
147 "contents": contents,
148 "generationConfig": {
149 "temperature": self.config.temperature,
150 "topP": self.config.top_p,
151 "topK": self.config.top_k,
152 "maxOutputTokens": self.config.max_output_tokens,
153 },
154 "safetySettings": [
155 {"category": s.category, "threshold": s.threshold}
156 for s in self.config.safety_settings
157 ],
158 }
159 if system_instruction:
160 body["systemInstruction"] = system_instruction
162 if context.tools:
163 body["tools"] = _convert_tools_to_gemini(context.tools)
165 url = f"{GEMINI_API_BASE}/models/{spec.model_id}:generateContent?key={self.api_key}"
166 resp = await self._http.post(url, json=body)
167 resp.raise_for_status()
168 data = resp.json()
170 return self._parse_response(data)
172 async def call_stream(
173 self,
174 spec: ModelSpec,
175 context: AgentContext,
176 ) -> AsyncIterator[dict]:
177 """流式调用Gemini API,逐个yield chunk。"""
178 contents, system_instruction = self._build_gemini_contents(context)
179 body = {
180 "contents": contents,
181 "generationConfig": {
182 "temperature": self.config.temperature,
183 "topP": self.config.top_p,
184 "topK": self.config.top_k,
185 "maxOutputTokens": self.config.max_output_tokens,
186 },
187 "safetySettings": [
188 {"category": s.category, "threshold": s.threshold}
189 for s in self.config.safety_settings
190 ],
191 }
192 if system_instruction:
193 body["systemInstruction"] = system_instruction
195 url = f"{GEMINI_API_BASE}/models/{spec.model_id}:streamGenerateContent?alt=sse&key={self.api_key}"
196 async with self._http.stream("POST", url, json=body) as resp:
197 resp.raise_for_status()
198 async for line in resp.aiter_lines():
199 line = line.strip()
200 if not line or not line.startswith("data: "):
201 continue
202 data_str = line[6:]
203 if data_str == "[DONE]":
204 break
205 try:
206 chunk = json.loads(data_str)
207 except json.JSONDecodeError:
208 continue
209 # skip safety / promptFeedback
210 if "candidates" not in chunk:
211 continue
212 yield chunk
214 async def call_with_image(
215 self,
216 spec: ModelSpec,
217 prompt: str,
218 image_data: bytes,
219 mime_type: str = "image/jpeg",
220 ) -> ModelResponse:
221 """Vision多模态调用。image_data为base64之前的内容。"""
222 import base64
223 b64 = base64.b64encode(image_data).decode()
224 contents = [{
225 "role": "user",
226 "parts": [
227 {"text": prompt},
228 {"inlineData": {"mimeType": mime_type, "data": b64}},
229 ],
230 }]
231 body = {
232 "contents": contents,
233 "generationConfig": {
234 "temperature": self.config.temperature,
235 "maxOutputTokens": self.config.max_output_tokens,
236 },
237 }
238 url = f"{GEMINI_API_BASE}/models/{spec.model_id}:generateContent?key={self.api_key}"
239 resp = await self._http.post(url, json=body)
240 resp.raise_for_status()
241 return self._parse_response(resp.json())
243 async def count_tokens(self, spec: ModelSpec, context: AgentContext) -> dict:
244 """使用Gemini API统计输入/输出token数。"""
245 contents, _ = self._build_gemini_contents(context)
246 url = f"{GEMINI_API_BASE}/models/{spec.model_id}:countTokens?key={self.api_key}"
247 resp = await self._http.post(url, json={"contents": contents})
248 resp.raise_for_status()
249 data = resp.json()
250 return {
251 "total_tokens": data.get("totalTokens", 0),
252 "prompt_tokens": data.get("totalTokens", 0), # Gemini不区分输入输出
253 "model": spec.model_id,
254 }
256 # ── Internal helpers ──────────────────────────
258 def _build_gemini_contents(self, context: AgentContext) -> tuple[list[dict], dict | None]:
259 """将AgentContext转为Gemini contents格式。"""
260 contents = []
261 system_instruction = None
263 for msg in context.messages:
264 role = self._map_role(msg.role)
265 parts = []
267 # system prompt → systemInstruction
268 if msg.role == "system":
269 system_instruction = {"parts": [{"text": msg.content}]}
270 continue
272 # text content
273 if msg.content:
274 parts.append({"text": msg.content})
276 # tool calls from assistant
277 if msg.tool_calls:
278 for tc in msg.tool_calls:
279 parts.append({
280 "functionCall": {
281 "name": tc.name,
282 "args": tc.arguments,
283 }
284 })
286 # tool results
287 if msg.role == "tool" and msg.tool_call_id:
288 # Gemini uses functionResponse in user role
289 parts.append({
290 "functionResponse": {
291 "name": msg.tool_call_id,
292 "response": {"content": msg.content},
293 }
294 })
296 if parts:
297 contents.append({"role": role, "parts": parts})
299 # Ensure there's at least a user message
300 if not contents:
301 contents = [{"role": "user", "parts": [{"text": context.current_task or ""}]}]
303 return contents, system_instruction
305 def _map_role(self, role: str) -> str:
306 mapping = {
307 "user": "user",
308 "assistant": "model",
309 "system": "user", # handled separately via systemInstruction
310 "tool": "user", # functionResponse must be in user turn
311 }
312 return mapping.get(role, "user")
314 def _parse_response(self, data: dict) -> ModelResponse:
315 """解析Gemini API响应为ModelResponse。"""
316 candidates = data.get("candidates", [])
317 if not candidates:
318 # Safety blocked
319 block_reason = data.get("promptFeedback", {}).get("blockReason", "unknown")
320 return ModelResponse(content=f"[SAFETY_BLOCKED] {block_reason}")
322 candidate = candidates[0]
323 content = candidate.get("content", {})
324 parts = content.get("parts", [])
326 text_parts = []
327 tool_calls = []
329 for part in parts:
330 if "text" in part:
331 text_parts.append(part["text"])
332 if "functionCall" in part:
333 fc = part["functionCall"]
334 args = fc.get("args", {})
335 if isinstance(args, str):
336 try:
337 args = json.loads(args)
338 except json.JSONDecodeError:
339 args = {}
340 tool_calls.append(ToolCall(
341 id=fc.get("name", "unknown"),
342 name=fc.get("name", "unknown"),
343 arguments=args,
344 ))
346 return ModelResponse(
347 content="\n".join(text_parts),
348 tool_calls=tool_calls,
349 )