Coverage for session_mgmt_mcp/llm_providers.py: 14.02%

325 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-01 05:22 -0700

1#!/usr/bin/env python3 

2"""Cross-LLM Compatibility for Session Management MCP Server. 

3 

4Provides unified interface for multiple LLM providers including OpenAI, Google Gemini, and Ollama. 

5""" 

6 

7import json 

8import logging 

9import os 

10from abc import ABC, abstractmethod 

11from collections.abc import AsyncGenerator 

12from dataclasses import dataclass 

13from datetime import datetime 

14from pathlib import Path 

15from typing import Any 

16 

17 

18@dataclass 

19class LLMMessage: 

20 """Standardized message format across LLM providers.""" 

21 

22 role: str # 'system', 'user', 'assistant' 

23 content: str 

24 timestamp: str | None = None 

25 metadata: dict[str, Any] = None 

26 

27 def __post_init__(self): 

28 if self.timestamp is None: 

29 self.timestamp = datetime.now().isoformat() 

30 if self.metadata is None: 

31 self.metadata = {} 

32 

33 

34@dataclass 

35class LLMResponse: 

36 """Standardized response format from LLM providers.""" 

37 

38 content: str 

39 model: str 

40 provider: str 

41 usage: dict[str, Any] 

42 finish_reason: str 

43 timestamp: str 

44 metadata: dict[str, Any] = None 

45 

46 def __post_init__(self): 

47 if self.metadata is None: 

48 self.metadata = {} 

49 

50 

51class LLMProvider(ABC): 

52 """Abstract base class for LLM providers.""" 

53 

54 def __init__(self, config: dict[str, Any]) -> None: 

55 self.config = config 

56 self.name = self.__class__.__name__.replace("Provider", "").lower() 

57 self.logger = logging.getLogger(f"llm_providers.{self.name}") 

58 

59 @abstractmethod 

60 async def generate( 

61 self, 

62 messages: list[LLMMessage], 

63 model: str | None = None, 

64 temperature: float = 0.7, 

65 max_tokens: int | None = None, 

66 **kwargs, 

67 ) -> LLMResponse: 

68 """Generate a response from the LLM.""" 

69 

70 @abstractmethod 

71 async def stream_generate( 

72 self, 

73 messages: list[LLMMessage], 

74 model: str | None = None, 

75 temperature: float = 0.7, 

76 max_tokens: int | None = None, 

77 **kwargs, 

78 ) -> AsyncGenerator[str]: 

79 """Generate a streaming response from the LLM.""" 

80 

81 @abstractmethod 

82 async def is_available(self) -> bool: 

83 """Check if the provider is available and properly configured.""" 

84 

85 @abstractmethod 

86 def get_models(self) -> list[str]: 

87 """Get list of available models for this provider.""" 

88 

89 

90class OpenAIProvider(LLMProvider): 

91 """OpenAI API provider.""" 

92 

93 def __init__(self, config: dict[str, Any]) -> None: 

94 super().__init__(config) 

95 self.api_key = config.get("api_key") or os.getenv("OPENAI_API_KEY") 

96 self.base_url = config.get("base_url", "https://api.openai.com/v1") 

97 self.default_model = config.get("default_model", "gpt-4") 

98 self._client = None 

99 

100 async def _get_client(self): 

101 """Get or create OpenAI client.""" 

102 if self._client is None: 

103 try: 

104 import openai 

105 

106 self._client = openai.AsyncOpenAI( 

107 api_key=self.api_key, 

108 base_url=self.base_url, 

109 ) 

110 except ImportError: 

111 msg = "OpenAI package not installed. Install with: pip install openai" 

112 raise ImportError( 

113 msg, 

114 ) 

115 return self._client 

116 

117 def _convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, str]]: 

118 """Convert LLMMessage objects to OpenAI format.""" 

119 return [{"role": msg.role, "content": msg.content} for msg in messages] 

120 

121 async def generate( 

122 self, 

123 messages: list[LLMMessage], 

124 model: str | None = None, 

125 temperature: float = 0.7, 

126 max_tokens: int | None = None, 

127 **kwargs, 

128 ) -> LLMResponse: 

129 """Generate response using OpenAI API.""" 

130 if not await self.is_available(): 

