Coverage for session_buddy / llm_providers.py: 47.87%

254 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1#!/usr/bin/env python3 

2"""Cross-LLM Compatibility for Session Management MCP Server. 

3 

4Provides unified interface for multiple LLM providers including OpenAI, Google Gemini, and Ollama. 

5""" 

6 

7import contextlib 

8import json 

9import logging 

10import os 

11from collections.abc import AsyncGenerator 

12from pathlib import Path 

13from typing import Any 

14 

15from session_buddy.llm import ( 

16 GeminiProvider, 

17 LLMMessage, 

18 LLMProvider, 

19 LLMResponse, 

20 OllamaProvider, 

21 OpenAIProvider, 

22 StreamChunk, 

23 StreamGenerationOptions, 

24) 

25from session_buddy.settings import get_llm_api_key, get_settings 

26 

27# Security utilities for API key validation/masking 

28try: 

29 from mcp_common.security import APIKeyValidator 

30 

31 SECURITY_AVAILABLE = True 

32except ImportError: 

33 APIKeyValidator = None # type: ignore[no-redef] 

34 SECURITY_AVAILABLE = False 

35 

36# Re-export for backwards compatibility 

37__all__ = [ 

38 "SECURITY_AVAILABLE", 

39 "LLMManager", 

40 "LLMMessage", 

41 "LLMProvider", 

42 "LLMResponse", 

43 "StreamChunk", 

44 "StreamGenerationOptions", 

45] 

46 

47 

48def _get_provider_api_key_and_env( 

49 provider: str, 

50) -> tuple[str | None, str | None]: 

51 """Return the provider API key and its environment variable name.""" 

52 configured_key = get_llm_api_key(provider) 

53 if configured_key: 53 ↛ 54line 53 didn't jump to line 54 because the condition on line 53 was never true

54 return configured_key, f"settings.{provider}_api_key" 

55 if provider == "openai": 

56 return os.getenv("OPENAI_API_KEY"), "OPENAI_API_KEY" 

57 if provider == "anthropic": 57 ↛ 58line 57 didn't jump to line 58 because the condition on line 57 was never true

58 return os.getenv("ANTHROPIC_API_KEY"), "ANTHROPIC_API_KEY" 

59 if provider == "gemini": 

60 if os.getenv("GEMINI_API_KEY"): 

61 return os.getenv("GEMINI_API_KEY"), "GEMINI_API_KEY" 

62 if os.getenv("GOOGLE_API_KEY"): 62 ↛ 64line 62 didn't jump to line 64 because the condition on line 62 was always true

63 return os.getenv("GOOGLE_API_KEY"), "GOOGLE_API_KEY" 

64 return None, "GEMINI_API_KEY" 

65 if provider == "ollama": 

66 return None, None 

67 return None, None 

68 

69 

70def _get_configured_providers() -> list[str]: 

71 """Get list of configured providers based on environment variables.""" 

72 providers: set[str] = set() 

73 if get_llm_api_key("openai"): 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true

74 providers.add("openai") 

75 if get_llm_api_key("gemini"): 75 ↛ 76line 75 didn't jump to line 76 because the condition on line 75 was never true

76 providers.add("gemini") 

77 if get_llm_api_key("anthropic"): 77 ↛ 78line 77 didn't jump to line 78 because the condition on line 77 was never true

78 providers.add("anthropic") 

79 if os.getenv("OPENAI_API_KEY"): 

80 providers.add("openai") 

81 if os.getenv("ANTHROPIC_API_KEY"): 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true

82 providers.add("anthropic") 

83 if os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY"): 

84 providers.add("gemini") 

85 return sorted(providers) 

86 

87 

88def _validate_provider_basic(provider: str, api_key: str) -> str: 

89 """Basic API key validation without security module.""" 

90 import sys 

91 

92 if len(api_key.strip()) < 16: 

93 print( 

94 f"API Key Warning: {provider} API key appears very short", 

95 file=sys.stderr, 

96 ) 

97 return "basic_check" 

98 

99 

