Coverage for session_buddy / llm / providers / ollama_provider.py: 42.41%

124 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1"""Ollama local LLM provider implementation. 

2 

3This module provides the Ollama provider implementation using the mcp-common 

4HTTPClientAdapter for connection pooling and aiohttp fallback for HTTP communications. 

5""" 

6 

7from __future__ import annotations 

8 

9import json 

10from datetime import datetime 

11from typing import TYPE_CHECKING, Any 

12 

13from session_buddy.llm.base import LLMProvider 

14from session_buddy.llm.models import LLMMessage, LLMResponse 

15 

16if TYPE_CHECKING: 

17 from collections.abc import AsyncGenerator 

18 

19 from session_buddy.di.container import depends 

20 

21# mcp-common HTTP client adapter (httpx based) 

22try: 

23 from mcp_common.adapters.http.client import HTTPClientAdapter 

24 from session_buddy.di.container import depends 

25 

26 HTTP_ADAPTER_AVAILABLE = True 

27except Exception: 

28 HTTPClientAdapter = None # type: ignore[assignment] 

29 HTTP_ADAPTER_AVAILABLE = False 

30 

31 

32class OllamaProvider(LLMProvider): 

33 """Ollama local LLM provider using HTTPClientAdapter for connection pooling.""" 

34 

35 def __init__(self, config: dict[str, Any]) -> None: 

36 super().__init__(config) 

37 self.base_url = config.get("base_url", "http://localhost:11434") 

38 self.default_model = config.get("default_model", "llama2") 

39 self._available_models: list[str] = [] 

40 

41 # Initialize HTTP client adapter if available 

42 self._http_adapter = None 

43 if HTTP_ADAPTER_AVAILABLE and HTTPClientAdapter is not None: 43 ↛ 44line 43 didn't jump to line 44 because the condition on line 43 was never true

44 try: 

45 self._http_adapter = depends.get_sync(HTTPClientAdapter) 

46 except Exception: 

47 self._http_adapter = None 

48 

49 async def _make_api_request( 

50 self, 

51 endpoint: str, 

52 data: dict[str, Any], 

53 ) -> dict[str, Any]: 

54 """Make API request to Ollama service with connection pooling.""" 

55 url = f"{self.base_url}/{endpoint}" 

56 

57 if self._http_adapter is not None: 

58 try: 

59 async with self._http_adapter as client: 

60 resp = await client.post(url, json=data, timeout=300) 

61 return resp.json() # type: ignore[no-any-return] 

62 except Exception as e: 

63 self.logger.exception(f"HTTP request failed: {e}") 

64 raise 

65 # Fallback to aiohttp (legacy) 

66 try: 

67 import aiohttp 

68 

69 async with ( 

70 aiohttp.ClientSession() as session, 

71 session.post( 

72 url, 

73 json=data, 

74 timeout=aiohttp.ClientTimeout(total=300), 

75 ) as response, 

76 ): 

77 return await response.json() # type: ignore[no-any-return] 

78 except ImportError: 

79 msg = ( 

80 "aiohttp package not installed and HTTPClientAdapter not available. " 

81 "Install with: pip install aiohttp or configure mcp-common HTTPClientAdapter" 

82 ) 

83 raise ImportError(msg) # type: ignore[no-any-return] 

84 

85 def _convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, str]]: 

86 """Convert LLMMessage objects to Ollama format.""" 

87 return [{"role": msg.role, "content": msg.content} for msg in messages] 

88 

89 async def generate( 

90 self, 

91 messages: list[LLMMessage], 

92 model: str | None = None, 

93 temperature: float = 0.7, 

94 max_tokens: int | None = None, 

95 **kwargs: Any, 

96 ) -> LLMResponse: 

97 """Generate response using Ollama API.""" 

98 if not await self.is_available(): 

99 msg = "Ollama provider not available" 

100 raise RuntimeError(msg) 

101 

102 model_name = model or self.default_model 

103 

104 try: 

105 data: dict[str, Any] = { 

106 "model": model_name, 

107 "messages": self._convert_messages(messages), 

108 "options": {"temperature": temperature}, 

109 } 

110 

111 if max_tokens: 

112 data["options"]["num_predict"] = max_tokens 

113 

114 response = await self._make_api_request("api/chat", data) 

115 

116 return LLMResponse( 

117 content=response.get("message", {}).get("content", ""), 

118 model=model_name, 

119 provider="ollama", 

120 usage={ 

121 "prompt_tokens": response.get("prompt_eval_count", 0), 

122 "completion_tokens": response.get("eval_count", 0), 

123 "total_tokens": response.get("prompt_eval_count", 0) 

124 + response.get("eval_count", 0), 

125 }, 

126 finish_reason=response.get("done_reason", "stop"), 

127 timestamp=datetime.now().isoformat(), 

128 ) 

129 

130 except Exception as e: 

131 self.logger.exception(f"Ollama generation failed: {e}") 

132 raise 

133 

134 def _prepare_stream_data( 

135 self, 

136 model_name: str, 

137 messages: list[LLMMessage], 

138 temperature: float, 

139 max_tokens: int | None, 

140 ) -> dict[str, Any]: 

141 """Prepare data payload for streaming request.""" 

142 data: dict[str, Any] = { 

143 "model": model_name, 

144 "messages": self._convert_messages(messages), 

145 "stream": True, 

146 "options": {"temperature": temperature}, 

147 } 

