Coverage for mcpgateway/handlers/sampling.py: 95%
87 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
1# -*- coding: utf-8 -*-
2"""MCP Sampling Handler Implementation.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements the sampling handler for MCP LLM interactions.
9It handles model selection, sampling preferences, and message generation.
10"""
12# Standard
13import logging
14from typing import Any, Dict, List
16# Third-Party
17from sqlalchemy.orm import Session
19# First-Party
20from mcpgateway.models import CreateMessageResult, ModelPreferences, Role, TextContent
22logger = logging.getLogger(__name__)
25class SamplingError(Exception):
26 """Base class for sampling errors."""
29class SamplingHandler:
30 """MCP sampling request handler.
32 Handles:
33 - Model selection based on preferences
34 - Message sampling requests
35 - Context management
36 - Content validation
37 """
39 def __init__(self):
40 """Initialize sampling handler."""
41 self._supported_models = {
42 # Maps model names to capabilities scores (cost, speed, intelligence)
43 "claude-3-haiku": (0.8, 0.9, 0.7),
44 "claude-3-sonnet": (0.5, 0.7, 0.9),
45 "claude-3-opus": (0.2, 0.5, 1.0),
46 "gemini-1.5-pro": (0.6, 0.8, 0.8),
47 }
49 async def initialize(self) -> None:
50 """Initialize sampling handler."""
51 logger.info("Initializing sampling handler")
53 async def shutdown(self) -> None:
54 """Shutdown sampling handler."""
55 logger.info("Shutting down sampling handler")
57 async def create_message(self, db: Session, request: Dict[str, Any]) -> CreateMessageResult:
58 """Create message from sampling request.
60 Args:
61 db: Database session
62 request: Sampling request parameters
64 Returns:
65 Sampled message result
67 Raises:
68 SamplingError: If sampling fails
69 """
70 try:
71 # Extract request parameters
72 messages = request.get("messages", [])
73 max_tokens = request.get("maxTokens")
74 model_prefs = ModelPreferences.model_validate(request.get("modelPreferences", {}))
75 include_context = request.get("includeContext", "none")
76 request.get("metadata", {})
78 # Validate request
79 if not messages:
80 raise SamplingError("No messages provided")
81 if not max_tokens:
82 raise SamplingError("Max tokens not specified")
84 # Select model
85 model = self._select_model(model_prefs)
86 logger.info(f"Selected model: {model}")
88 # Include context if requested
89 if include_context != "none": 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 messages = await self._add_context(db, messages, include_context)
92 # Validate messages
93 for msg in messages:
94 if not self._validate_message(msg):
95 raise SamplingError(f"Invalid message format: {msg}")
97 # pylint: disable=fixme
98 # TODO: Sample from selected model
99 # For now return mock response
100 response = self._mock_sample(messages=messages)
102 # Convert to result
103 return CreateMessageResult(
104 content=TextContent(type="text", text=response),
105 model=model,
106 role=Role.ASSISTANT,
107 stop_reason="maxTokens",
108 )
110 except Exception as e:
111 logger.error(f"Sampling error: {e}")
112 raise SamplingError(str(e))
114 def _select_model(self, preferences: ModelPreferences) -> str:
115 """Select model based on preferences.
117 Args:
118 preferences: Model selection preferences
120 Returns:
121 Selected model name
123 Raises:
124 SamplingError: If no suitable model found
125 """
126 # Check model hints first
127 if preferences.hints:
128 for hint in preferences.hints: 128 ↛ 134line 128 didn't jump to line 134 because the loop on line 128 didn't complete
129 for model in self._supported_models: 129 ↛ 128line 129 didn't jump to line 128 because the loop on line 129 didn't complete
130 if hint.name and hint.name in model:
131 return model
133 # Score models on preferences
134 best_score = -1
135 best_model = None
137 for model, caps in self._supported_models.items():
138 cost_score = caps[0] * (1 - preferences.cost_priority)
139 speed_score = caps[1] * preferences.speed_priority
140 intel_score = caps[2] * preferences.intelligence_priority
142 total_score = (cost_score + speed_score + intel_score) / 3
144 if total_score > best_score:
145 best_score = total_score
146 best_model = model
148 if not best_model:
149 raise SamplingError("No suitable model found")
151 return best_model
153 async def _add_context(self, _db: Session, messages: List[Dict[str, Any]], _context_type: str) -> List[Dict[str, Any]]:
154 """Add context to messages.
156 Args:
157 _db: Database session
158 messages: Message list
159 _context_type: Context inclusion type
161 Returns:
162 Messages with added context
163 """
164 # pylint: disable=fixme
165 # TODO: Implement context gathering based on type
166 # For now return original messages
167 return messages
169 def _validate_message(self, message: Dict[str, Any]) -> bool:
170 """Validate message format.
172 Args:
173 message: Message to validate
175 Returns:
176 True if valid
177 """
178 try:
179 # Must have role and content
180 if "role" not in message or "content" not in message or message["role"] not in ("user", "assistant"):
181 return False
183 # Content must be valid
184 content = message["content"]
185 if content.get("type") == "text":
186 if not isinstance(content.get("text"), str):
187 return False
188 elif content.get("type") == "image":
189 if not (content.get("data") and content.get("mime_type")):
190 return False
191 else:
192 return False
194 return True
196 except Exception:
197 return False
199 def _mock_sample(
200 self,
201 messages: List[Dict[str, Any]],
202 ) -> str:
203 """Mock sampling response for testing.
205 Args:
206 messages: Input messages
208 Returns:
209 Sampled response text
210 """
211 # Extract last user message
212 last_msg = None
213 for msg in reversed(messages):
214 if msg["role"] == "user":
215 last_msg = msg
216 break
218 if not last_msg:
219 return "I'm not sure what to respond to."
221 # Get user text
222 user_text = ""
223 content = last_msg["content"]
224 if content["type"] == "text":
225 user_text = content["text"]
226 elif content["type"] == "image": 226 ↛ 230line 226 didn't jump to line 230 because the condition on line 226 was always true
227 user_text = "I see the image you shared."
229 # Generate simple response
230 return f"You said: {user_text}\nHere is my response..."