100def _validate_provider_with_security(provider: str, api_key: str) -> tuple[bool, str]: 

101 """Validate API key with security module.""" 

102 import sys 

103 

104 validator = APIKeyValidator(provider=provider) if APIKeyValidator else None 

105 try: 

106 if validator is None: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true

107 return False, "unavailable" 

108 validator.validate(api_key, raise_on_invalid=True) 

109 print(f"✅ API Key validated for {provider}", file=sys.stderr) 

110 return True, "valid" 

111 except ValueError as exc: 

112 print(f"❌ API Key validation failed: {exc}", file=sys.stderr) 

113 sys.exit(1) 

114 

115 

116def validate_llm_api_keys_at_startup() -> dict[str, str]: 

117 """Validate configured LLM API keys and return status by provider.""" 

118 import sys 

119 

120 configured = _get_configured_providers() 

121 if not configured: 

122 print("No LLM Provider API Keys Configured", file=sys.stderr) 

123 return {} 

124 

125 results: dict[str, str] = {} 

126 for provider in configured: 

127 api_key, env_var = _get_provider_api_key_and_env(provider) 

128 if api_key is None: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true

129 continue 

130 if not api_key.strip(): 

131 print(f"{env_var} is empty", file=sys.stderr) 

132 sys.exit(1) 

133 

134 if SECURITY_AVAILABLE: 

135 _, status = _validate_provider_with_security(provider, api_key) 

136 else: 

137 status = _validate_provider_basic(provider, api_key) 

138 results[provider] = status 

139 

140 return results 

141 

142 

143def get_masked_api_key(provider: str = "openai") -> str: 

144 """Return masked API key for safe logging.""" 

145 settings = get_settings() 

146 key_field_map = { 

147 "openai": "openai_api_key", 

148 "anthropic": "anthropic_api_key", 

149 "gemini": "gemini_api_key", 

150 } 

151 key_field = key_field_map.get(provider) 

152 if key_field: 

153 configured = getattr(settings, key_field, None) 

154 if isinstance(configured, str) and configured.strip(): 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true

155 return settings.get_masked_key(key_name=key_field, visible_chars=4) 

156 

157 api_key, _ = _get_provider_api_key_and_env(provider) 

158 

159 if provider == "ollama": 

160 return "N/A (local service)" 

161 

162 if not api_key: 

163 return "***" 

164 

165 if SECURITY_AVAILABLE and APIKeyValidator: 

166 return APIKeyValidator.mask_key(api_key, visible_chars=4) 

167 

168 if len(api_key) <= 4: 

169 return "***" 

170 return f"...{api_key[-4:]}" 

171 

172 

173class LLMManager: 

174 """Manager for multiple LLM providers with fallback support.""" 

175 

176 def __init__(self, config_path: str | None = None) -> None: 

177 self.providers: dict[str, LLMProvider] = {} 

178 self.config = self._load_config(config_path) 

179 self.logger = logging.getLogger("llm_providers.manager") 

180 self._initialize_providers() 

181 

182 def _load_config(self, config_path: str | None) -> dict[str, Any]: 

183 """Load configuration from file or environment.""" 

184 config: dict[str, Any] = { 

185 "providers": {}, 

186 "default_provider": "openai", 

187 # Plan cascade: openai -> anthropic -> gemini (-> ollama future) 

188 "fallback_providers": ["anthropic", "gemini", "ollama"], 

189 } 

190 

191 if config_path and Path(config_path).exists(): 191 ↛ 192line 191 didn't jump to line 192 because the condition on line 191 was never true

192 with contextlib.suppress(OSError, json.JSONDecodeError): 

193 with open(config_path, encoding="utf-8") as f: 

194 file_config = json.load(f) 

195 config.update(file_config) 

196 

197 # Add environment-based provider configs 

198 if not config["providers"].get("openai"): 198 ↛ 204line 198 didn't jump to line 204 because the condition on line 198 was always true

199 config["providers"]["openai"] = { 

200 "api_key": os.getenv("OPENAI_API_KEY"), 

201 "default_model": "gpt-4", 

202 } 

