Coverage for session_buddy / tools / llm_tools.py: 19.15%

148 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1#!/usr/bin/env python3 

2"""LLM provider management MCP tools. 

3 

4This module provides tools for managing and interacting with LLM providers 

5following crackerjack architecture patterns. 

6 

7Refactored to use utility modules for reduced code duplication. 

8""" 

9 

10from __future__ import annotations 

11 

12import typing as t 

13from typing import TYPE_CHECKING, Any 

14 

15from session_buddy.utils.error_handlers import _get_logger 

16from session_buddy.utils.instance_managers import ( 

17 get_llm_manager as resolve_llm_manager, 

18) 

19from session_buddy.utils.messages import ToolMessages 

20 

21if TYPE_CHECKING: 

22 from fastmcp import FastMCP 

23 

24 

25# Lazy loading flag for optional LLM dependencies 

26_llm_available: bool | None = None 

27 

28LLM_NOT_AVAILABLE_MSG = "LLM providers not available. Install dependencies: pip install openai google-generativeai aiohttp" 

29 

30 

31# ============================================================================ 

32# Service Resolution and Availability Checks 

33# ============================================================================ 

34 

35 

36def _check_llm_available() -> bool: 

37 """Check if LLM providers are available.""" 

38 global _llm_available 

39 

40 if _llm_available is None: 

41 try: 

42 import importlib.util 

43 

44 spec = importlib.util.find_spec("session_buddy.llm_providers") 

45 _llm_available = spec is not None 

46 except ImportError: 

47 _llm_available = False 

48 

49 return _llm_available 

50 

51 

52async def _get_llm_manager() -> Any: 

53 """Get LLM manager instance with lazy loading.""" 

54 global _llm_available 

55 

56 if _llm_available is False: 

57 return None 

58 

59 manager = await resolve_llm_manager() 

60 if manager is None: 

61 _get_logger().warning("LLM providers not available.") 

62 _llm_available = False 

63 return None 

64 

65 _llm_available = True 

66 return manager 

67 

68 

69async def _require_llm_manager() -> Any: 

70 """Get LLM manager or raise with helpful error message.""" 

71 if not _check_llm_available(): 

72 raise RuntimeError(LLM_NOT_AVAILABLE_MSG) 

73 

74 manager = await _get_llm_manager() 

75 if not manager: 

76 msg = "Failed to initialize LLM manager" 

77 raise RuntimeError(msg) 

78 

79 return manager 

80 

81 

82async def _execute_llm_operation( 

83 operation_name: str, 

84 operation: t.Callable[[Any], t.Awaitable[str]], 

85) -> str: 

86 """Execute an LLM operation with error handling.""" 

87 try: 

88 manager = await _require_llm_manager() 

89 return await operation(manager) 

90 except RuntimeError as e: 

91 return f"{e!s}" 

92 except Exception as e: 

93 _get_logger().exception(f"Error in {operation_name}: {e}") 

94 return ToolMessages.operation_failed(operation_name, e) 

95 

96 

97# ============================================================================ 

98# Output Formatting Helpers 

99# ============================================================================ 

100 

101 

102def _add_provider_details( 

103 output: list[str], 

104 providers: dict[str, Any], 

105 available_providers: set[str], 

106) -> None: 

107 """Add provider details to the output list.""" 

108 for provider_name, info in providers.items(): 

109 status = "✅" if provider_name in available_providers else "❌" 

110 output.append(f"{status} {provider_name.title()}") 

111 

112 if provider_name in available_providers: 

113 _add_model_list(output, info["models"]) 

114 output.append("") 

115 

116 

117def _add_model_list(output: list[str], models: list[str]) -> None: 

118 """Add model list to the output with truncation.""" 

119 displayed_models = models[:5] # Show first 5 models 

120 for model in displayed_models: 

121 output.append(f"{model}") 

122 

123 if len(models) > 5: 

124 output.append(f" • ... and {len(models) - 5} more") 

125 

126 

127def _add_config_summary(output: list[str], config: dict[str, Any]) -> None: 

128 """Add configuration summary to the output.""" 

129 output.extend( 

130 [ 

131 f"🎯 Default Provider: {config['default_provider']}", 

132 f"🔄 Fallback Providers: {', '.join(config['fallback_providers'])}", 

133 ], 

134 ) 

135 

136 

137def _format_provider_list(provider_data: dict[str, Any]) -> str: 

138 """Format provider information into a readable list.""" 

139 available_providers = provider_data["available_providers"] 

140 provider_info = provider_data["provider_info"] 

141 

142 output = ["🤖 Available LLM Providers", ""] 

143 _add_provider_details(output, provider_info["providers"], available_providers) 

144 _add_config_summary(output, provider_info["config"]) 

145 

146 return "\n".join(output) 

147 

148 

149def _format_generation_result(result: dict[str, Any]) -> str: 

150 """Format LLM generation result.""" 

