Coverage for session_buddy / llm_providers.py: 47.87%
254 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1#!/usr/bin/env python3
2"""Cross-LLM Compatibility for Session Management MCP Server.
4Provides unified interface for multiple LLM providers including OpenAI, Google Gemini, and Ollama.
5"""
7import contextlib
8import json
9import logging
10import os
11from collections.abc import AsyncGenerator
12from pathlib import Path
13from typing import Any
15from session_buddy.llm import (
16 GeminiProvider,
17 LLMMessage,
18 LLMProvider,
19 LLMResponse,
20 OllamaProvider,
21 OpenAIProvider,
22 StreamChunk,
23 StreamGenerationOptions,
24)
25from session_buddy.settings import get_llm_api_key, get_settings
27# Security utilities for API key validation/masking
28try:
29 from mcp_common.security import APIKeyValidator
31 SECURITY_AVAILABLE = True
32except ImportError:
33 APIKeyValidator = None # type: ignore[no-redef]
34 SECURITY_AVAILABLE = False
36# Re-export for backwards compatibility
37__all__ = [
38 "SECURITY_AVAILABLE",
39 "LLMManager",
40 "LLMMessage",
41 "LLMProvider",
42 "LLMResponse",
43 "StreamChunk",
44 "StreamGenerationOptions",
45]
48def _get_provider_api_key_and_env(
49 provider: str,
50) -> tuple[str | None, str | None]:
51 """Return the provider API key and its environment variable name."""
52 configured_key = get_llm_api_key(provider)
53 if configured_key: 53 ↛ 54line 53 didn't jump to line 54 because the condition on line 53 was never true
54 return configured_key, f"settings.{provider}_api_key"
55 if provider == "openai":
56 return os.getenv("OPENAI_API_KEY"), "OPENAI_API_KEY"
57 if provider == "anthropic": 57 ↛ 58line 57 didn't jump to line 58 because the condition on line 57 was never true
58 return os.getenv("ANTHROPIC_API_KEY"), "ANTHROPIC_API_KEY"
59 if provider == "gemini":
60 if os.getenv("GEMINI_API_KEY"):
61 return os.getenv("GEMINI_API_KEY"), "GEMINI_API_KEY"
62 if os.getenv("GOOGLE_API_KEY"): 62 ↛ 64line 62 didn't jump to line 64 because the condition on line 62 was always true
63 return os.getenv("GOOGLE_API_KEY"), "GOOGLE_API_KEY"
64 return None, "GEMINI_API_KEY"
65 if provider == "ollama":
66 return None, None
67 return None, None
70def _get_configured_providers() -> list[str]:
71 """Get list of configured providers based on environment variables."""
72 providers: set[str] = set()
73 if get_llm_api_key("openai"): 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true
74 providers.add("openai")
75 if get_llm_api_key("gemini"): 75 ↛ 76line 75 didn't jump to line 76 because the condition on line 75 was never true
76 providers.add("gemini")
77 if get_llm_api_key("anthropic"): 77 ↛ 78line 77 didn't jump to line 78 because the condition on line 77 was never true
78 providers.add("anthropic")
79 if os.getenv("OPENAI_API_KEY"):
80 providers.add("openai")
81 if os.getenv("ANTHROPIC_API_KEY"): 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true
82 providers.add("anthropic")
83 if os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY"):
84 providers.add("gemini")
85 return sorted(providers)
88def _validate_provider_basic(provider: str, api_key: str) -> str:
89 """Basic API key validation without security module."""
90 import sys
92 if len(api_key.strip()) < 16:
93 print(
94 f"API Key Warning: {provider} API key appears very short",
95 file=sys.stderr,
96 )
97 return "basic_check"
100def _validate_provider_with_security(provider: str, api_key: str) -> tuple[bool, str]:
101 """Validate API key with security module."""
102 import sys
104 validator = APIKeyValidator(provider=provider) if APIKeyValidator else None
105 try:
106 if validator is None: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true
107 return False, "unavailable"
108 validator.validate(api_key, raise_on_invalid=True)
109 print(f"✅ API Key validated for {provider}", file=sys.stderr)
110 return True, "valid"
111 except ValueError as exc:
112 print(f"❌ API Key validation failed: {exc}", file=sys.stderr)
113 sys.exit(1)
116def validate_llm_api_keys_at_startup() -> dict[str, str]:
117 """Validate configured LLM API keys and return status by provider."""
118 import sys
120 configured = _get_configured_providers()
121 if not configured:
122 print("No LLM Provider API Keys Configured", file=sys.stderr)
123 return {}
125 results: dict[str, str] = {}
126 for provider in configured:
127 api_key, env_var = _get_provider_api_key_and_env(provider)
128 if api_key is None: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true
129 continue
130 if not api_key.strip():
131 print(f"❌ {env_var} is empty", file=sys.stderr)
132 sys.exit(1)
134 if SECURITY_AVAILABLE:
135 _, status = _validate_provider_with_security(provider, api_key)
136 else:
137 status = _validate_provider_basic(provider, api_key)
138 results[provider] = status
140 return results
143def get_masked_api_key(provider: str = "openai") -> str:
144 """Return masked API key for safe logging."""
145 settings = get_settings()
146 key_field_map = {
147 "openai": "openai_api_key",
148 "anthropic": "anthropic_api_key",
149 "gemini": "gemini_api_key",
150 }
151 key_field = key_field_map.get(provider)
152 if key_field:
153 configured = getattr(settings, key_field, None)
154 if isinstance(configured, str) and configured.strip(): 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true
155 return settings.get_masked_key(key_name=key_field, visible_chars=4)
157 api_key, _ = _get_provider_api_key_and_env(provider)
159 if provider == "ollama":
160 return "N/A (local service)"
162 if not api_key:
163 return "***"
165 if SECURITY_AVAILABLE and APIKeyValidator:
166 return APIKeyValidator.mask_key(api_key, visible_chars=4)
168 if len(api_key) <= 4:
169 return "***"
170 return f"...{api_key[-4:]}"
173class LLMManager:
174 """Manager for multiple LLM providers with fallback support."""
176 def __init__(self, config_path: str | None = None) -> None:
177 self.providers: dict[str, LLMProvider] = {}
178 self.config = self._load_config(config_path)
179 self.logger = logging.getLogger("llm_providers.manager")
180 self._initialize_providers()
182 def _load_config(self, config_path: str | None) -> dict[str, Any]:
183 """Load configuration from file or environment."""
184 config: dict[str, Any] = {
185 "providers": {},
186 "default_provider": "openai",
187 # Plan cascade: openai -> anthropic -> gemini (-> ollama future)
188 "fallback_providers": ["anthropic", "gemini", "ollama"],
189 }
191 if config_path and Path(config_path).exists(): 191 ↛ 192line 191 didn't jump to line 192 because the condition on line 191 was never true
192 with contextlib.suppress(OSError, json.JSONDecodeError):
193 with open(config_path, encoding="utf-8") as f:
194 file_config = json.load(f)
195 config.update(file_config)
197 # Add environment-based provider configs
198 if not config["providers"].get("openai"): 198 ↛ 204line 198 didn't jump to line 204 because the condition on line 198 was always true
199 config["providers"]["openai"] = {
200 "api_key": os.getenv("OPENAI_API_KEY"),
201 "default_model": "gpt-4",
202 }
204 if not config["providers"].get("anthropic"): 204 ↛ 210line 204 didn't jump to line 210 because the condition on line 204 was always true
205 config["providers"]["anthropic"] = {
206 "api_key": os.getenv("ANTHROPIC_API_KEY"),
207 "default_model": "claude-3-5-haiku-20241022",
208 }
210 if not config["providers"].get("gemini"): 210 ↛ 216line 210 didn't jump to line 216 because the condition on line 210 was always true
211 config["providers"]["gemini"] = {
212 "api_key": os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY"),
213 "default_model": "gemini-pro",
214 }
216 if not config["providers"].get("ollama"): 216 ↛ 222line 216 didn't jump to line 222 because the condition on line 216 was always true
217 config["providers"]["ollama"] = {
218 "base_url": os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
219 "default_model": "llama2",
220 }
222 return config
224 def _initialize_providers(self) -> None:
225 """Initialize all configured providers."""
226 provider_classes = {
227 "openai": OpenAIProvider,
228 "anthropic": __import__(
229 "session_buddy.llm.providers.anthropic_provider",
230 fromlist=["AnthropicProvider"],
231 ).AnthropicProvider,
232 "gemini": GeminiProvider,
233 "ollama": OllamaProvider,
234 }
236 for provider_name, provider_config in self.config["providers"].items():
237 if provider_name in provider_classes: 237 ↛ 236line 237 didn't jump to line 236 because the condition on line 237 was always true
238 try:
239 self.providers[provider_name] = provider_classes[provider_name](
240 provider_config,
241 )
242 except Exception as e:
243 self.logger.warning(
244 f"Failed to initialize {provider_name} provider: {e}",
245 )
247 async def get_available_providers(self) -> list[str]:
248 """Get list of available providers."""
249 return [
250 name
251 for name, provider in self.providers.items()
252 if await provider.is_available()
253 ]
255 async def generate(
256 self,
257 messages: list[LLMMessage],
258 provider: str | None = None,
259 model: str | None = None,
260 use_fallback: bool = True,
261 **kwargs: Any,
262 ) -> LLMResponse:
263 """Generate response with optional fallback."""
264 target_provider = provider or self.config["default_provider"]
266 # Try primary provider
267 result = await self._try_primary_provider_generate(
268 target_provider,
269 messages,
270 model,
271 **kwargs,
272 )
273 if result is not None:
274 return result
276 # Try fallback providers if enabled
277 if use_fallback:
278 result = await self._try_fallback_providers_generate(
279 target_provider,
280 messages,
281 model,
282 **kwargs,
283 )
284 if result is not None:
285 return result
287 msg = "No available LLM providers"
288 raise RuntimeError(msg)
290 async def _try_primary_provider_generate(
291 self,
292 target_provider: str,
293 messages: list[LLMMessage],
294 model: str | None,
295 **kwargs: Any,
296 ) -> LLMResponse | None:
297 """Try generating with primary provider."""
298 if target_provider not in self.providers:
299 return None
301 try:
302 provider_instance = self.providers[target_provider]
303 if await provider_instance.is_available():
304 return await provider_instance.generate(messages, model, **kwargs)
305 except Exception as e:
306 self.logger.warning(f"Provider {target_provider} failed: {e}")
307 return None
309 async def _try_fallback_providers_generate(
310 self,
311 target_provider: str,
312 messages: list[LLMMessage],
313 model: str | None,
314 **kwargs: Any,
315 ) -> LLMResponse | None:
316 """Try generating with fallback providers."""
317 for fallback_name in self.config.get("fallback_providers", []):
318 if fallback_name in self.providers and fallback_name != target_provider:
319 try:
320 provider_instance = self.providers[fallback_name]
321 if await provider_instance.is_available():
322 self.logger.info(f"Falling back to {fallback_name}")
323 return await provider_instance.generate(
324 messages,
325 model,
326 **kwargs,
327 )
328 except Exception as e:
329 self.logger.warning(
330 f"Fallback provider {fallback_name} failed: {e}",
331 )
332 return None
334 def _get_fallback_providers(self, target_provider: str) -> list[str]:
335 """Get list of fallback providers excluding the target provider."""
336 return [
337 name
338 for name in self.config.get("fallback_providers", [])
339 if name in self.providers and name != target_provider
340 ]
342 def _is_valid_provider(self, provider_name: str) -> bool:
343 """Check if a provider is valid and available."""
344 return provider_name in self.providers
346 async def _get_provider_stream(
347 self,
348 provider_name: str,
349 messages: list[LLMMessage],
350 model: str | None,
351 **kwargs: Any,
352 ) -> AsyncGenerator[str]:
353 """Get stream from provider (assumes provider is available)."""
354 provider_instance = self.providers[provider_name]
355 async for chunk in provider_instance.stream_generate( # type: ignore[attr-defined]
356 messages,
357 model,
358 **kwargs,
359 ):
360 yield chunk
362 async def _try_provider_streaming(
363 self,
364 provider_name: str,
365 messages: list[LLMMessage],
366 model: str | None,
367 **kwargs: Any,
368 ) -> AsyncGenerator[str]:
369 """Try streaming from a provider with error handling."""
370 try:
371 provider_instance = self.providers[provider_name]
372 if await provider_instance.is_available():
373 async for chunk in self._get_provider_stream(
374 provider_name,
375 messages,
376 model,
377 **kwargs,
378 ):
379 yield chunk
380 except Exception as e:
381 self.logger.warning(f"Provider {provider_name} failed: {e}")
383 async def _select_primary_provider(self, options: StreamGenerationOptions) -> str:
384 """Select primary provider. Target complexity: ≤3."""
385 target_provider = options.provider or self.config["default_provider"]
386 if not self._is_valid_provider(target_provider):
387 msg = f"Invalid provider: {target_provider}"
388 raise RuntimeError(msg)
389 return target_provider
391 async def _try_streaming_from_provider(
392 self,
393 provider_name: str,
394 messages: list[LLMMessage],
395 options: StreamGenerationOptions,
396 ) -> AsyncGenerator[StreamChunk]:
397 """Try streaming from a specific provider. Target complexity: ≤6."""
398 try:
399 stream_started = False
400 async for chunk_content in self._try_provider_streaming(
401 provider_name,
402 messages,
403 options.model,
404 temperature=options.temperature,
405 max_tokens=options.max_tokens,
406 ):
407 stream_started = True
408 yield StreamChunk.content_chunk(chunk_content, provider_name)
410 if not stream_started:
411 yield StreamChunk.error_chunk(f"No response from {provider_name}")
413 except Exception as e:
414 self.logger.warning(f"Provider {provider_name} failed: {e}")
415 yield StreamChunk.error_chunk(str(e))
417 async def _stream_from_primary_provider(
418 self,
419 primary_provider: str,
420 messages: list[LLMMessage],
421 options: StreamGenerationOptions,
422 ) -> AsyncGenerator[str]:
423 """Stream from primary provider. Target complexity: ≤4."""
424 has_content = False
425 async for chunk in self._try_streaming_from_provider(
426 primary_provider,
427 messages,
428 options,
429 ):
430 if chunk.is_error:
431 if not has_content: # Log errors only if no content received
432 self.logger.warning(
433 f"Primary provider error: {chunk.metadata.get('error', 'Unknown')}",
434 )
435 continue
437 has_content = True
438 yield chunk.content
440 if not has_content:
441 self.logger.debug(
442 f"Primary provider {primary_provider} produced no content",
443 )
445 async def _stream_from_fallback_providers(
446 self,
447 primary_provider: str,
448 messages: list[LLMMessage],
449 options: StreamGenerationOptions,
450 ) -> AsyncGenerator[str]:
451 """Stream from fallback providers. Target complexity: ≤5."""
452 if not options.use_fallback:
453 return
455 fallback_providers = self._get_fallback_providers(primary_provider)
456 for fallback_name in fallback_providers:
457 self.logger.info(f"Falling back to {fallback_name}")
458 has_content = False
459 async for chunk in self._try_streaming_from_provider(
460 fallback_name,
461 messages,
462 options,
463 ):
464 if chunk.is_error:
465 continue
466 has_content = True
467 yield chunk.content
468 if has_content:
469 return
471 async def stream_generate( # type: ignore[override]
472 self,
473 messages: list[LLMMessage],
474 provider: str | None = None,
475 model: str | None = None,
476 use_fallback: bool = True,
477 **kwargs: Any,
478 ) -> AsyncGenerator[str]:
479 """Stream generate response with optional fallback. Target complexity: ≤8."""
480 options = StreamGenerationOptions(
481 provider=provider,
482 model=model,
483 use_fallback=use_fallback,
484 temperature=kwargs.get("temperature", 0.7),
485 max_tokens=kwargs.get("max_tokens"),
486 )
488 try:
489 # Try primary provider first
490 primary_provider = await self._select_primary_provider(options)
491 async for chunk_content in self._stream_from_primary_provider(
492 primary_provider,
493 messages,
494 options,
495 ):
496 yield chunk_content
497 return # Success - exit early
499 # Try fallback providers if primary failed
500 async for chunk_content in self._stream_from_fallback_providers(
501 primary_provider,
502 messages,
503 options,
504 ):
505 yield chunk_content
506 return # Success - exit early
508 # All providers failed
509 msg = "No available LLM providers"
510 raise RuntimeError(msg)
512 except Exception as e:
513 self.logger.exception(f"Stream generation failed: {e}")
514 raise
516 def get_provider_info(self) -> dict[str, Any]:
517 """Get information about all providers."""
518 info: dict[str, Any] = {
519 "providers": {},
520 "config": {
521 "default_provider": self.config["default_provider"],
522 "fallback_providers": self.config.get("fallback_providers", []),
523 },
524 }
526 for name, provider in self.providers.items():
527 info["providers"][name] = {
528 "models": provider.get_models(),
529 "config": {
530 k: v for k, v in provider.config.items() if "key" not in k.lower()
531 },
532 }
534 return info
536 async def test_providers(self) -> dict[str, Any]:
537 """Test all providers and return status."""
538 test_message = [
539 LLMMessage(role="user", content='Hello, respond with just "OK"'),
540 ]
541 results = {}
543 for name, provider in self.providers.items():
544 try:
545 available = await provider.is_available()
546 if available:
547 # Quick test generation
548 response = await provider.generate(test_message, max_tokens=10)
549 results[name] = {
550 "available": True,
551 "test_successful": True,
552 "response_length": len(response.content),
553 "model": response.model,
554 }
555 else:
556 results[name] = {
557 "available": False,
558 "test_successful": False,
559 "error": "Provider not available",
560 }
561 except Exception as e:
562 results[name] = {
563 "available": False,
564 "test_successful": False,
565 "error": str(e),
566 }
568 return results