203 

204 if not config["providers"].get("anthropic"): 204 ↛ 210line 204 didn't jump to line 210 because the condition on line 204 was always true

205 config["providers"]["anthropic"] = { 

206 "api_key": os.getenv("ANTHROPIC_API_KEY"), 

207 "default_model": "claude-3-5-haiku-20241022", 

208 } 

209 

210 if not config["providers"].get("gemini"): 210 ↛ 216line 210 didn't jump to line 216 because the condition on line 210 was always true

211 config["providers"]["gemini"] = { 

212 "api_key": os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY"), 

213 "default_model": "gemini-pro", 

214 } 

215 

216 if not config["providers"].get("ollama"): 216 ↛ 222line 216 didn't jump to line 222 because the condition on line 216 was always true

217 config["providers"]["ollama"] = { 

218 "base_url": os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"), 

219 "default_model": "llama2", 

220 } 

221 

222 return config 

223 

224 def _initialize_providers(self) -> None: 

225 """Initialize all configured providers.""" 

226 provider_classes = { 

227 "openai": OpenAIProvider, 

228 "anthropic": __import__( 

229 "session_buddy.llm.providers.anthropic_provider", 

230 fromlist=["AnthropicProvider"], 

231 ).AnthropicProvider, 

232 "gemini": GeminiProvider, 

233 "ollama": OllamaProvider, 

234 } 

235 

236 for provider_name, provider_config in self.config["providers"].items(): 

237 if provider_name in provider_classes: 237 ↛ 236line 237 didn't jump to line 236 because the condition on line 237 was always true

238 try: 

239 self.providers[provider_name] = provider_classes[provider_name]( 

240 provider_config, 

241 ) 

242 except Exception as e: 

243 self.logger.warning( 

244 f"Failed to initialize {provider_name} provider: {e}", 

245 ) 

246 

247 async def get_available_providers(self) -> list[str]: 

248 """Get list of available providers.""" 

249 return [ 

250 name 

251 for name, provider in self.providers.items() 

252 if await provider.is_available() 

253 ] 

254 

255 async def generate( 

256 self, 

257 messages: list[LLMMessage], 

258 provider: str | None = None, 

259 model: str | None = None, 

260 use_fallback: bool = True, 

261 **kwargs: Any, 

262 ) -> LLMResponse: 

263 """Generate response with optional fallback.""" 

264 target_provider = provider or self.config["default_provider"] 

265 

266 # Try primary provider 

267 result = await self._try_primary_provider_generate( 

268 target_provider, 

269 messages, 

270 model, 

271 **kwargs, 

272 ) 

273 if result is not None: 

274 return result 

275 

276 # Try fallback providers if enabled 

277 if use_fallback: 

278 result = await self._try_fallback_providers_generate( 

279 target_provider, 

280 messages, 

281 model, 

282 **kwargs, 

283 ) 

284 if result is not None: 

285 return result 

286 

287 msg = "No available LLM providers" 

288 raise RuntimeError(msg) 

289 

290 async def _try_primary_provider_generate( 

291 self, 

292 target_provider: str, 

293 messages: list[LLMMessage], 

294 model: str | None, 

295 **kwargs: Any, 

296 ) -> LLMResponse | None: 

297 """Try generating with primary provider.""" 

298 if target_provider not in self.providers: 

299 return None 

300 

301 try: 

302 provider_instance = self.providers[target_provider] 

303 if await provider_instance.is_available(): 

304 return await provider_instance.generate(messages, model, **kwargs) 

305 except Exception as e: 

306 self.logger.warning(f"Provider {target_provider} failed: {e}") 

307 return None 

308 

309 async def _try_fallback_providers_generate( 

310 self, 

311 target_provider: str, 

312 messages: list[LLMMessage], 

313 model: str | None, 

314 **kwargs: Any, 

315 ) -> LLMResponse | None: 

316 """Try generating with fallback providers.""" 

317 for fallback_name in self.config.get("fallback_providers", []): 

318 if fallback_name in self.providers and fallback_name != target_provider: 

319 try: 

