Coverage for providers / vertex.py: 0%

102 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-29 02:55 +0800

1import os 

2import json 

3import base64 

4from google import genai 

5from google.genai import types 

6from google.genai.types import HttpOptions 

7from qrclaw.providers.base import LLMProvider, LLMResponse, ToolCall 

8from qrclaw.config import OPENAI_MODEL, VERTEX_API_KEY 

9from qrclaw.logger import get_logger 

10 

11logger = get_logger("qrclaw.providers.vertex") 

12 

13# thought_signature 在 tool_call 里的存储 key 

14_TS_KEY = "__thought_signature__" 

15 

16 

17def _build_vertex_tools(schemas: list[dict]) -> list[types.Tool] | None: 

18 """把 OpenAI tool schema 转换成 Vertex AI 的 Tool 格式""" 

19 if not schemas: 

20 return None 

21 declarations = [] 

22 for s in schemas: 

23 fn = s["function"] 

24 params = fn.get("parameters") 

25 declarations.append(types.FunctionDeclaration( 

26 name=fn["name"], 

27 description=fn.get("description", ""), 

28 parameters=params, 

29 )) 

30 return [types.Tool(function_declarations=declarations)] 

31 

32 

33class VertexProvider(LLMProvider): 

34 

35 def __init__(self): 

36 os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "True" 

37 self._client = genai.Client( 

38 http_options=HttpOptions(api_version="v1"), 

39 api_key=VERTEX_API_KEY, 

40 ) 

41 logger.info("Vertex AI 渠道已初始化") 

42 

43 def chat(self, messages: list[dict], tools: list[dict] | None = None) -> LLMResponse: 

44 contents = [] 

45 system_parts = [] 

46 

47 i = 0 

48 while i < len(messages): 

49 msg = messages[i] 

50 role = msg.get("role") 

51 content = msg.get("content", "") 

52 

53 if role == "system": 

54 if content: 

55 system_parts.append(content) 

56 i += 1 

57 

58 elif role == "user": 

59 contents.append(types.Content( 

60 role="user", 

61 parts=[types.Part(text=content or "")] 

62 )) 

63 i += 1 

64 

65 elif role == "assistant": 

66 parts = [] 

67 if content: 

68 parts.append(types.Part(text=content)) 

69 # 处理 tool_calls,恢复 thought_signature 

70 for tc in msg.get("tool_calls", []): 

71 fn = tc["function"] 

72 try: 

73 args = json.loads(fn["arguments"]) 

74 except Exception: 

75 args = {} 

76 ts_b64 = tc.get(_TS_KEY) 

77 ts_bytes = base64.b64decode(ts_b64) if ts_b64 else None 

78 

79 part = types.Part( 

80 function_call=types.FunctionCall( 

81 name=fn["name"], 

82 args=args, 

83 ) 

84 ) 

85 if ts_bytes: 

86 part.thought_signature = ts_bytes 

87 parts.append(part) 

88 if parts: 

89 contents.append(types.Content(role="model", parts=parts)) 

90 i += 1 

91 

92 elif role == "tool": 

93 # 把连续的 tool 消息合并到同一个 Content,Vertex AI 要求 function_response 数量与 function_call 一致 

94 tool_parts = [] 

95 while i < len(messages) and messages[i].get("role") == "tool": 

96 m = messages[i] 

97 tool_call_id = m.get("tool_call_id", "") 

98 c = m.get("content", "") 

99 try: 

100 result = json.loads(c) if c else {} 

101 if not isinstance(result, dict): 

102 result = {"result": c} 

103 except Exception: 

104 result = {"result": c} 

105 tool_parts.append(types.Part( 

106 function_response=types.FunctionResponse( 

107 name=tool_call_id, 

108 response=result, 

109 ) 

110 )) 

111 i += 1 

112 contents.append(types.Content(role="user", parts=tool_parts)) 

113 

114 else: 

115 i += 1 

116 

117 config_kwargs = {} 

118 if system_parts: 

119 config_kwargs["system_instruction"] = "\n".join(system_parts) 

120 vertex_tools = _build_vertex_tools(tools) 

121 if vertex_tools: 

122 config_kwargs["tools"] = vertex_tools 

123 

124 config = types.GenerateContentConfig(**config_kwargs) if config_kwargs else None 

125 

126 response = self._client.models.generate_content( 

127 model=OPENAI_MODEL, 

128 contents=contents, 

129 config=config, 

130 ) 

131 

132 # 解析响应,把 thought_signature 用 base64 存进 ToolCall 

133 tool_calls = [] 

134 text_content = "" 

135 finish_reason = "stop" 

136 

137 candidate = response.candidates[0] 

138 for part in candidate.content.parts: 

139 if part.text: 

140 text_content += part.text 

141 elif part.function_call: 

142 fc = part.function_call 

143 ts = getattr(part, "thought_signature", None) 

144 ts_b64 = base64.b64encode(ts).decode() if ts else None 

145 tool_calls.append(ToolCall( 

146 id=fc.name, 

147 name=fc.name, 

148 arguments=json.dumps(dict(fc.args), ensure_ascii=False), 

149 thought_signature=ts_b64, 

150 )) 

151 

152 if tool_calls: 

153 finish_reason = "tool_calls" 

154 

155 usage = response.usage_metadata 

156 prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0 

157 completion_tokens = getattr(usage, "candidates_token_count", 0) or 0 

158 

159 logger.info(f"Vertex AI 响应成功,tokens: {prompt_tokens + completion_tokens}") 

160 return LLMResponse( 

161 content=text_content, 

162 tool_calls=tool_calls, 

163 finish_reason=finish_reason, 

164 prompt_tokens=prompt_tokens, 

165 completion_tokens=completion_tokens, 

166 total_tokens=prompt_tokens + completion_tokens, 

167 raw=response, 

168 )