131 msg = "OpenAI provider not available" 

132 raise RuntimeError(msg) 

133 

134 client = await self._get_client() 

135 model_name = model or self.default_model 

136 

137 try: 

138 response = await client.chat.completions.create( 

139 model=model_name, 

140 messages=self._convert_messages(messages), 

141 temperature=temperature, 

142 max_tokens=max_tokens, 

143 **kwargs, 

144 ) 

145 

146 return LLMResponse( 

147 content=response.choices[0].message.content, 

148 model=model_name, 

149 provider="openai", 

150 usage={ 

151 "prompt_tokens": response.usage.prompt_tokens 

152 if response.usage 

153 else 0, 

154 "completion_tokens": response.usage.completion_tokens 

155 if response.usage 

156 else 0, 

157 "total_tokens": response.usage.total_tokens 

158 if response.usage 

159 else 0, 

160 }, 

161 finish_reason=response.choices[0].finish_reason, 

162 timestamp=datetime.now().isoformat(), 

163 metadata={"response_id": response.id}, 

164 ) 

165 

166 except Exception as e: 

167 self.logger.exception(f"OpenAI generation failed: {e}") 

168 raise 

169 

170 async def stream_generate( 

171 self, 

172 messages: list[LLMMessage], 

173 model: str | None = None, 

174 temperature: float = 0.7, 

175 max_tokens: int | None = None, 

176 **kwargs, 

177 ) -> AsyncGenerator[str]: 

178 """Stream response using OpenAI API.""" 

179 if not await self.is_available(): 

180 msg = "OpenAI provider not available" 

181 raise RuntimeError(msg) 

182 

183 client = await self._get_client() 

184 model_name = model or self.default_model 

185 

186 try: 

187 response = await client.chat.completions.create( 

188 model=model_name, 

189 messages=self._convert_messages(messages), 

190 temperature=temperature, 

191 max_tokens=max_tokens, 

192 stream=True, 

193 **kwargs, 

194 ) 

195 

196 async for chunk in response: 

197 if chunk.choices[0].delta.content: 

198 yield chunk.choices[0].delta.content 

199 

200 except Exception as e: 

201 self.logger.exception(f"OpenAI streaming failed: {e}") 

202 raise 

203 

204 async def is_available(self) -> bool: 

205 """Check if OpenAI API is available.""" 

206 if not self.api_key: 

207 return False 

208 

209 try: 

210 client = await self._get_client() 

211 # Test with a simple request 

212 await client.models.list() 

213 return True 

214 except Exception: 

215 return False 

216 

217 def get_models(self) -> list[str]: 

218 """Get available OpenAI models.""" 

219 return [ 

220 "gpt-4", 

221 "gpt-4-turbo", 

222 "gpt-4o", 

223 "gpt-4o-mini", 

224 "gpt-3.5-turbo", 

225 "gpt-3.5-turbo-16k", 

226 ] 

227 

228 

229class GeminiProvider(LLMProvider): 

230 """Google Gemini API provider.""" 

231 

232 def __init__(self, config: dict[str, Any]) -> None: 

233 super().__init__(config) 

234 self.api_key = ( 

235 config.get("api_key") 

236 or os.getenv("GEMINI_API_KEY") 

237 or os.getenv("GOOGLE_API_KEY") 

238 ) 

239 self.default_model = config.get("default_model", "gemini-pro") 

240 self._client = None 

241 

242 async def _get_client(self): 

243 """Get or create Gemini client.""" 

244 if self._client is None: 

245 try: 

246 import google.generativeai as genai 

247 

248 genai.configure(api_key=self.api_key) 

249 self._client = genai 

250 except ImportError: 

251 msg = "Google Generative AI package not installed. Install with: pip install google-generativeai" 

252 raise ImportError( 

253 msg, 

254 ) 

255 return self._client 

256 

257 def _convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, str]]: 

258 """Convert LLMMessage objects to Gemini format.""" 

259 converted = [] 

260 for msg in messages: 

261 if msg.role == "system": 

262 # Gemini doesn't have system role, prepend to first user message 

263 if converted and converted[-1]["role"] == "user": 

264 converted[-1]["parts"] = [ 

265 f"System: {msg.content}\n\nUser: {converted[-1]['parts'][0]}", 

266 ] 

