Coverage for session_buddy / llm / providers / gemini_provider.py: 15.24%

79 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1"""Google Gemini API provider implementation. 

2 

3This module provides the Gemini provider implementation using the Google 

4Generative AI SDK for chat completions and streaming. 

5""" 

6 

7from __future__ import annotations 

8 

9from datetime import datetime 

10from typing import TYPE_CHECKING, Any 

11 

12from session_buddy.llm.base import LLMProvider 

13from session_buddy.llm.models import LLMMessage, LLMResponse 

14 

15if TYPE_CHECKING: 

16 from collections.abc import AsyncGenerator 

17 

18 

19class GeminiProvider(LLMProvider): 

20 """Google Gemini API provider.""" 

21 

22 def __init__(self, config: dict[str, Any]) -> None: 

23 super().__init__(config) 

24 self.api_key = config.get("api_key") 

25 self.default_model = config.get("default_model", "gemini-pro") 

26 self._client = None 

27 

28 async def _get_client(self) -> Any: 

29 """Get or create Gemini client.""" 

30 if self._client is None: 

31 try: 

32 import google.generativeai as genai 

33 

34 genai.configure(api_key=self.api_key) 

35 self._client = genai 

36 except ImportError: 

37 msg = "Google Generative AI package not installed. Install with: pip install google-generativeai" 

38 raise ImportError( 

39 msg, 

40 ) 

41 return self._client 

42 

43 def _convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, Any]]: 

44 """Convert LLMMessage objects to Gemini format using modern pattern matching.""" 

45 converted: list[dict[str, Any]] = [] 

46 

47 for msg in messages: 

48 match msg.role: 

49 case "system": 

50 # Gemini doesn't have system role, prepend to first user message 

51 if converted and converted[-1]["role"] == "user": 

52 converted[-1]["parts"] = [ 

53 f"System: {msg.content}\n\nUser: {converted[-1]['parts'][0]}", 

54 ] 

55 else: 

56 converted.append( 

57 {"role": "user", "parts": [f"System: {msg.content}"]}, 

58 ) 

59 case "user": 

60 converted.append({"role": "user", "parts": [msg.content]}) 

61 case "assistant": 

62 converted.append({"role": "model", "parts": [msg.content]}) 

63 case _: 

64 # Unknown role - default to user for safety 

65 converted.append({"role": "user", "parts": [msg.content]}) 

66 

67 return converted 

68 

69 async def generate( 

70 self, 

71 messages: list[LLMMessage], 

72 model: str | None = None, 

73 temperature: float = 0.7, 

74 max_tokens: int | None = None, 

75 **kwargs: Any, 

76 ) -> LLMResponse: 

77 """Generate response using Gemini API.""" 

78 if not await self.is_available(): 

79 msg = "Gemini provider not available" 

80 raise RuntimeError(msg) 

81 

82 genai = await self._get_client() 

83 model_name = model or self.default_model 

84 

85 try: 

86 model_instance = genai.GenerativeModel(model_name) 

87 

88 # Convert messages to Gemini chat format 

89 chat_messages = self._convert_messages(messages) 

90 

91 # Create chat or generate single response 

92 if len(chat_messages) > 1: 

93 chat = model_instance.start_chat(history=chat_messages[:-1]) 

94 response = await chat.send_message_async( 

95 chat_messages[-1]["parts"][0], 

96 generation_config={ 

97 "temperature": temperature, 

98 "max_output_tokens": max_tokens, 

99 }, 

100 ) 

101 else: 

102 response = await model_instance.generate_content_async( 

103 chat_messages[0]["parts"][0], 

104 generation_config={ 

105 "temperature": temperature, 

106 "max_output_tokens": max_tokens, 

107 }, 

108 ) 

109 

110 return LLMResponse( 

111 content=response.text, 

112 model=model_name, 

113 provider="gemini", 

114 usage={ 

115 "prompt_tokens": response.usage_metadata.prompt_token_count 

116 if hasattr(response, "usage_metadata") 

117 else 0, 

118 "completion_tokens": response.usage_metadata.candidates_token_count 

119 if hasattr(response, "usage_metadata") 

120 else 0, 

121 "total_tokens": response.usage_metadata.total_token_count 

122 if hasattr(response, "usage_metadata") 

123 else 0, 

124 }, 

125 finish_reason="stop", # Gemini doesn't provide detailed finish reasons 

126 timestamp=datetime.now().isoformat(), 

127 ) 

128 

129 except Exception as e: 

130 self.logger.exception(f"Gemini generation failed: {e}") 

131 raise 

132 

133 async def stream_generate( # type: ignore[override] 

134 self, 

135 messages: list[LLMMessage], 

136 model: str | None = None, 

137 temperature: float = 0.7, 

138 max_tokens: int | None = None, 

139 **kwargs: Any, 

140 ) -> AsyncGenerator[str]: 

141 """Stream response using Gemini API.""" 

142 if not await self.is_available(): 

143 msg = "Gemini provider not available" 

144 raise RuntimeError(msg) 

145 

146 genai = await self._get_client() 

147 model_name = model or self.default_model 

148 

149 try: 

150 model_instance = genai.GenerativeModel(model_name) 

151 chat_messages = self._convert_messages(messages) 

152 

153 if len(chat_messages) > 1: 

154 chat = model_instance.start_chat(history=chat_messages[:-1]) 

155 response = chat.send_message( 

156 chat_messages[-1]["parts"][0], 

157 generation_config={ 

158 "temperature": temperature, 

159 "max_output_tokens": max_tokens, 

160 }, 

161 stream=True, 

162 ) 

163 else: 

164 response = model_instance.generate_content( 

165 chat_messages[0]["parts"][0], 

166 generation_config={ 

167 "temperature": temperature, 

168 "max_output_tokens": max_tokens, 

169 }, 

170 stream=True, 

171 ) 

172 

173 for chunk in response: 

174 if chunk.text: 

175 yield chunk.text 

176 

177 except Exception as e: 

178 self.logger.exception(f"Gemini streaming failed: {e}") 

179 raise 

180 

181 async def is_available(self) -> bool: 

182 """Check if Gemini API is available.""" 

183 if not self.api_key: 

184 return False 

185 

186 try: 

187 genai = await self._get_client() 

188 # Test with a simple model list request 

189 list(genai.list_models()) 

190 return True 

191 except Exception: 

192 return False 

193 

194 def get_models(self) -> list[str]: 

195 """Get available Gemini models.""" 

196 return [ 

197 "gemini-pro", 

198 "gemini-pro-vision", 

199 "gemini-1.5-pro", 

200 "gemini-1.5-flash", 

201 "gemini-1.0-pro", 

202 ]