Coverage for session_buddy / llm / providers / gemini_provider.py: 15.24%
79 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1"""Google Gemini API provider implementation.
3This module provides the Gemini provider implementation using the Google
4Generative AI SDK for chat completions and streaming.
5"""
7from __future__ import annotations
9from datetime import datetime
10from typing import TYPE_CHECKING, Any
12from session_buddy.llm.base import LLMProvider
13from session_buddy.llm.models import LLMMessage, LLMResponse
15if TYPE_CHECKING:
16 from collections.abc import AsyncGenerator
19class GeminiProvider(LLMProvider):
20 """Google Gemini API provider."""
22 def __init__(self, config: dict[str, Any]) -> None:
23 super().__init__(config)
24 self.api_key = config.get("api_key")
25 self.default_model = config.get("default_model", "gemini-pro")
26 self._client = None
28 async def _get_client(self) -> Any:
29 """Get or create Gemini client."""
30 if self._client is None:
31 try:
32 import google.generativeai as genai
34 genai.configure(api_key=self.api_key)
35 self._client = genai
36 except ImportError:
37 msg = "Google Generative AI package not installed. Install with: pip install google-generativeai"
38 raise ImportError(
39 msg,
40 )
41 return self._client
43 def _convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, Any]]:
44 """Convert LLMMessage objects to Gemini format using modern pattern matching."""
45 converted: list[dict[str, Any]] = []
47 for msg in messages:
48 match msg.role:
49 case "system":
50 # Gemini doesn't have system role, prepend to first user message
51 if converted and converted[-1]["role"] == "user":
52 converted[-1]["parts"] = [
53 f"System: {msg.content}\n\nUser: {converted[-1]['parts'][0]}",
54 ]
55 else:
56 converted.append(
57 {"role": "user", "parts": [f"System: {msg.content}"]},
58 )
59 case "user":
60 converted.append({"role": "user", "parts": [msg.content]})
61 case "assistant":
62 converted.append({"role": "model", "parts": [msg.content]})
63 case _:
64 # Unknown role - default to user for safety
65 converted.append({"role": "user", "parts": [msg.content]})
67 return converted
69 async def generate(
70 self,
71 messages: list[LLMMessage],
72 model: str | None = None,
73 temperature: float = 0.7,
74 max_tokens: int | None = None,
75 **kwargs: Any,
76 ) -> LLMResponse:
77 """Generate response using Gemini API."""
78 if not await self.is_available():
79 msg = "Gemini provider not available"
80 raise RuntimeError(msg)
82 genai = await self._get_client()
83 model_name = model or self.default_model
85 try:
86 model_instance = genai.GenerativeModel(model_name)
88 # Convert messages to Gemini chat format
89 chat_messages = self._convert_messages(messages)
91 # Create chat or generate single response
92 if len(chat_messages) > 1:
93 chat = model_instance.start_chat(history=chat_messages[:-1])
94 response = await chat.send_message_async(
95 chat_messages[-1]["parts"][0],
96 generation_config={
97 "temperature": temperature,
98 "max_output_tokens": max_tokens,
99 },
100 )
101 else:
102 response = await model_instance.generate_content_async(
103 chat_messages[0]["parts"][0],
104 generation_config={
105 "temperature": temperature,
106 "max_output_tokens": max_tokens,
107 },
108 )
110 return LLMResponse(
111 content=response.text,
112 model=model_name,
113 provider="gemini",
114 usage={
115 "prompt_tokens": response.usage_metadata.prompt_token_count
116 if hasattr(response, "usage_metadata")
117 else 0,
118 "completion_tokens": response.usage_metadata.candidates_token_count
119 if hasattr(response, "usage_metadata")
120 else 0,
121 "total_tokens": response.usage_metadata.total_token_count
122 if hasattr(response, "usage_metadata")
123 else 0,
124 },
125 finish_reason="stop", # Gemini doesn't provide detailed finish reasons
126 timestamp=datetime.now().isoformat(),
127 )
129 except Exception as e:
130 self.logger.exception(f"Gemini generation failed: {e}")
131 raise
133 async def stream_generate( # type: ignore[override]
134 self,
135 messages: list[LLMMessage],
136 model: str | None = None,
137 temperature: float = 0.7,
138 max_tokens: int | None = None,
139 **kwargs: Any,
140 ) -> AsyncGenerator[str]:
141 """Stream response using Gemini API."""
142 if not await self.is_available():
143 msg = "Gemini provider not available"
144 raise RuntimeError(msg)
146 genai = await self._get_client()
147 model_name = model or self.default_model
149 try:
150 model_instance = genai.GenerativeModel(model_name)
151 chat_messages = self._convert_messages(messages)
153 if len(chat_messages) > 1:
154 chat = model_instance.start_chat(history=chat_messages[:-1])
155 response = chat.send_message(
156 chat_messages[-1]["parts"][0],
157 generation_config={
158 "temperature": temperature,
159 "max_output_tokens": max_tokens,
160 },
161 stream=True,
162 )
163 else:
164 response = model_instance.generate_content(
165 chat_messages[0]["parts"][0],
166 generation_config={
167 "temperature": temperature,
168 "max_output_tokens": max_tokens,
169 },
170 stream=True,
171 )
173 for chunk in response:
174 if chunk.text:
175 yield chunk.text
177 except Exception as e:
178 self.logger.exception(f"Gemini streaming failed: {e}")
179 raise
181 async def is_available(self) -> bool:
182 """Check if Gemini API is available."""
183 if not self.api_key:
184 return False
186 try:
187 genai = await self._get_client()
188 # Test with a simple model list request
189 list(genai.list_models())
190 return True
191 except Exception:
192 return False
194 def get_models(self) -> list[str]:
195 """Get available Gemini models."""
196 return [
197 "gemini-pro",
198 "gemini-pro-vision",
199 "gemini-1.5-pro",
200 "gemini-1.5-flash",
201 "gemini-1.0-pro",
202 ]