151 output = ["✨ LLM Generation Result", ""] 

152 output.extend( 

153 ( 

154 f"🤖 Provider: {result['metadata']['provider']}", 

155 f"🎯 Model: {result['metadata']['model']}", 

156 f"⚡ Response time: {result['metadata']['response_time_ms']:.0f}ms", 

157 f"📊 Tokens: {result['metadata'].get('tokens_used', 'N/A')}", 

158 "", 

159 "💬 Generated text:", 

160 "─" * 40, 

161 result["text"], 

162 ) 

163 ) 

164 

165 return "\n".join(output) 

166 

167 

168def _format_chat_result(result: dict[str, Any], message_count: int) -> str: 

169 """Format LLM chat result.""" 

170 output = ["💬 LLM Chat Result", ""] 

171 output.extend( 

172 ( 

173 f"🤖 Provider: {result['metadata']['provider']}", 

174 f"🎯 Model: {result['metadata']['model']}", 

175 f"⚡ Response time: {result['metadata']['response_time_ms']:.0f}ms", 

176 f"📊 Messages: {message_count} → 1", 

177 "", 

178 "🎭 Assistant response:", 

179 "─" * 40, 

180 result["response"], 

181 ) 

182 ) 

183 

184 return "\n".join(output) 

185 

186 

187def _format_provider_config_output( 

188 provider: str, 

189 api_key: str | None = None, 

190 base_url: str | None = None, 

191 default_model: str | None = None, 

192) -> str: 

193 """Format the provider configuration output.""" 

194 output = ["⚙️ Provider Configuration Updated", ""] 

195 output.append(f"🤖 Provider: {provider}") 

196 

197 if api_key: 

198 # Don't show the full API key for security 

199 masked_key = api_key[:8] + "..." + api_key[-4:] if len(api_key) > 12 else "***" 

200 output.append(f"🔑 API Key: {masked_key}") 

201 

202 if base_url: 

203 output.append(f"🌐 Base URL: {base_url}") 

204 

205 if default_model: 

206 output.append(f"🎯 Default Model: {default_model}") 

207 

208 output.extend( 

209 ( 

210 "", 

211 "✅ Configuration saved successfully!", 

212 "💡 Use `test_llm_providers` to verify the configuration", 

213 ) 

214 ) 

215 

216 return "\n".join(output) 

217 

218 

219# ============================================================================ 

220# LLM Operation Implementations 

221# ============================================================================ 

222 

223 

224async def _list_llm_providers_operation(manager: Any) -> str: 

225 """List all available LLM providers and their models.""" 

226 provider_data = { 

227 "available_providers": await manager.get_available_providers(), 

228 "provider_info": manager.get_provider_info(), 

229 } 

230 return _format_provider_list(provider_data) 

231 

232 

233async def _list_llm_providers_impl() -> str: 

234 """List all available LLM providers and their models.""" 

235 return await _execute_llm_operation( 

236 "List LLM providers", 

237 _list_llm_providers_operation, 

238 ) 

239 

240 

241async def _test_llm_providers_operation(manager: Any) -> str: 

242 """Test all LLM providers to check their availability and functionality.""" 

243 test_results = await manager.test_all_providers() 

244 

245 output = ["🧪 LLM Provider Test Results", ""] 

246 

247 for provider, result in test_results.items(): 

248 status = "✅" if result["success"] else "❌" 

249 output.append(f"{status} {provider.title()}") 

250 

251 if result["success"]: 

252 output.extend( 

253 ( 

254 f" ⚡ Response time: {result['response_time_ms']:.0f}ms", 

255 f" 🎯 Model: {result['model']}", 

256 ) 

257 ) 

258 else: 

259 output.append(f" ❌ Error: {result['error']}") 

260 output.append("") 

261 

262 working_count = sum(1 for r in test_results.values() if r["success"]) 

263 total_count = len(test_results) 

264 output.append(f"📊 Summary: {working_count}/{total_count} providers working") 

265 

266 return "\n".join(output) 

267 

268 

269async def _test_llm_providers_impl() -> str: 

270 """Test all LLM providers to check their availability and functionality.""" 

271 return await _execute_llm_operation( 

272 "Test LLM providers", 

273 _test_llm_providers_operation, 

274 ) 

275 

276 

277async def _generate_with_llm_impl( 

278 prompt: str, 

279 provider: str | None = None, 

280 model: str | None = None, 

281 temperature: float = 0.7, 

282 max_tokens: int | None = None, 

283 use_fallback: bool = True, 

284) -> str: 

285 """Generate text using specified LLM provider.""" 

286 

287 async def operation(manager: Any) -> str: 

288 result = await manager.generate_text( 

289 prompt=prompt, 

290 provider=provider, 

291 model=model, 

292 temperature=temperature, 

293 max_tokens=max_tokens, 

294 use_fallback=use_fallback, 

295 ) 

296 

297 if result["success"]: 

