Coverage for session_buddy / llm / providers / ollama_provider.py: 42.41%
124 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1"""Ollama local LLM provider implementation.
3This module provides the Ollama provider implementation using the mcp-common
4HTTPClientAdapter for connection pooling and aiohttp fallback for HTTP communications.
5"""
7from __future__ import annotations
9import json
10from datetime import datetime
11from typing import TYPE_CHECKING, Any
13from session_buddy.llm.base import LLMProvider
14from session_buddy.llm.models import LLMMessage, LLMResponse
16if TYPE_CHECKING:
17 from collections.abc import AsyncGenerator
19 from session_buddy.di.container import depends
21# mcp-common HTTP client adapter (httpx based)
22try:
23 from mcp_common.adapters.http.client import HTTPClientAdapter
24 from session_buddy.di.container import depends
26 HTTP_ADAPTER_AVAILABLE = True
27except Exception:
28 HTTPClientAdapter = None # type: ignore[assignment]
29 HTTP_ADAPTER_AVAILABLE = False
32class OllamaProvider(LLMProvider):
33 """Ollama local LLM provider using HTTPClientAdapter for connection pooling."""
35 def __init__(self, config: dict[str, Any]) -> None:
36 super().__init__(config)
37 self.base_url = config.get("base_url", "http://localhost:11434")
38 self.default_model = config.get("default_model", "llama2")
39 self._available_models: list[str] = []
41 # Initialize HTTP client adapter if available
42 self._http_adapter = None
43 if HTTP_ADAPTER_AVAILABLE and HTTPClientAdapter is not None: 43 ↛ 44line 43 didn't jump to line 44 because the condition on line 43 was never true
44 try:
45 self._http_adapter = depends.get_sync(HTTPClientAdapter)
46 except Exception:
47 self._http_adapter = None
49 async def _make_api_request(
50 self,
51 endpoint: str,
52 data: dict[str, Any],
53 ) -> dict[str, Any]:
54 """Make API request to Ollama service with connection pooling."""
55 url = f"{self.base_url}/{endpoint}"
57 if self._http_adapter is not None:
58 try:
59 async with self._http_adapter as client:
60 resp = await client.post(url, json=data, timeout=300)
61 return resp.json() # type: ignore[no-any-return]
62 except Exception as e:
63 self.logger.exception(f"HTTP request failed: {e}")
64 raise
65 # Fallback to aiohttp (legacy)
66 try:
67 import aiohttp
69 async with (
70 aiohttp.ClientSession() as session,
71 session.post(
72 url,
73 json=data,
74 timeout=aiohttp.ClientTimeout(total=300),
75 ) as response,
76 ):
77 return await response.json() # type: ignore[no-any-return]
78 except ImportError:
79 msg = (
80 "aiohttp package not installed and HTTPClientAdapter not available. "
81 "Install with: pip install aiohttp or configure mcp-common HTTPClientAdapter"
82 )
83 raise ImportError(msg) # type: ignore[no-any-return]
85 def _convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, str]]:
86 """Convert LLMMessage objects to Ollama format."""
87 return [{"role": msg.role, "content": msg.content} for msg in messages]
89 async def generate(
90 self,
91 messages: list[LLMMessage],
92 model: str | None = None,
93 temperature: float = 0.7,
94 max_tokens: int | None = None,
95 **kwargs: Any,
96 ) -> LLMResponse:
97 """Generate response using Ollama API."""
98 if not await self.is_available():
99 msg = "Ollama provider not available"
100 raise RuntimeError(msg)
102 model_name = model or self.default_model
104 try:
105 data: dict[str, Any] = {
106 "model": model_name,
107 "messages": self._convert_messages(messages),
108 "options": {"temperature": temperature},
109 }
111 if max_tokens:
112 data["options"]["num_predict"] = max_tokens
114 response = await self._make_api_request("api/chat", data)
116 return LLMResponse(
117 content=response.get("message", {}).get("content", ""),
118 model=model_name,
119 provider="ollama",
120 usage={
121 "prompt_tokens": response.get("prompt_eval_count", 0),
122 "completion_tokens": response.get("eval_count", 0),
123 "total_tokens": response.get("prompt_eval_count", 0)
124 + response.get("eval_count", 0),
125 },
126 finish_reason=response.get("done_reason", "stop"),
127 timestamp=datetime.now().isoformat(),
128 )
130 except Exception as e:
131 self.logger.exception(f"Ollama generation failed: {e}")
132 raise
134 def _prepare_stream_data(
135 self,
136 model_name: str,
137 messages: list[LLMMessage],
138 temperature: float,
139 max_tokens: int | None,
140 ) -> dict[str, Any]:
141 """Prepare data payload for streaming request."""
142 data: dict[str, Any] = {
143 "model": model_name,
144 "messages": self._convert_messages(messages),
145 "stream": True,
146 "options": {"temperature": temperature},
147 }
148 if max_tokens:
149 data["options"]["num_predict"] = max_tokens
150 return data
152 def _extract_chunk_content(self, line: bytes) -> str | None:
153 """Extract content from a streaming chunk line."""
154 if not line:
155 return None
157 try:
158 chunk_data = json.loads(line.decode("utf-8"))
159 if isinstance(chunk_data, dict) and "message" in chunk_data:
160 message = chunk_data["message"]
161 if isinstance(message, dict) and "content" in message: 161 ↛ 165line 161 didn't jump to line 165 because the condition on line 161 was always true
162 return str(message["content"])
163 except json.JSONDecodeError:
164 pass
165 return None
167 async def _stream_from_response_aiohttp(self, response: Any) -> AsyncGenerator[str]:
168 """Process streaming response from aiohttp and yield content chunks."""
169 async for line in response.content:
170 content = self._extract_chunk_content(line)
171 if content:
172 yield content
174 async def _stream_from_response_httpx(self, response: Any) -> AsyncGenerator[str]:
175 """Process streaming response from httpx and yield content chunks."""
176 async for line in response.aiter_bytes():
177 content = self._extract_chunk_content(line)
178 if content:
179 yield content
181 async def _stream_with_mcp_common(
182 self,
183 url: str,
184 data: dict[str, Any],
185 ) -> AsyncGenerator[str]:
186 """Stream using MCP-common HTTP adapter."""
187 # Note: http_adapter access requires mcp-common integration setup
188 # This is a placeholder for future mcp-common integration
189 if False: # Disabled until http_adapter is properly initialized
190 yield "" # pragma: no cover
191 else:
192 # Fallback to aiohttp for now
193 async for chunk in self._stream_with_aiohttp(url, data):
194 yield chunk
196 async def _stream_with_aiohttp(
197 self,
198 url: str,
199 data: dict[str, Any],
200 ) -> AsyncGenerator[str]:
201 """Stream using aiohttp fallback."""
202 try:
203 import aiohttp
205 async with (
206 aiohttp.ClientSession() as session,
207 session.post(
208 url,
209 json=data,
210 timeout=aiohttp.ClientTimeout(total=300),
211 ) as response,
212 ):
213 async for chunk in self._stream_from_response_aiohttp(response):
214 yield chunk
215 except ImportError:
216 msg = "aiohttp not installed and mcp-common not available"
217 raise ImportError(msg)
219 async def stream_generate( # type: ignore[override]
220 self,
221 messages: list[LLMMessage],
222 model: str | None = None,
223 temperature: float = 0.7,
224 max_tokens: int | None = None,
225 **kwargs: Any,
226 ) -> AsyncGenerator[str]:
227 """Stream response using Ollama API with connection pooling."""
228 if not await self.is_available():
229 msg = "Ollama provider not available"
230 raise RuntimeError(msg)
232 model_name = model or self.default_model
233 data = self._prepare_stream_data(model_name, messages, temperature, max_tokens)
234 url = f"{self.base_url}/api/chat"
236 try:
237 # Note: mcp-common integration deferred - using aiohttp fallback
238 # if self._use_mcp_common and self.http_adapter:
239 # async for chunk in self._stream_with_mcp_common(url, data):
240 # yield chunk
241 # else:
242 async for chunk in self._stream_with_aiohttp(url, data):
243 yield chunk
244 except Exception as e:
245 self.logger.exception(f"Ollama streaming failed: {e}")
246 raise
248 async def _check_with_mcp_common(self, url: str) -> bool:
249 """Check availability using MCP-common HTTP adapter."""
250 # Note: http_adapter access requires mcp-common integration setup
251 # This is a placeholder for future mcp-common integration
252 return False # Disabled until http_adapter is properly initialized
254 async def _check_with_aiohttp(self, url: str) -> bool:
255 """Check availability using aiohttp fallback."""
256 try:
257 import aiohttp
259 async with (
260 aiohttp.ClientSession() as session,
261 session.get(
262 url,
263 timeout=aiohttp.ClientTimeout(total=10),
264 ) as response,
265 ):
266 if response.status == 200:
267 data = await response.json()
268 self._available_models = [
269 model["name"] for model in data.get("models", [])
270 ]
271 return True
272 return False
273 except Exception:
274 return False
276 async def is_available(self) -> bool:
277 """Check if Ollama is available with connection pooling."""
278 try:
279 url = f"{self.base_url}/api/tags"
281 # Note: mcp-common integration deferred - using aiohttp fallback
282 # if self._use_mcp_common and self.http_adapter:
283 # return await self._check_with_mcp_common(url)
284 return await self._check_with_aiohttp(url)
285 except Exception:
286 return False
288 def get_models(self) -> list[str]:
289 """Get available Ollama models."""
290 return self._available_models or [
291 "llama2",
292 "llama2:13b",
293 "llama2:70b",
294 "codellama",
295 "mistral",
296 "mixtral",
297 ]