Coverage for session_mgmt_mcp/llm_providers.py: 14.02%
325 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 05:22 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 05:22 -0700
1#!/usr/bin/env python3
2"""Cross-LLM Compatibility for Session Management MCP Server.
4Provides unified interface for multiple LLM providers including OpenAI, Google Gemini, and Ollama.
5"""
7import json
8import logging
9import os
10from abc import ABC, abstractmethod
11from collections.abc import AsyncGenerator
12from dataclasses import dataclass
13from datetime import datetime
14from pathlib import Path
15from typing import Any
18@dataclass
19class LLMMessage:
20 """Standardized message format across LLM providers."""
22 role: str # 'system', 'user', 'assistant'
23 content: str
24 timestamp: str | None = None
25 metadata: dict[str, Any] = None
27 def __post_init__(self):
28 if self.timestamp is None:
29 self.timestamp = datetime.now().isoformat()
30 if self.metadata is None:
31 self.metadata = {}
34@dataclass
35class LLMResponse:
36 """Standardized response format from LLM providers."""
38 content: str
39 model: str
40 provider: str
41 usage: dict[str, Any]
42 finish_reason: str
43 timestamp: str
44 metadata: dict[str, Any] = None
46 def __post_init__(self):
47 if self.metadata is None:
48 self.metadata = {}
51class LLMProvider(ABC):
52 """Abstract base class for LLM providers."""
54 def __init__(self, config: dict[str, Any]) -> None:
55 self.config = config
56 self.name = self.__class__.__name__.replace("Provider", "").lower()
57 self.logger = logging.getLogger(f"llm_providers.{self.name}")
59 @abstractmethod
60 async def generate(
61 self,
62 messages: list[LLMMessage],
63 model: str | None = None,
64 temperature: float = 0.7,
65 max_tokens: int | None = None,
66 **kwargs,
67 ) -> LLMResponse:
68 """Generate a response from the LLM."""
70 @abstractmethod
71 async def stream_generate(
72 self,
73 messages: list[LLMMessage],
74 model: str | None = None,
75 temperature: float = 0.7,
76 max_tokens: int | None = None,
77 **kwargs,
78 ) -> AsyncGenerator[str]:
79 """Generate a streaming response from the LLM."""
81 @abstractmethod
82 async def is_available(self) -> bool:
83 """Check if the provider is available and properly configured."""
85 @abstractmethod
86 def get_models(self) -> list[str]:
87 """Get list of available models for this provider."""
90class OpenAIProvider(LLMProvider):
91 """OpenAI API provider."""
93 def __init__(self, config: dict[str, Any]) -> None:
94 super().__init__(config)
95 self.api_key = config.get("api_key") or os.getenv("OPENAI_API_KEY")
96 self.base_url = config.get("base_url", "https://api.openai.com/v1")
97 self.default_model = config.get("default_model", "gpt-4")
98 self._client = None
100 async def _get_client(self):
101 """Get or create OpenAI client."""
102 if self._client is None:
103 try:
104 import openai
106 self._client = openai.AsyncOpenAI(
107 api_key=self.api_key,
108 base_url=self.base_url,
109 )
110 except ImportError:
111 msg = "OpenAI package not installed. Install with: pip install openai"
112 raise ImportError(
113 msg,
114 )
115 return self._client
117 def _convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, str]]:
118 """Convert LLMMessage objects to OpenAI format."""
119 return [{"role": msg.role, "content": msg.content} for msg in messages]
121 async def generate(
122 self,
123 messages: list[LLMMessage],
124 model: str | None = None,
125 temperature: float = 0.7,
126 max_tokens: int | None = None,
127 **kwargs,
128 ) -> LLMResponse:
129 """Generate response using OpenAI API."""
130 if not await self.is_available():
131 msg = "OpenAI provider not available"
132 raise RuntimeError(msg)
134 client = await self._get_client()
135 model_name = model or self.default_model
137 try:
138 response = await client.chat.completions.create(
139 model=model_name,
140 messages=self._convert_messages(messages),
141 temperature=temperature,
142 max_tokens=max_tokens,
143 **kwargs,
144 )
146 return LLMResponse(
147 content=response.choices[0].message.content,
148 model=model_name,
149 provider="openai",
150 usage={
151 "prompt_tokens": response.usage.prompt_tokens
152 if response.usage
153 else 0,
154 "completion_tokens": response.usage.completion_tokens
155 if response.usage
156 else 0,
157 "total_tokens": response.usage.total_tokens
158 if response.usage
159 else 0,
160 },
161 finish_reason=response.choices[0].finish_reason,
162 timestamp=datetime.now().isoformat(),
163 metadata={"response_id": response.id},
164 )
166 except Exception as e:
167 self.logger.exception(f"OpenAI generation failed: {e}")
168 raise
170 async def stream_generate(
171 self,
172 messages: list[LLMMessage],
173 model: str | None = None,
174 temperature: float = 0.7,
175 max_tokens: int | None = None,
176 **kwargs,
177 ) -> AsyncGenerator[str]:
178 """Stream response using OpenAI API."""
179 if not await self.is_available():
180 msg = "OpenAI provider not available"
181 raise RuntimeError(msg)
183 client = await self._get_client()
184 model_name = model or self.default_model
186 try:
187 response = await client.chat.completions.create(
188 model=model_name,
189 messages=self._convert_messages(messages),
190 temperature=temperature,
191 max_tokens=max_tokens,
192 stream=True,
193 **kwargs,
194 )
196 async for chunk in response:
197 if chunk.choices[0].delta.content:
198 yield chunk.choices[0].delta.content
200 except Exception as e:
201 self.logger.exception(f"OpenAI streaming failed: {e}")
202 raise
204 async def is_available(self) -> bool:
205 """Check if OpenAI API is available."""
206 if not self.api_key:
207 return False
209 try:
210 client = await self._get_client()
211 # Test with a simple request
212 await client.models.list()
213 return True
214 except Exception:
215 return False
217 def get_models(self) -> list[str]:
218 """Get available OpenAI models."""
219 return [
220 "gpt-4",
221 "gpt-4-turbo",
222 "gpt-4o",
223 "gpt-4o-mini",
224 "gpt-3.5-turbo",
225 "gpt-3.5-turbo-16k",
226 ]
229class GeminiProvider(LLMProvider):
230 """Google Gemini API provider."""
232 def __init__(self, config: dict[str, Any]) -> None:
233 super().__init__(config)
234 self.api_key = (
235 config.get("api_key")
236 or os.getenv("GEMINI_API_KEY")
237 or os.getenv("GOOGLE_API_KEY")
238 )
239 self.default_model = config.get("default_model", "gemini-pro")
240 self._client = None
242 async def _get_client(self):
243 """Get or create Gemini client."""
244 if self._client is None:
245 try:
246 import google.generativeai as genai
248 genai.configure(api_key=self.api_key)
249 self._client = genai
250 except ImportError:
251 msg = "Google Generative AI package not installed. Install with: pip install google-generativeai"
252 raise ImportError(
253 msg,
254 )
255 return self._client
257 def _convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, str]]:
258 """Convert LLMMessage objects to Gemini format."""
259 converted = []
260 for msg in messages:
261 if msg.role == "system":
262 # Gemini doesn't have system role, prepend to first user message
263 if converted and converted[-1]["role"] == "user":
264 converted[-1]["parts"] = [
265 f"System: {msg.content}\n\nUser: {converted[-1]['parts'][0]}",
266 ]
267 else:
268 converted.append(
269 {"role": "user", "parts": [f"System: {msg.content}"]},
270 )
271 elif msg.role == "user":
272 converted.append({"role": "user", "parts": [msg.content]})
273 elif msg.role == "assistant":
274 converted.append({"role": "model", "parts": [msg.content]})
275 return converted
277 async def generate(
278 self,
279 messages: list[LLMMessage],
280 model: str | None = None,
281 temperature: float = 0.7,
282 max_tokens: int | None = None,
283 **kwargs,
284 ) -> LLMResponse:
285 """Generate response using Gemini API."""
286 if not await self.is_available():
287 msg = "Gemini provider not available"
288 raise RuntimeError(msg)
290 genai = await self._get_client()
291 model_name = model or self.default_model
293 try:
294 model_instance = genai.GenerativeModel(model_name)
296 # Convert messages to Gemini chat format
297 chat_messages = self._convert_messages(messages)
299 # Create chat or generate single response
300 if len(chat_messages) > 1:
301 chat = model_instance.start_chat(history=chat_messages[:-1])
302 response = await chat.send_message_async(
303 chat_messages[-1]["parts"][0],
304 generation_config={
305 "temperature": temperature,
306 "max_output_tokens": max_tokens,
307 },
308 )
309 else:
310 response = await model_instance.generate_content_async(
311 chat_messages[0]["parts"][0],
312 generation_config={
313 "temperature": temperature,
314 "max_output_tokens": max_tokens,
315 },
316 )
318 return LLMResponse(
319 content=response.text,
320 model=model_name,
321 provider="gemini",
322 usage={
323 "prompt_tokens": response.usage_metadata.prompt_token_count
324 if hasattr(response, "usage_metadata")
325 else 0,
326 "completion_tokens": response.usage_metadata.candidates_token_count
327 if hasattr(response, "usage_metadata")
328 else 0,
329 "total_tokens": response.usage_metadata.total_token_count
330 if hasattr(response, "usage_metadata")
331 else 0,
332 },
333 finish_reason="stop", # Gemini doesn't provide detailed finish reasons
334 timestamp=datetime.now().isoformat(),
335 )
337 except Exception as e:
338 self.logger.exception(f"Gemini generation failed: {e}")
339 raise
341 async def stream_generate(
342 self,
343 messages: list[LLMMessage],
344 model: str | None = None,
345 temperature: float = 0.7,
346 max_tokens: int | None = None,
347 **kwargs,
348 ) -> AsyncGenerator[str]:
349 """Stream response using Gemini API."""
350 if not await self.is_available():
351 msg = "Gemini provider not available"
352 raise RuntimeError(msg)
354 genai = await self._get_client()
355 model_name = model or self.default_model
357 try:
358 model_instance = genai.GenerativeModel(model_name)
359 chat_messages = self._convert_messages(messages)
361 if len(chat_messages) > 1:
362 chat = model_instance.start_chat(history=chat_messages[:-1])
363 response = chat.send_message(
364 chat_messages[-1]["parts"][0],
365 generation_config={
366 "temperature": temperature,
367 "max_output_tokens": max_tokens,
368 },
369 stream=True,
370 )
371 else:
372 response = model_instance.generate_content(
373 chat_messages[0]["parts"][0],
374 generation_config={
375 "temperature": temperature,
376 "max_output_tokens": max_tokens,
377 },
378 stream=True,
379 )
381 for chunk in response:
382 if chunk.text:
383 yield chunk.text
385 except Exception as e:
386 self.logger.exception(f"Gemini streaming failed: {e}")
387 raise
389 async def is_available(self) -> bool:
390 """Check if Gemini API is available."""
391 if not self.api_key:
392 return False
394 try:
395 genai = await self._get_client()
396 # Test with a simple model list request
397 list(genai.list_models())
398 return True
399 except Exception:
400 return False
402 def get_models(self) -> list[str]:
403 """Get available Gemini models."""
404 return [
405 "gemini-pro",
406 "gemini-pro-vision",
407 "gemini-1.5-pro",
408 "gemini-1.5-flash",
409 "gemini-1.0-pro",
410 ]
413class OllamaProvider(LLMProvider):
414 """Ollama local LLM provider."""
416 def __init__(self, config: dict[str, Any]) -> None:
417 super().__init__(config)
418 self.base_url = config.get("base_url", "http://localhost:11434")
419 self.default_model = config.get("default_model", "llama2")
420 self._available_models = []
422 async def _make_request(
423 self,
424 endpoint: str,
425 data: dict[str, Any],
426 ) -> dict[str, Any]:
427 """Make HTTP request to Ollama API."""
428 try:
429 import aiohttp
431 async with aiohttp.ClientSession() as session:
432 async with session.post(
433 f"{self.base_url}/{endpoint}",
434 json=data,
435 timeout=aiohttp.ClientTimeout(total=300),
436 ) as response:
437 return await response.json()
438 except ImportError:
439 msg = "aiohttp package not installed. Install with: pip install aiohttp"
440 raise ImportError(
441 msg,
442 )
444 def _convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, str]]:
445 """Convert LLMMessage objects to Ollama format."""
446 return [{"role": msg.role, "content": msg.content} for msg in messages]
448 async def generate(
449 self,
450 messages: list[LLMMessage],
451 model: str | None = None,
452 temperature: float = 0.7,
453 max_tokens: int | None = None,
454 **kwargs,
455 ) -> LLMResponse:
456 """Generate response using Ollama API."""
457 if not await self.is_available():
458 msg = "Ollama provider not available"
459 raise RuntimeError(msg)
461 model_name = model or self.default_model
463 try:
464 data = {
465 "model": model_name,
466 "messages": self._convert_messages(messages),
467 "options": {"temperature": temperature},
468 }
470 if max_tokens:
471 data["options"]["num_predict"] = max_tokens
473 response = await self._make_request("api/chat", data)
475 return LLMResponse(
476 content=response.get("message", {}).get("content", ""),
477 model=model_name,
478 provider="ollama",
479 usage={
480 "prompt_tokens": response.get("prompt_eval_count", 0),
481 "completion_tokens": response.get("eval_count", 0),
482 "total_tokens": response.get("prompt_eval_count", 0)
483 + response.get("eval_count", 0),
484 },
485 finish_reason=response.get("done_reason", "stop"),
486 timestamp=datetime.now().isoformat(),
487 )
489 except Exception as e:
490 self.logger.exception(f"Ollama generation failed: {e}")
491 raise
493 async def stream_generate(
494 self,
495 messages: list[LLMMessage],
496 model: str | None = None,
497 temperature: float = 0.7,
498 max_tokens: int | None = None,
499 **kwargs,
500 ) -> AsyncGenerator[str]:
501 """Stream response using Ollama API."""
502 if not await self.is_available():
503 msg = "Ollama provider not available"
504 raise RuntimeError(msg)
506 model_name = model or self.default_model
508 try:
509 import aiohttp
511 data = {
512 "model": model_name,
513 "messages": self._convert_messages(messages),
514 "stream": True,
515 "options": {"temperature": temperature},
516 }
518 if max_tokens:
519 data["options"]["num_predict"] = max_tokens
521 async with aiohttp.ClientSession() as session:
522 async with session.post(
523 f"{self.base_url}/api/chat",
524 json=data,
525 timeout=aiohttp.ClientTimeout(total=300),
526 ) as response:
527 async for line in response.content:
528 if line:
529 try:
530 chunk_data = json.loads(line.decode("utf-8"))
531 if (
532 "message" in chunk_data
533 and "content" in chunk_data["message"]
534 ):
535 yield chunk_data["message"]["content"]
536 except json.JSONDecodeError:
537 continue
539 except Exception as e:
540 self.logger.exception(f"Ollama streaming failed: {e}")
541 raise
543 async def is_available(self) -> bool:
544 """Check if Ollama is available."""
545 try:
546 import aiohttp
548 async with aiohttp.ClientSession() as session:
549 async with session.get(
550 f"{self.base_url}/api/tags",
551 timeout=aiohttp.ClientTimeout(total=10),
552 ) as response:
553 if response.status == 200:
554 data = await response.json()
555 self._available_models = [
556 model["name"] for model in data.get("models", [])
557 ]
558 return True
559 return False
560 except Exception:
561 return False
563 def get_models(self) -> list[str]:
564 """Get available Ollama models."""
565 return (
566 self._available_models
567 if self._available_models
568 else [
569 "llama2",
570 "llama2:13b",
571 "llama2:70b",
572 "codellama",
573 "mistral",
574 "mixtral",
575 ]
576 )
579class LLMManager:
580 """Manager for multiple LLM providers with fallback support."""
582 def __init__(self, config_path: str | None = None) -> None:
583 self.providers: dict[str, LLMProvider] = {}
584 self.config = self._load_config(config_path)
585 self.logger = logging.getLogger("llm_providers.manager")
586 self._initialize_providers()
588 def _load_config(self, config_path: str | None) -> dict[str, Any]:
589 """Load configuration from file or environment."""
590 config = {
591 "providers": {},
592 "default_provider": "openai",
593 "fallback_providers": ["gemini", "ollama"],
594 }
596 if config_path and Path(config_path).exists():
597 try:
598 with open(config_path) as f:
599 file_config = json.load(f)
600 config.update(file_config)
601 except (OSError, json.JSONDecodeError):
602 pass
604 # Add environment-based provider configs
605 if not config["providers"].get("openai"):
606 config["providers"]["openai"] = {
607 "api_key": os.getenv("OPENAI_API_KEY"),
608 "default_model": "gpt-4",
609 }
611 if not config["providers"].get("gemini"):
612 config["providers"]["gemini"] = {
613 "api_key": os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY"),
614 "default_model": "gemini-pro",
615 }
617 if not config["providers"].get("ollama"):
618 config["providers"]["ollama"] = {
619 "base_url": os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
620 "default_model": "llama2",
621 }
623 return config
625 def _initialize_providers(self) -> None:
626 """Initialize all configured providers."""
627 provider_classes = {
628 "openai": OpenAIProvider,
629 "gemini": GeminiProvider,
630 "ollama": OllamaProvider,
631 }
633 for provider_name, provider_config in self.config["providers"].items():
634 if provider_name in provider_classes:
635 try:
636 self.providers[provider_name] = provider_classes[provider_name](
637 provider_config,
638 )
639 except Exception as e:
640 self.logger.warning(
641 f"Failed to initialize {provider_name} provider: {e}",
642 )
644 async def get_available_providers(self) -> list[str]:
645 """Get list of available providers."""
646 available = []
647 for name, provider in self.providers.items():
648 if await provider.is_available():
649 available.append(name)
650 return available
652 async def generate(
653 self,
654 messages: list[LLMMessage],
655 provider: str | None = None,
656 model: str | None = None,
657 use_fallback: bool = True,
658 **kwargs,
659 ) -> LLMResponse:
660 """Generate response with optional fallback."""
661 target_provider = provider or self.config["default_provider"]
663 # Try primary provider
664 if target_provider in self.providers:
665 try:
666 provider_instance = self.providers[target_provider]
667 if await provider_instance.is_available():
668 return await provider_instance.generate(messages, model, **kwargs)
669 except Exception as e:
670 self.logger.warning(f"Provider {target_provider} failed: {e}")
672 # Try fallback providers if enabled
673 if use_fallback:
674 for fallback_name in self.config.get("fallback_providers", []):
675 if fallback_name in self.providers and fallback_name != target_provider:
676 try:
677 provider_instance = self.providers[fallback_name]
678 if await provider_instance.is_available():
679 self.logger.info(f"Falling back to {fallback_name}")
680 return await provider_instance.generate(
681 messages,
682 model,
683 **kwargs,
684 )
685 except Exception as e:
686 self.logger.warning(
687 f"Fallback provider {fallback_name} failed: {e}",
688 )
690 msg = "No available LLM providers"
691 raise RuntimeError(msg)
693 async def stream_generate(
694 self,
695 messages: list[LLMMessage],
696 provider: str | None = None,
697 model: str | None = None,
698 use_fallback: bool = True,
699 **kwargs,
700 ) -> AsyncGenerator[str]:
701 """Stream generate response with optional fallback."""
702 target_provider = provider or self.config["default_provider"]
704 # Try primary provider
705 if target_provider in self.providers:
706 try:
707 provider_instance = self.providers[target_provider]
708 if await provider_instance.is_available():
709 async for chunk in provider_instance.stream_generate(
710 messages,
711 model,
712 **kwargs,
713 ):
714 yield chunk
715 return
716 except Exception as e:
717 self.logger.warning(f"Provider {target_provider} failed: {e}")
719 # Try fallback providers if enabled
720 if use_fallback:
721 for fallback_name in self.config.get("fallback_providers", []):
722 if fallback_name in self.providers and fallback_name != target_provider:
723 try:
724 provider_instance = self.providers[fallback_name]
725 if await provider_instance.is_available():
726 self.logger.info(f"Falling back to {fallback_name}")
727 async for chunk in provider_instance.stream_generate(
728 messages,
729 model,
730 **kwargs,
731 ):
732 yield chunk
733 return
734 except Exception as e:
735 self.logger.warning(
736 f"Fallback provider {fallback_name} failed: {e}",
737 )
739 msg = "No available LLM providers"
740 raise RuntimeError(msg)
742 def get_provider_info(self) -> dict[str, Any]:
743 """Get information about all providers."""
744 info = {
745 "providers": {},
746 "config": {
747 "default_provider": self.config["default_provider"],
748 "fallback_providers": self.config.get("fallback_providers", []),
749 },
750 }
752 for name, provider in self.providers.items():
753 info["providers"][name] = {
754 "models": provider.get_models(),
755 "config": {
756 k: v for k, v in provider.config.items() if "key" not in k.lower()
757 },
758 }
760 return info
762 async def test_providers(self) -> dict[str, Any]:
763 """Test all providers and return status."""
764 test_message = [
765 LLMMessage(role="user", content='Hello, respond with just "OK"'),
766 ]
767 results = {}
769 for name, provider in self.providers.items():
770 try:
771 available = await provider.is_available()
772 if available:
773 # Quick test generation
774 response = await provider.generate(test_message, max_tokens=10)
775 results[name] = {
776 "available": True,
777 "test_successful": True,
778 "response_length": len(response.content),
779 "model": response.model,
780 }
781 else:
782 results[name] = {
783 "available": False,
784 "test_successful": False,
785 "error": "Provider not available",
786 }
787 except Exception as e:
788 results[name] = {
789 "available": False,
790 "test_successful": False,
791 "error": str(e),
792 }
794 return results