Coverage for providers / vertex.py: 0%
102 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-29 02:55 +0800
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-29 02:55 +0800
1import os
2import json
3import base64
4from google import genai
5from google.genai import types
6from google.genai.types import HttpOptions
7from qrclaw.providers.base import LLMProvider, LLMResponse, ToolCall
8from qrclaw.config import OPENAI_MODEL, VERTEX_API_KEY
9from qrclaw.logger import get_logger
11logger = get_logger("qrclaw.providers.vertex")
13# thought_signature 在 tool_call 里的存储 key
14_TS_KEY = "__thought_signature__"
17def _build_vertex_tools(schemas: list[dict]) -> list[types.Tool] | None:
18 """把 OpenAI tool schema 转换成 Vertex AI 的 Tool 格式"""
19 if not schemas:
20 return None
21 declarations = []
22 for s in schemas:
23 fn = s["function"]
24 params = fn.get("parameters")
25 declarations.append(types.FunctionDeclaration(
26 name=fn["name"],
27 description=fn.get("description", ""),
28 parameters=params,
29 ))
30 return [types.Tool(function_declarations=declarations)]
33class VertexProvider(LLMProvider):
35 def __init__(self):
36 os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "True"
37 self._client = genai.Client(
38 http_options=HttpOptions(api_version="v1"),
39 api_key=VERTEX_API_KEY,
40 )
41 logger.info("Vertex AI 渠道已初始化")
43 def chat(self, messages: list[dict], tools: list[dict] | None = None) -> LLMResponse:
44 contents = []
45 system_parts = []
47 i = 0
48 while i < len(messages):
49 msg = messages[i]
50 role = msg.get("role")
51 content = msg.get("content", "")
53 if role == "system":
54 if content:
55 system_parts.append(content)
56 i += 1
58 elif role == "user":
59 contents.append(types.Content(
60 role="user",
61 parts=[types.Part(text=content or "")]
62 ))
63 i += 1
65 elif role == "assistant":
66 parts = []
67 if content:
68 parts.append(types.Part(text=content))
69 # 处理 tool_calls,恢复 thought_signature
70 for tc in msg.get("tool_calls", []):
71 fn = tc["function"]
72 try:
73 args = json.loads(fn["arguments"])
74 except Exception:
75 args = {}
76 ts_b64 = tc.get(_TS_KEY)
77 ts_bytes = base64.b64decode(ts_b64) if ts_b64 else None
79 part = types.Part(
80 function_call=types.FunctionCall(
81 name=fn["name"],
82 args=args,
83 )
84 )
85 if ts_bytes:
86 part.thought_signature = ts_bytes
87 parts.append(part)
88 if parts:
89 contents.append(types.Content(role="model", parts=parts))
90 i += 1
92 elif role == "tool":
93 # 把连续的 tool 消息合并到同一个 Content,Vertex AI 要求 function_response 数量与 function_call 一致
94 tool_parts = []
95 while i < len(messages) and messages[i].get("role") == "tool":
96 m = messages[i]
97 tool_call_id = m.get("tool_call_id", "")
98 c = m.get("content", "")
99 try:
100 result = json.loads(c) if c else {}
101 if not isinstance(result, dict):
102 result = {"result": c}
103 except Exception:
104 result = {"result": c}
105 tool_parts.append(types.Part(
106 function_response=types.FunctionResponse(
107 name=tool_call_id,
108 response=result,
109 )
110 ))
111 i += 1
112 contents.append(types.Content(role="user", parts=tool_parts))
114 else:
115 i += 1
117 config_kwargs = {}
118 if system_parts:
119 config_kwargs["system_instruction"] = "\n".join(system_parts)
120 vertex_tools = _build_vertex_tools(tools)
121 if vertex_tools:
122 config_kwargs["tools"] = vertex_tools
124 config = types.GenerateContentConfig(**config_kwargs) if config_kwargs else None
126 response = self._client.models.generate_content(
127 model=OPENAI_MODEL,
128 contents=contents,
129 config=config,
130 )
132 # 解析响应,把 thought_signature 用 base64 存进 ToolCall
133 tool_calls = []
134 text_content = ""
135 finish_reason = "stop"
137 candidate = response.candidates[0]
138 for part in candidate.content.parts:
139 if part.text:
140 text_content += part.text
141 elif part.function_call:
142 fc = part.function_call
143 ts = getattr(part, "thought_signature", None)
144 ts_b64 = base64.b64encode(ts).decode() if ts else None
145 tool_calls.append(ToolCall(
146 id=fc.name,
147 name=fc.name,
148 arguments=json.dumps(dict(fc.args), ensure_ascii=False),
149 thought_signature=ts_b64,
150 ))
152 if tool_calls:
153 finish_reason = "tool_calls"
155 usage = response.usage_metadata
156 prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
157 completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
159 logger.info(f"Vertex AI 响应成功,tokens: {prompt_tokens + completion_tokens}")
160 return LLMResponse(
161 content=text_content,
162 tool_calls=tool_calls,
163 finish_reason=finish_reason,
164 prompt_tokens=prompt_tokens,
165 completion_tokens=completion_tokens,
166 total_tokens=prompt_tokens + completion_tokens,
167 raw=response,
168 )