267 else: 

268 converted.append( 

269 {"role": "user", "parts": [f"System: {msg.content}"]}, 

270 ) 

271 elif msg.role == "user": 

272 converted.append({"role": "user", "parts": [msg.content]}) 

273 elif msg.role == "assistant": 

274 converted.append({"role": "model", "parts": [msg.content]}) 

275 return converted 

276 

277 async def generate( 

278 self, 

279 messages: list[LLMMessage], 

280 model: str | None = None, 

281 temperature: float = 0.7, 

282 max_tokens: int | None = None, 

283 **kwargs, 

284 ) -> LLMResponse: 

285 """Generate response using Gemini API.""" 

286 if not await self.is_available(): 

287 msg = "Gemini provider not available" 

288 raise RuntimeError(msg) 

289 

290 genai = await self._get_client() 

291 model_name = model or self.default_model 

292 

293 try: 

294 model_instance = genai.GenerativeModel(model_name) 

295 

296 # Convert messages to Gemini chat format 

297 chat_messages = self._convert_messages(messages) 

298 

299 # Create chat or generate single response 

300 if len(chat_messages) > 1: 

301 chat = model_instance.start_chat(history=chat_messages[:-1]) 

302 response = await chat.send_message_async( 

303 chat_messages[-1]["parts"][0], 

304 generation_config={ 

305 "temperature": temperature, 

306 "max_output_tokens": max_tokens, 

307 }, 

308 ) 

309 else: 

310 response = await model_instance.generate_content_async( 

311 chat_messages[0]["parts"][0], 

312 generation_config={ 

313 "temperature": temperature, 

314 "max_output_tokens": max_tokens, 

315 }, 

316 ) 

317 

318 return LLMResponse( 

319 content=response.text, 

320 model=model_name, 

321 provider="gemini", 

322 usage={ 

323 "prompt_tokens": response.usage_metadata.prompt_token_count 

324 if hasattr(response, "usage_metadata") 

325 else 0, 

326 "completion_tokens": response.usage_metadata.candidates_token_count 

327 if hasattr(response, "usage_metadata") 

328 else 0, 

329 "total_tokens": response.usage_metadata.total_token_count 

330 if hasattr(response, "usage_metadata") 

331 else 0, 

332 }, 

333 finish_reason="stop", # Gemini doesn't provide detailed finish reasons 

334 timestamp=datetime.now().isoformat(), 

335 ) 

336 

337 except Exception as e: 

338 self.logger.exception(f"Gemini generation failed: {e}") 

339 raise 

340 

341 async def stream_generate( 

342 self, 

343 messages: list[LLMMessage], 

344 model: str | None = None, 

345 temperature: float = 0.7, 

346 max_tokens: int | None = None, 

347 **kwargs, 

348 ) -> AsyncGenerator[str]: 

349 """Stream response using Gemini API.""" 

350 if not await self.is_available(): 

351 msg = "Gemini provider not available" 

352 raise RuntimeError(msg) 

353 

354 genai = await self._get_client() 

355 model_name = model or self.default_model 

356 

357 try: 

358 model_instance = genai.GenerativeModel(model_name) 

359 chat_messages = self._convert_messages(messages) 

360 

361 if len(chat_messages) > 1: 

362 chat = model_instance.start_chat(history=chat_messages[:-1]) 

363 response = chat.send_message( 

364 chat_messages[-1]["parts"][0], 

365 generation_config={ 

366 "temperature": temperature, 

367 "max_output_tokens": max_tokens, 

368 }, 

369 stream=True, 

370 ) 

371 else: 

372 response = model_instance.generate_content( 

373 chat_messages[0]["parts"][0], 

374 generation_config={ 

375 "temperature": temperature, 

376 "max_output_tokens": max_tokens, 

377 }, 

378 stream=True, 

379 ) 

380 

381 for chunk in response: 

382 if chunk.text: 

383 yield chunk.text 

384 

385 except Exception as e: 

386 self.logger.exception(f"Gemini streaming failed: {e}") 

387 raise 

388 

389 async def is_available(self) -> bool: 

390 """Check if Gemini API is available.""" 

391 if not self.api_key: 