320 provider_instance = self.providers[fallback_name] 

321 if await provider_instance.is_available(): 

322 self.logger.info(f"Falling back to {fallback_name}") 

323 return await provider_instance.generate( 

324 messages, 

325 model, 

326 **kwargs, 

327 ) 

328 except Exception as e: 

329 self.logger.warning( 

330 f"Fallback provider {fallback_name} failed: {e}", 

331 ) 

332 return None 

333 

334 def _get_fallback_providers(self, target_provider: str) -> list[str]: 

335 """Get list of fallback providers excluding the target provider.""" 

336 return [ 

337 name 

338 for name in self.config.get("fallback_providers", []) 

339 if name in self.providers and name != target_provider 

340 ] 

341 

342 def _is_valid_provider(self, provider_name: str) -> bool: 

343 """Check if a provider is valid and available.""" 

344 return provider_name in self.providers 

345 

346 async def _get_provider_stream( 

347 self, 

348 provider_name: str, 

349 messages: list[LLMMessage], 

350 model: str | None, 

351 **kwargs: Any, 

352 ) -> AsyncGenerator[str]: 

353 """Get stream from provider (assumes provider is available).""" 

354 provider_instance = self.providers[provider_name] 

355 async for chunk in provider_instance.stream_generate( # type: ignore[attr-defined] 

356 messages, 

357 model, 

358 **kwargs, 

359 ): 

360 yield chunk 

361 

362 async def _try_provider_streaming( 

363 self, 

364 provider_name: str, 

365 messages: list[LLMMessage], 

366 model: str | None, 

367 **kwargs: Any, 

368 ) -> AsyncGenerator[str]: 

369 """Try streaming from a provider with error handling.""" 

370 try: 

371 provider_instance = self.providers[provider_name] 

372 if await provider_instance.is_available(): 

373 async for chunk in self._get_provider_stream( 

374 provider_name, 

375 messages, 

376 model, 

377 **kwargs, 

378 ): 

379 yield chunk 

380 except Exception as e: 

381 self.logger.warning(f"Provider {provider_name} failed: {e}") 

382 

383 async def _select_primary_provider(self, options: StreamGenerationOptions) -> str: 

384 """Select primary provider. Target complexity: ≤3.""" 

385 target_provider = options.provider or self.config["default_provider"] 

386 if not self._is_valid_provider(target_provider): 

387 msg = f"Invalid provider: {target_provider}" 

388 raise RuntimeError(msg) 

389 return target_provider 

390 

391 async def _try_streaming_from_provider( 

392 self, 

393 provider_name: str, 

394 messages: list[LLMMessage], 

395 options: StreamGenerationOptions, 

396 ) -> AsyncGenerator[StreamChunk]: 

397 """Try streaming from a specific provider. Target complexity: ≤6.""" 

398 try: 

399 stream_started = False 

400 async for chunk_content in self._try_provider_streaming( 

401 provider_name, 

402 messages, 

403 options.model, 

404 temperature=options.temperature, 

405 max_tokens=options.max_tokens, 

406 ): 

407 stream_started = True 

408 yield StreamChunk.content_chunk(chunk_content, provider_name) 

409 

410 if not stream_started: 

411 yield StreamChunk.error_chunk(f"No response from {provider_name}") 

412 

413 except Exception as e: 

414 self.logger.warning(f"Provider {provider_name} failed: {e}") 

415 yield StreamChunk.error_chunk(str(e)) 

416 

417 async def _stream_from_primary_provider( 

418 self, 

419 primary_provider: str, 

420 messages: list[LLMMessage], 

421 options: StreamGenerationOptions, 

422 ) -> AsyncGenerator[str]: 

423 """Stream from primary provider. Target complexity: ≤4.""" 

424 has_content = False 

425 async for chunk in self._try_streaming_from_provider( 

426 primary_provider, 

427 messages, 

428 options, 

429 ): 

430 if chunk.is_error: 

431 if not has_content: # Log errors only if no content received 

432 self.logger.warning( 

433 f"Primary provider error: {chunk.metadata.get('error', 'Unknown')}", 

434 ) 