148 if max_tokens: 

149 data["options"]["num_predict"] = max_tokens 

150 return data 

151 

152 def _extract_chunk_content(self, line: bytes) -> str | None: 

153 """Extract content from a streaming chunk line.""" 

154 if not line: 

155 return None 

156 

157 try: 

158 chunk_data = json.loads(line.decode("utf-8")) 

159 if isinstance(chunk_data, dict) and "message" in chunk_data: 

160 message = chunk_data["message"] 

161 if isinstance(message, dict) and "content" in message: 161 ↛ 165line 161 didn't jump to line 165 because the condition on line 161 was always true

162 return str(message["content"]) 

163 except json.JSONDecodeError: 

164 pass 

165 return None 

166 

167 async def _stream_from_response_aiohttp(self, response: Any) -> AsyncGenerator[str]: 

168 """Process streaming response from aiohttp and yield content chunks.""" 

169 async for line in response.content: 

170 content = self._extract_chunk_content(line) 

171 if content: 

172 yield content 

173 

174 async def _stream_from_response_httpx(self, response: Any) -> AsyncGenerator[str]: 

175 """Process streaming response from httpx and yield content chunks.""" 

176 async for line in response.aiter_bytes(): 

177 content = self._extract_chunk_content(line) 

178 if content: 

179 yield content 

180 

181 async def _stream_with_mcp_common( 

182 self, 

183 url: str, 

184 data: dict[str, Any], 

185 ) -> AsyncGenerator[str]: 

186 """Stream using MCP-common HTTP adapter.""" 

187 # Note: http_adapter access requires mcp-common integration setup 

188 # This is a placeholder for future mcp-common integration 

189 if False: # Disabled until http_adapter is properly initialized 

190 yield "" # pragma: no cover 

191 else: 

192 # Fallback to aiohttp for now 

193 async for chunk in self._stream_with_aiohttp(url, data): 

194 yield chunk 

195 

196 async def _stream_with_aiohttp( 

197 self, 

198 url: str, 

199 data: dict[str, Any], 

200 ) -> AsyncGenerator[str]: 

201 """Stream using aiohttp fallback.""" 

202 try: 

203 import aiohttp 

204 

205 async with ( 

206 aiohttp.ClientSession() as session, 

207 session.post( 

208 url, 

209 json=data, 

210 timeout=aiohttp.ClientTimeout(total=300), 

211 ) as response, 

212 ): 

213 async for chunk in self._stream_from_response_aiohttp(response): 

214 yield chunk 

215 except ImportError: 

216 msg = "aiohttp not installed and mcp-common not available" 

217 raise ImportError(msg) 

218 

219 async def stream_generate( # type: ignore[override] 

220 self, 

221 messages: list[LLMMessage], 

222 model: str | None = None, 

223 temperature: float = 0.7, 

224 max_tokens: int | None = None, 

225 **kwargs: Any, 

226 ) -> AsyncGenerator[str]: 

227 """Stream response using Ollama API with connection pooling.""" 

228 if not await self.is_available(): 

229 msg = "Ollama provider not available" 

230 raise RuntimeError(msg) 

231 

232 model_name = model or self.default_model 

233 data = self._prepare_stream_data(model_name, messages, temperature, max_tokens) 

234 url = f"{self.base_url}/api/chat" 

235 

236 try: 

237 # Note: mcp-common integration deferred - using aiohttp fallback 

238 # if self._use_mcp_common and self.http_adapter: 

239 # async for chunk in self._stream_with_mcp_common(url, data): 

240 # yield chunk 

241 # else: 

242 async for chunk in self._stream_with_aiohttp(url, data): 

243 yield chunk 

244 except Exception as e: 

245 self.logger.exception(f"Ollama streaming failed: {e}") 

246 raise 

247 

248 async def _check_with_mcp_common(self, url: str) -> bool: 

249 """Check availability using MCP-common HTTP adapter.""" 

250 # Note: http_adapter access requires mcp-common integration setup 

251 # This is a placeholder for future mcp-common integration 

252 return False # Disabled until http_adapter is properly initialized 

253 

254 async def _check_with_aiohttp(self, url: str) -> bool: 

255 """Check availability using aiohttp fallback.""" 

256 try: 

257 import aiohttp 

258 

259 async with ( 

260 aiohttp.ClientSession() as session, 

261 session.get( 

262 url, 

263 timeout=aiohttp.ClientTimeout(total=10), 

264 ) as response, 

265 ): 

266 if response.status == 200: 

267 data = await response.json() 

268 self._available_models = [ 

269 model["name"] for model in data.get("models", []) 

270 ] 

271 return True 

272 return False 

273 except Exception: 

274 return False 

275 

276 async def is_available(self) -> bool: 

277 """Check if Ollama is available with connection pooling.""" 

278 try: 

279 url = f"{self.base_url}/api/tags" 

280 

281 # Note: mcp-common integration deferred - using aiohttp fallback 

282 # if self._use_mcp_common and self.http_adapter: 

283 # return await self._check_with_mcp_common(url) 

284 return await self._check_with_aiohttp(url) 

285 except Exception: 

286 return False 

287 

288 def get_models(self) -> list[str]: 

289 """Get available Ollama models.""" 

290 return self._available_models or [ 

291 "llama2", 

292 "llama2:13b", 

293 "llama2:70b", 

294 "codellama", 

295 "mistral", 

296 "mixtral", 

297 ]