298 return _format_generation_result(result) 

299 return f"❌ Generation failed: {result['error']}" 

300 

301 return await _execute_llm_operation("Generate with LLM", operation) 

302 

303 

304async def _chat_with_llm_impl( 

305 messages: list[dict[str, str]], 

306 provider: str | None = None, 

307 model: str | None = None, 

308 temperature: float = 0.7, 

309 max_tokens: int | None = None, 

310) -> str: 

311 """Have a conversation with an LLM provider.""" 

312 

313 async def operation(manager: Any) -> str: 

314 result = await manager.chat( 

315 messages=messages, 

316 provider=provider, 

317 model=model, 

318 temperature=temperature, 

319 max_tokens=max_tokens, 

320 ) 

321 

322 if result["success"]: 

323 return _format_chat_result(result, len(messages)) 

324 return f"❌ Chat failed: {result['error']}" 

325 

326 return await _execute_llm_operation("Chat with LLM", operation) 

327 

328 

329async def _configure_llm_provider_impl( 

330 provider: str, 

331 api_key: str | None = None, 

332 base_url: str | None = None, 

333 default_model: str | None = None, 

334) -> str: 

335 """Configure an LLM provider with API credentials and settings.""" 

336 

337 async def operation(manager: Any) -> str: 

338 config_data = {} 

339 if api_key: 

340 config_data["api_key"] = api_key 

341 if base_url: 

342 config_data["base_url"] = base_url 

343 if default_model: 

344 config_data["default_model"] = default_model 

345 

346 result = await manager.configure_provider(provider, config_data) 

347 

348 if result["success"]: 

349 return _format_provider_config_output( 

350 provider, 

351 api_key, 

352 base_url, 

353 default_model, 

354 ) 

355 return f"❌ Configuration failed: {result['error']}" 

356 

357 return await _execute_llm_operation("Configure LLM provider", operation) 

358 

359 

360# ============================================================================ 

361# MCP Tool Registration 

362# ============================================================================ 

363 

364 

365def register_llm_tools(mcp: FastMCP) -> None: 

366 """Register all LLM provider management MCP tools. 

367 

368 Args: 

369 mcp: FastMCP server instance 

370 

371 """ 

372 

373 @mcp.tool() 

374 async def list_llm_providers() -> str: 

375 """List all available LLM providers and their models.""" 

376 return await _list_llm_providers_impl() 

377 

378 @mcp.tool() 

379 async def test_llm_providers() -> str: 

380 """Test all LLM providers to check their availability and functionality.""" 

381 return await _test_llm_providers_impl() 

382 

383 @mcp.tool() 

384 async def generate_with_llm( 

385 prompt: str, 

386 provider: str | None = None, 

387 model: str | None = None, 

388 temperature: float = 0.7, 

389 max_tokens: int | None = None, 

390 use_fallback: bool = True, 

391 ) -> str: 

392 """Generate text using specified LLM provider. 

393 

394 Args: 

395 prompt: The text prompt to generate from 

396 provider: LLM provider to use (openai, gemini, ollama) 

397 model: Specific model to use 

398 temperature: Generation temperature (0.0-1.0) 

399 max_tokens: Maximum tokens to generate 

400 use_fallback: Whether to use fallback providers if primary fails 

401 

402 """ 

403 return await _generate_with_llm_impl( 

404 prompt, 

405 provider, 

406 model, 

407 temperature, 

408 max_tokens, 

409 use_fallback, 

410 ) 

411 

412 @mcp.tool() 

413 async def chat_with_llm( 

414 messages: list[dict[str, str]], 

415 provider: str | None = None, 

416 model: str | None = None, 

417 temperature: float = 0.7, 

418 max_tokens: int | None = None, 

419 ) -> str: 

420 """Have a conversation with an LLM provider. 

421 

422 Args: 

423 messages: List of messages in format [{"role": "user/assistant/system", "content": "text"}] 

424 provider: LLM provider to use (openai, gemini, ollama) 

425 model: Specific model to use 

426 temperature: Generation temperature (0.0-1.0) 

427 max_tokens: Maximum tokens to generate 

428 

429 """ 

430 return await _chat_with_llm_impl( 

431 messages, 

432 provider, 

433 model, 

434 temperature, 

435 max_tokens, 

436 ) 

437 

438 @mcp.tool() 

439 async def configure_llm_provider( 

440 provider: str, 

441 api_key: str | None = None, 

442 base_url: str | None = None, 

443 default_model: str | None = None, 

444 ) -> str: 

445 """Configure an LLM provider with API credentials and settings. 

446 

447 Args: 

448 provider: Provider name (openai, gemini, ollama) 

449 api_key: API key for the provider 

450 base_url: Base URL for the provider API 

451 default_model: Default model to use 

452 

453 """ 

454 return await _configure_llm_provider_impl( 

455 provider, 

456 api_key, 

457 base_url, 

458 default_model, 

459 )