392 return False 

393 

394 try: 

395 genai = await self._get_client() 

396 # Test with a simple model list request 

397 list(genai.list_models()) 

398 return True 

399 except Exception: 

400 return False 

401 

402 def get_models(self) -> list[str]: 

403 """Get available Gemini models.""" 

404 return [ 

405 "gemini-pro", 

406 "gemini-pro-vision", 

407 "gemini-1.5-pro", 

408 "gemini-1.5-flash", 

409 "gemini-1.0-pro", 

410 ] 

411 

412 

413class OllamaProvider(LLMProvider): 

414 """Ollama local LLM provider.""" 

415 

416 def __init__(self, config: dict[str, Any]) -> None: 

417 super().__init__(config) 

418 self.base_url = config.get("base_url", "http://localhost:11434") 

419 self.default_model = config.get("default_model", "llama2") 

420 self._available_models = [] 

421 

422 async def _make_request( 

423 self, 

424 endpoint: str, 

425 data: dict[str, Any], 

426 ) -> dict[str, Any]: 

427 """Make HTTP request to Ollama API.""" 

428 try: 

429 import aiohttp 

430 

431 async with aiohttp.ClientSession() as session: 

432 async with session.post( 

433 f"{self.base_url}/{endpoint}", 

434 json=data, 

435 timeout=aiohttp.ClientTimeout(total=300), 

436 ) as response: 

437 return await response.json() 

438 except ImportError: 

439 msg = "aiohttp package not installed. Install with: pip install aiohttp" 

440 raise ImportError( 

441 msg, 

442 ) 

443 

444 def _convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, str]]: 

445 """Convert LLMMessage objects to Ollama format.""" 

446 return [{"role": msg.role, "content": msg.content} for msg in messages] 

447 

448 async def generate( 

449 self, 

450 messages: list[LLMMessage], 

451 model: str | None = None, 

452 temperature: float = 0.7, 

453 max_tokens: int | None = None, 

454 **kwargs, 

455 ) -> LLMResponse: 

456 """Generate response using Ollama API.""" 

457 if not await self.is_available(): 

458 msg = "Ollama provider not available" 

459 raise RuntimeError(msg) 

460 

461 model_name = model or self.default_model 

462 

463 try: 

464 data = { 

465 "model": model_name, 

466 "messages": self._convert_messages(messages), 

467 "options": {"temperature": temperature}, 

468 } 

469 

470 if max_tokens: 

471 data["options"]["num_predict"] = max_tokens 

472 

473 response = await self._make_request("api/chat", data) 

474 

475 return LLMResponse( 

476 content=response.get("message", {}).get("content", ""), 

477 model=model_name, 

478 provider="ollama", 

479 usage={ 

480 "prompt_tokens": response.get("prompt_eval_count", 0), 

481 "completion_tokens": response.get("eval_count", 0), 

482 "total_tokens": response.get("prompt_eval_count", 0) 

483 + response.get("eval_count", 0), 

484 }, 

485 finish_reason=response.get("done_reason", "stop"), 

486 timestamp=datetime.now().isoformat(), 

487 ) 

488 

489 except Exception as e: 

490 self.logger.exception(f"Ollama generation failed: {e}") 

491 raise 

492 

493 async def stream_generate( 

494 self, 

495 messages: list[LLMMessage], 

496 model: str | None = None, 

497 temperature: float = 0.7, 

498 max_tokens: int | None = None, 

499 **kwargs, 

500 ) -> AsyncGenerator[str]: 

501 """Stream response using Ollama API.""" 

502 if not await self.is_available(): 

503 msg = "Ollama provider not available" 

504 raise RuntimeError(msg) 

505 

506 model_name = model or self.default_model 

507 

508 try: 

509 import aiohttp 

510 

511 data = { 

512 "model": model_name, 

513 "messages": self._convert_messages(messages), 

514 "stream": True, 

515 "options": {"temperature": temperature}, 

516 } 

517 

518 if max_tokens: 

519 data["options"]["num_predict"] = max_tokens 

520 

521 async with aiohttp.ClientSession() as session: 

522 async with session.post( 

523 f"{self.base_url}/api/chat", 

524 json=data, 

525 timeout=aiohttp.ClientTimeout(total=300), 

526 ) as response: 