435 continue 

436 

437 has_content = True 

438 yield chunk.content 

439 

440 if not has_content: 

441 self.logger.debug( 

442 f"Primary provider {primary_provider} produced no content", 

443 ) 

444 

445 async def _stream_from_fallback_providers( 

446 self, 

447 primary_provider: str, 

448 messages: list[LLMMessage], 

449 options: StreamGenerationOptions, 

450 ) -> AsyncGenerator[str]: 

451 """Stream from fallback providers. Target complexity: ≤5.""" 

452 if not options.use_fallback: 

453 return 

454 

455 fallback_providers = self._get_fallback_providers(primary_provider) 

456 for fallback_name in fallback_providers: 

457 self.logger.info(f"Falling back to {fallback_name}") 

458 has_content = False 

459 async for chunk in self._try_streaming_from_provider( 

460 fallback_name, 

461 messages, 

462 options, 

463 ): 

464 if chunk.is_error: 

465 continue 

466 has_content = True 

467 yield chunk.content 

468 if has_content: 

469 return 

470 

471 async def stream_generate( # type: ignore[override] 

472 self, 

473 messages: list[LLMMessage], 

474 provider: str | None = None, 

475 model: str | None = None, 

476 use_fallback: bool = True, 

477 **kwargs: Any, 

478 ) -> AsyncGenerator[str]: 

479 """Stream generate response with optional fallback. Target complexity: ≤8.""" 

480 options = StreamGenerationOptions( 

481 provider=provider, 

482 model=model, 

483 use_fallback=use_fallback, 

484 temperature=kwargs.get("temperature", 0.7), 

485 max_tokens=kwargs.get("max_tokens"), 

486 ) 

487 

488 try: 

489 # Try primary provider first 

490 primary_provider = await self._select_primary_provider(options) 

491 async for chunk_content in self._stream_from_primary_provider( 

492 primary_provider, 

493 messages, 

494 options, 

495 ): 

496 yield chunk_content 

497 return # Success - exit early 

498 

499 # Try fallback providers if primary failed 

500 async for chunk_content in self._stream_from_fallback_providers( 

501 primary_provider, 

502 messages, 

503 options, 

504 ): 

505 yield chunk_content 

506 return # Success - exit early 

507 

508 # All providers failed 

509 msg = "No available LLM providers" 

510 raise RuntimeError(msg) 

511 

512 except Exception as e: 

513 self.logger.exception(f"Stream generation failed: {e}") 

514 raise 

515 

516 def get_provider_info(self) -> dict[str, Any]: 

517 """Get information about all providers.""" 

518 info: dict[str, Any] = { 

519 "providers": {}, 

520 "config": { 

521 "default_provider": self.config["default_provider"], 

522 "fallback_providers": self.config.get("fallback_providers", []), 

523 }, 

524 } 

525 

526 for name, provider in self.providers.items(): 

527 info["providers"][name] = { 

528 "models": provider.get_models(), 

529 "config": { 

530 k: v for k, v in provider.config.items() if "key" not in k.lower() 

531 }, 

532 } 

533 

534 return info 

535 

536 async def test_providers(self) -> dict[str, Any]: 

537 """Test all providers and return status.""" 

538 test_message = [ 

539 LLMMessage(role="user", content='Hello, respond with just "OK"'), 

540 ] 

541 results = {} 

542 

543 for name, provider in self.providers.items(): 

544 try: 

545 available = await provider.is_available() 

546 if available: 

547 # Quick test generation 

548 response = await provider.generate(test_message, max_tokens=10) 

549 results[name] = { 

550 "available": True, 

551 "test_successful": True, 

552 "response_length": len(response.content), 

553 "model": response.model, 

554 } 

555 else: 

556 results[name] = { 

557 "available": False, 

558 "test_successful": False, 

559 "error": "Provider not available", 

560 } 

561 except Exception as e: 

562 results[name] = { 

563 "available": False, 

564 "test_successful": False, 

565 "error": str(e), 

566 } 

567 

568 return results