Coverage for mcpgateway/handlers/sampling.py: 95%

87 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-09 11:03 +0100

1# -*- coding: utf-8 -*- 

2"""MCP Sampling Handler Implementation. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements the sampling handler for MCP LLM interactions. 

9It handles model selection, sampling preferences, and message generation. 

10""" 

11 

12# Standard 

13import logging 

14from typing import Any, Dict, List 

15 

16# Third-Party 

17from sqlalchemy.orm import Session 

18 

19# First-Party 

20from mcpgateway.models import CreateMessageResult, ModelPreferences, Role, TextContent 

21 

22logger = logging.getLogger(__name__) 

23 

24 

25class SamplingError(Exception): 

26 """Base class for sampling errors.""" 

27 

28 

29class SamplingHandler: 

30 """MCP sampling request handler. 

31 

32 Handles: 

33 - Model selection based on preferences 

34 - Message sampling requests 

35 - Context management 

36 - Content validation 

37 """ 

38 

39 def __init__(self): 

40 """Initialize sampling handler.""" 

41 self._supported_models = { 

42 # Maps model names to capabilities scores (cost, speed, intelligence) 

43 "claude-3-haiku": (0.8, 0.9, 0.7), 

44 "claude-3-sonnet": (0.5, 0.7, 0.9), 

45 "claude-3-opus": (0.2, 0.5, 1.0), 

46 "gemini-1.5-pro": (0.6, 0.8, 0.8), 

47 } 

48 

49 async def initialize(self) -> None: 

50 """Initialize sampling handler.""" 

51 logger.info("Initializing sampling handler") 

52 

53 async def shutdown(self) -> None: 

54 """Shutdown sampling handler.""" 

55 logger.info("Shutting down sampling handler") 

56 

57 async def create_message(self, db: Session, request: Dict[str, Any]) -> CreateMessageResult: 

58 """Create message from sampling request. 

59 

60 Args: 

61 db: Database session 

62 request: Sampling request parameters 

63 

64 Returns: 

65 Sampled message result 

66 

67 Raises: 

68 SamplingError: If sampling fails 

69 """ 

70 try: 

71 # Extract request parameters 

72 messages = request.get("messages", []) 

73 max_tokens = request.get("maxTokens") 

74 model_prefs = ModelPreferences.model_validate(request.get("modelPreferences", {})) 

75 include_context = request.get("includeContext", "none") 

76 request.get("metadata", {}) 

77 

78 # Validate request 

79 if not messages: 

80 raise SamplingError("No messages provided") 

81 if not max_tokens: 

82 raise SamplingError("Max tokens not specified") 

83 

84 # Select model 

85 model = self._select_model(model_prefs) 

86 logger.info(f"Selected model: {model}") 

87 

88 # Include context if requested 

89 if include_context != "none": 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 messages = await self._add_context(db, messages, include_context) 

91 

92 # Validate messages 

93 for msg in messages: 

94 if not self._validate_message(msg): 

95 raise SamplingError(f"Invalid message format: {msg}") 

96 

97 # pylint: disable=fixme 

98 # TODO: Sample from selected model 

99 # For now return mock response 

100 response = self._mock_sample(messages=messages) 

101 

102 # Convert to result 

103 return CreateMessageResult( 

104 content=TextContent(type="text", text=response), 

105 model=model, 

106 role=Role.ASSISTANT, 

107 stop_reason="maxTokens", 

108 ) 

109 

110 except Exception as e: 

111 logger.error(f"Sampling error: {e}") 

112 raise SamplingError(str(e)) 

113 

114 def _select_model(self, preferences: ModelPreferences) -> str: 

115 """Select model based on preferences. 

116 

117 Args: 

118 preferences: Model selection preferences 

119 

120 Returns: 

121 Selected model name 

122 

123 Raises: 

124 SamplingError: If no suitable model found 

125 """ 

126 # Check model hints first 

127 if preferences.hints: 

128 for hint in preferences.hints: 128 ↛ 134line 128 didn't jump to line 134 because the loop on line 128 didn't complete

129 for model in self._supported_models: 129 ↛ 128line 129 didn't jump to line 128 because the loop on line 129 didn't complete

130 if hint.name and hint.name in model: 

131 return model 

132 

133 # Score models on preferences 

134 best_score = -1 

135 best_model = None 

136 

137 for model, caps in self._supported_models.items(): 

138 cost_score = caps[0] * (1 - preferences.cost_priority) 

139 speed_score = caps[1] * preferences.speed_priority 

140 intel_score = caps[2] * preferences.intelligence_priority 

141 

142 total_score = (cost_score + speed_score + intel_score) / 3 

143 

144 if total_score > best_score: 

145 best_score = total_score 

146 best_model = model 

147 

148 if not best_model: 

149 raise SamplingError("No suitable model found") 

150 

151 return best_model 

152 

153 async def _add_context(self, _db: Session, messages: List[Dict[str, Any]], _context_type: str) -> List[Dict[str, Any]]: 

154 """Add context to messages. 

155 

156 Args: 

157 _db: Database session 

158 messages: Message list 

159 _context_type: Context inclusion type 

160 

161 Returns: 

162 Messages with added context 

163 """ 

164 # pylint: disable=fixme 

165 # TODO: Implement context gathering based on type 

166 # For now return original messages 

167 return messages 

168 

169 def _validate_message(self, message: Dict[str, Any]) -> bool: 

170 """Validate message format. 

171 

172 Args: 

173 message: Message to validate 

174 

175 Returns: 

176 True if valid 

177 """ 

178 try: 

179 # Must have role and content 

180 if "role" not in message or "content" not in message or message["role"] not in ("user", "assistant"): 

181 return False 

182 

183 # Content must be valid 

184 content = message["content"] 

185 if content.get("type") == "text": 

186 if not isinstance(content.get("text"), str): 

187 return False 

188 elif content.get("type") == "image": 

189 if not (content.get("data") and content.get("mime_type")): 

190 return False 

191 else: 

192 return False 

193 

194 return True 

195 

196 except Exception: 

197 return False 

198 

199 def _mock_sample( 

200 self, 

201 messages: List[Dict[str, Any]], 

202 ) -> str: 

203 """Mock sampling response for testing. 

204 

205 Args: 

206 messages: Input messages 

207 

208 Returns: 

209 Sampled response text 

210 """ 

211 # Extract last user message 

212 last_msg = None 

213 for msg in reversed(messages): 

214 if msg["role"] == "user": 

215 last_msg = msg 

216 break 

217 

218 if not last_msg: 

219 return "I'm not sure what to respond to." 

220 

221 # Get user text 

222 user_text = "" 

223 content = last_msg["content"] 

224 if content["type"] == "text": 

225 user_text = content["text"] 

226 elif content["type"] == "image": 226 ↛ 230line 226 didn't jump to line 230 because the condition on line 226 was always true

227 user_text = "I see the image you shared." 

228 

229 # Generate simple response 

230 return f"You said: {user_text}\nHere is my response..."