527 async for line in response.content: 

528 if line: 

529 try: 

530 chunk_data = json.loads(line.decode("utf-8")) 

531 if ( 

532 "message" in chunk_data 

533 and "content" in chunk_data["message"] 

534 ): 

535 yield chunk_data["message"]["content"] 

536 except json.JSONDecodeError: 

537 continue 

538 

539 except Exception as e: 

540 self.logger.exception(f"Ollama streaming failed: {e}") 

541 raise 

542 

543 async def is_available(self) -> bool: 

544 """Check if Ollama is available.""" 

545 try: 

546 import aiohttp 

547 

548 async with aiohttp.ClientSession() as session: 

549 async with session.get( 

550 f"{self.base_url}/api/tags", 

551 timeout=aiohttp.ClientTimeout(total=10), 

552 ) as response: 

553 if response.status == 200: 

554 data = await response.json() 

555 self._available_models = [ 

556 model["name"] for model in data.get("models", []) 

557 ] 

558 return True 

559 return False 

560 except Exception: 

561 return False 

562 

563 def get_models(self) -> list[str]: 

564 """Get available Ollama models.""" 

565 return ( 

566 self._available_models 

567 if self._available_models 

568 else [ 

569 "llama2", 

570 "llama2:13b", 

571 "llama2:70b", 

572 "codellama", 

573 "mistral", 

574 "mixtral", 

575 ] 

576 ) 

577 

578 

579class LLMManager: 

580 """Manager for multiple LLM providers with fallback support.""" 

581 

582 def __init__(self, config_path: str | None = None) -> None: 

583 self.providers: dict[str, LLMProvider] = {} 

584 self.config = self._load_config(config_path) 

585 self.logger = logging.getLogger("llm_providers.manager") 

586 self._initialize_providers() 

587 

588 def _load_config(self, config_path: str | None) -> dict[str, Any]: 

589 """Load configuration from file or environment.""" 

590 config = { 

591 "providers": {}, 

592 "default_provider": "openai", 

593 "fallback_providers": ["gemini", "ollama"], 

594 } 

595 

596 if config_path and Path(config_path).exists(): 

597 try: 

598 with open(config_path) as f: 

599 file_config = json.load(f) 

600 config.update(file_config) 

601 except (OSError, json.JSONDecodeError): 

602 pass 

603 

604 # Add environment-based provider configs 

605 if not config["providers"].get("openai"): 

606 config["providers"]["openai"] = { 

607 "api_key": os.getenv("OPENAI_API_KEY"), 

608 "default_model": "gpt-4", 

609 } 

610 

611 if not config["providers"].get("gemini"): 

612 config["providers"]["gemini"] = { 

613 "api_key": os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY"), 

614 "default_model": "gemini-pro", 

615 } 

616 

617 if not config["providers"].get("ollama"): 

618 config["providers"]["ollama"] = { 

619 "base_url": os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"), 

620 "default_model": "llama2", 

621 } 

622 

623 return config 

624 

625 def _initialize_providers(self) -> None: 

626 """Initialize all configured providers.""" 

627 provider_classes = { 

628 "openai": OpenAIProvider, 

629 "gemini": GeminiProvider, 

630 "ollama": OllamaProvider, 

631 } 

632 

633 for provider_name, provider_config in self.config["providers"].items(): 

634 if provider_name in provider_classes: 

635 try: 

636 self.providers[provider_name] = provider_classes[provider_name]( 

637 provider_config, 

638 ) 

639 except Exception as e: 

640 self.logger.warning( 

641 f"Failed to initialize {provider_name} provider: {e}", 

642 ) 

643 

644 async def get_available_providers(self) -> list[str]: 

645 """Get list of available providers.""" 

646 available = [] 

647 for name, provider in self.providers.items(): 

648 if await provider.is_available(): 

649 available.append(name) 

650 return available 

651 

652 async def generate( 

653 self, 

654 messages: list[LLMMessage], 

655 provider: str | None = None, 

656 model: str | None = None, 

657 use_fallback: bool = True, 

658 **kwargs, 

659 ) -> LLMResponse: 

660 """Generate response with optional fallback.""" 

661 target_provider = provider or self.config["default_provider"] 

662 

663 # Try primary provider 

664 if target_provider in self.providers: 

665 try: 

666 provider_instance = self.providers[target_provider] 

667 if await provider_instance.is_available(): 

668 return await provider_instance.generate(messages, model, **kwargs) 

669 except Exception as e: 

670 self.logger.warning(f"Provider {target_provider} failed: {e}") 

671 

672 # Try fallback providers if enabled 

673 if use_fallback: 

674 for fallback_name in self.config.get("fallback_providers", []): 

675 if fallback_name in self.providers and fallback_name != target_provider: 

676 try: 

677 provider_instance = self.providers[fallback_name] 

678 if await provider_instance.is_available(): 

679 self.logger.info(f"Falling back to {fallback_name}") 

680 return await provider_instance.generate( 

681 messages, 

682 model, 

683 **kwargs, 

684 ) 

685 except Exception as e: 

686 self.logger.warning( 

687 f"Fallback provider {fallback_name} failed: {e}", 

688 ) 

689 

690 msg = "No available LLM providers" 

691 raise RuntimeError(msg) 

692 

693 async def stream_generate( 

694 self, 

695 messages: list[LLMMessage], 

696 provider: str | None = None, 

697 model: str | None = None, 

698 use_fallback: bool = True, 

699 **kwargs, 

700 ) -> AsyncGenerator[str]: 

701 """Stream generate response with optional fallback.""" 

702 target_provider = provider or self.config["default_provider"] 

703 

704 # Try primary provider 

705 if target_provider in self.providers: 

706 try: 

707 provider_instance = self.providers[target_provider] 

708 if await provider_instance.is_available(): 

709 async for chunk in provider_instance.stream_generate( 

710 messages, 

711 model, 

712 **kwargs, 

713 ): 

714 yield chunk 

715 return 

716 except Exception as e: 

717 self.logger.warning(f"Provider {target_provider} failed: {e}") 

718 

719 # Try fallback providers if enabled 

720 if use_fallback: 

721 for fallback_name in self.config.get("fallback_providers", []): 

722 if fallback_name in self.providers and fallback_name != target_provider: 

723 try: 

724 provider_instance = self.providers[fallback_name] 

725 if await provider_instance.is_available(): 

726 self.logger.info(f"Falling back to {fallback_name}") 

727 async for chunk in provider_instance.stream_generate( 

728 messages, 

729 model, 

730 **kwargs, 

731 ): 

732 yield chunk 

733 return 

734 except Exception as e: 

735 self.logger.warning( 

736 f"Fallback provider {fallback_name} failed: {e}", 

737 ) 

738 

739 msg = "No available LLM providers" 

740 raise RuntimeError(msg) 

741 

742 def get_provider_info(self) -> dict[str, Any]: 

743 """Get information about all providers.""" 

744 info = { 

745 "providers": {}, 

746 "config": { 

747 "default_provider": self.config["default_provider"], 

748 "fallback_providers": self.config.get("fallback_providers", []), 

749 }, 

750 } 

751 

752 for name, provider in self.providers.items(): 

753 info["providers"][name] = { 

754 "models": provider.get_models(), 

755 "config": { 

756 k: v for k, v in provider.config.items() if "key" not in k.lower() 

757 }, 

758 } 

759 

760 return info 

761 

762 async def test_providers(self) -> dict[str, Any]: 

763 """Test all providers and return status.""" 

764 test_message = [ 

765 LLMMessage(role="user", content='Hello, respond with just "OK"'), 

766 ] 

767 results = {} 

768 

769 for name, provider in self.providers.items(): 

770 try: 

771 available = await provider.is_available() 

772 if available: 

773 # Quick test generation 

774 response = await provider.generate(test_message, max_tokens=10) 

775 results[name] = { 

776 "available": True, 

777 "test_successful": True, 

778 "response_length": len(response.content), 

779 "model": response.model, 

780 } 

781 else: 

782 results[name] = { 

783 "available": False, 

784 "test_successful": False, 

785 "error": "Provider not available", 

786 } 

787 except Exception as e: 

788 results[name] = { 

789 "available": False, 

790 "test_successful": False, 

791 "error": str(e), 

792 } 

793 

794 return results