Coverage for agentos/mcp/sampling.py: 0%

203 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2AgentOS v1.14.0 — MCP Sampling 支持。 

3 

4MCP Sampling 允许 MCP Server 向 Client 发起 LLM 请求。 

5这是 MCP 协议中 server→client 方向的核心能力,对标 Claude Desktop 的采样功能。 

6 

7协议流程: 

81. Server 发送 `sampling/createMessage` 请求到 Client 

92. Client 调用 LLM 生成回复 

103. Client 返回结果给 Server 

11""" 

12 

13from __future__ import annotations 

14 

15import asyncio 

16import json 

17from dataclasses import dataclass, field 

18from enum import Enum 

19from typing import Any, Callable, Coroutine, Dict, List, Optional, Union 

20 

21 

22# ── Sampling Data Models ──────────────────── 

23 

24 

25class SamplingRole(str, Enum): 

26 """Sampling 消息角色。""" 

27 USER = "user" 

28 ASSISTANT = "assistant" 

29 

30 

31@dataclass 

32class SamplingContentBlock: 

33 """Sampling 内容块。 

34 

35 支持 text 和 image 两种类型。 

36 """ 

37 

38 type: str = "text" # text | image 

39 text: str = "" 

40 data: str = "" # base64 image data 

41 mime_type: str = "" 

42 

43 def to_dict(self) -> dict: 

44 d: dict = {"type": self.type} 

45 if self.type == "text": 

46 d["text"] = self.text 

47 elif self.type == "image": 

48 d["data"] = self.data 

49 d["mimeType"] = self.mime_type 

50 return d 

51 

52 @classmethod 

53 def from_dict(cls, d: dict) -> "SamplingContentBlock": 

54 return cls( 

55 type=d.get("type", "text"), 

56 text=d.get("text", ""), 

57 data=d.get("data", ""), 

58 mime_type=d.get("mimeType", ""), 

59 ) 

60 

61 @classmethod 

62 def text_block(cls, text: str) -> "SamplingContentBlock": 

63 return cls(type="text", text=text) 

64 

65 @classmethod 

66 def image_block(cls, base64_data: str, mime_type: str = "image/png") -> "SamplingContentBlock": 

67 return cls(type="image", data=base64_data, mime_type=mime_type) 

68 

69 

70@dataclass 

71class SamplingMessage: 

72 """Sampling 消息。""" 

73 role: SamplingRole 

74 content: Union[str, List[SamplingContentBlock]] 

75 

76 def to_dict(self) -> dict: 

77 if isinstance(self.content, str): 

78 return { 

79 "role": self.role.value, 

80 "content": {"type": "text", "text": self.content}, 

81 } 

82 return { 

83 "role": self.role.value, 

84 "content": [c.to_dict() for c in self.content], 

85 } 

86 

87 @classmethod 

88 def from_dict(cls, d: dict) -> "SamplingMessage": 

89 role = SamplingRole(d["role"]) 

90 content_raw = d["content"] 

91 if isinstance(content_raw, str): 

92 content = content_raw 

93 elif isinstance(content_raw, dict): 

94 content = [SamplingContentBlock.from_dict(content_raw)] 

95 elif isinstance(content_raw, list): 

96 content = [SamplingContentBlock.from_dict(c) for c in content_raw] 

97 else: 

98 content = str(content_raw) 

99 return cls(role=role, content=content) 

100 

101 

102@dataclass 

103class SamplingRequest: 

104 """Server 发起的 Sampling 请求。 

105 

106 符合 MCP sampling/createMessage 规范。 

107 """ 

108 

109 messages: List[SamplingMessage] 

110 model_preferences: Optional[Dict[str, Any]] = None 

111 system_prompt: str = "" 

112 include_context: str = "none" # none | thisServer | allServers 

113 temperature: float = 0.7 

114 max_tokens: int = 4096 

115 stop_sequences: List[str] = field(default_factory=list) 

116 metadata: Dict[str, Any] = field(default_factory=dict) 

117 

118 def to_dict(self) -> dict: 

119 d: dict = { 

120 "messages": [m.to_dict() for m in self.messages], 

121 "maxTokens": self.max_tokens, 

122 } 

123 if self.model_preferences: 

124 d["modelPreferences"] = self.model_preferences 

125 if self.system_prompt: 

126 d["systemPrompt"] = self.system_prompt 

127 if self.include_context != "none": 

128 d["includeContext"] = self.include_context 

129 if self.temperature != 0.7: 

130 d["temperature"] = self.temperature 

131 if self.stop_sequences: 

132 d["stopSequences"] = self.stop_sequences 

133 if self.metadata: 

134 d["metadata"] = self.metadata 

135 return d 

136 

137 @classmethod 

138 def from_dict(cls, d: dict) -> "SamplingRequest": 

139 return cls( 

140 messages=[SamplingMessage.from_dict(m) for m in d.get("messages", [])], 

141 model_preferences=d.get("modelPreferences"), 

142 system_prompt=d.get("systemPrompt", ""), 

143 include_context=d.get("includeContext", "none"), 

144 temperature=d.get("temperature", 0.7), 

145 max_tokens=d.get("maxTokens", 4096), 

146 stop_sequences=d.get("stopSequences", []), 

147 metadata=d.get("metadata", {}), 

148 ) 

149 

150 

151@dataclass 

152class SamplingResponse: 

153 """Sampling 响应。 

154 

155 Client 调用 LLM 后返回给 Server 的结果。 

156 """ 

157 

158 model: str = "" 

159 role: SamplingRole = SamplingRole.ASSISTANT 

160 content: Union[str, List[SamplingContentBlock]] = "" 

161 stop_reason: str = "" # endTurn | stopSequence | maxTokens 

162 metadata: Dict[str, Any] = field(default_factory=dict) 

163 

164 def to_dict(self) -> dict: 

165 content_block: dict 

166 if isinstance(self.content, str): 

167 content_block = {"type": "text", "text": self.content} 

168 elif isinstance(self.content, list): 

169 content_block = [c.to_dict() for c in self.content] 

170 else: 

171 content_block = {"type": "text", "text": str(self.content)} 

172 

173 return { 

174 "model": self.model, 

175 "role": self.role.value, 

176 "content": content_block, 

177 "stopReason": self.stop_reason or "endTurn", 

178 } 

179 

180 @classmethod 

181 def from_dict(cls, d: dict) -> "SamplingResponse": 

182 content_raw = d.get("content", "") 

183 if isinstance(content_raw, str): 

184 content = content_raw 

185 elif isinstance(content_raw, dict): 

186 content = [SamplingContentBlock.from_dict(content_raw)] 

187 elif isinstance(content_raw, list): 

188 content = [SamplingContentBlock.from_dict(c) for c in content_raw] 

189 else: 

190 content = str(content_raw) 

191 return cls( 

192 model=d.get("model", ""), 

193 role=SamplingRole(d.get("role", "assistant")), 

194 content=content, 

195 stop_reason=d.get("stopReason", ""), 

196 ) 

197 

198 

199# ── Sampling Handler ──────────────────────── 

200 

201 

202class SamplingError(Exception): 

203 """Sampling 错误。""" 

204 def __init__(self, code: int, message: str): 

205 self.code = code 

206 self.message = message 

207 super().__init__(f"Sampling Error [{code}]: {message}") 

208 

209 

210# LLM 调用接口 

211LLMCallFn = Callable[ 

212 [SamplingRequest], 

213 Coroutine[Any, Any, SamplingResponse], 

214] 

215 

216 

217class MCPClientSampling: 

218 """MCP Client 端 Sampling 支持。 

219 

220 在 MCPClient 上挂载此 handler 后,MCP Server 可通过 

221 `sampling/createMessage` 向 Client 发起 LLM 请求。 

222 

223 Usage: 

224 client = MCPClient() 

225 sampling = MCPClientSampling(my_llm_call_fn) 

226 client.set_sampling_handler(sampling) 

227 """ 

228 

229 def __init__( 

230 self, 

231 llm_call_fn: LLMCallFn, 

232 default_model: str = "claude-sonnet-4-20250514", 

233 max_tokens_limit: int = 8192, 

234 allow_image_input: bool = True, 

235 ): 

236 self._llm_call = llm_call_fn 

237 self.default_model = default_model 

238 self.max_tokens_limit = max_tokens_limit 

239 self.allow_image_input = allow_image_input 

240 

241 async def handle_create_message( 

242 self, 

243 params: dict[str, Any], 

244 ) -> dict[str, Any]: 

245 """处理 sampling/createMessage 请求。 

246 

247 Args: 

248 params: MCP 请求参数 

249 

250 Returns: 

251 MCP JSON-RPC 响应 result 部分 

252 """ 

253 try: 

254 request = SamplingRequest.from_dict(params) 

255 except Exception as e: 

256 raise SamplingError(-32602, f"Invalid sampling request: {e}") 

257 

258 # Validate max_tokens 

259 if request.max_tokens > self.max_tokens_limit: 

260 request.max_tokens = self.max_tokens_limit 

261 

262 # Validate image input 

263 if not self.allow_image_input: 

264 for msg in request.messages: 

265 if isinstance(msg.content, list): 

266 has_image = any( 

267 c.type == "image" for c in msg.content 

268 ) 

269 if has_image: 

270 raise SamplingError( 

271 -32000, 

272 "Image input is not allowed by client policy", 

273 ) 

274 

275 try: 

276 response = await self._llm_call(request) 

277 except Exception as e: 

278 raise SamplingError(-32001, f"LLM call failed: {e}") 

279 

280 return { 

281 "model": response.model or self.default_model, 

282 "role": response.role.value, 

283 "content": response.to_dict()["content"], 

284 "stopReason": response.stop_reason or "endTurn", 

285 } 

286 

287 

288# ── Mock LLM for Testing ─────────────────── 

289 

290 

291async def mock_llm_call(request: SamplingRequest) -> SamplingResponse: 

292 """Mock LLM 调用函数(用于测试)。 

293 

294 简单回显最后一条用户消息的内容。 

295 """ 

296 last_msg = request.messages[-1] if request.messages else None 

297 if last_msg and isinstance(last_msg.content, str): 

298 content = f"[Mock] Echo: {last_msg.content[:100]}" 

299 elif last_msg and isinstance(last_msg.content, list): 

300 text_parts = [c.text for c in last_msg.content if c.type == "text"] 

301 content = f"[Mock] Echo: {' '.join(text_parts)[:100]}" 

302 else: 

303 content = "[Mock] No input messages." 

304 

305 return SamplingResponse( 

306 model="mock-model", 

307 content=content, 

308 stop_reason="endTurn", 

309 ) 

310 

311 

312# ── Resource Templates ────────────────────── 

313 

314 

315@dataclass 

316class MCPResourceTemplate: 

317 """MCP Resource Template(URI 模板)。 

318 

319 Server 可暴露参数化的资源模板,Client 可按模板实例化资源。 

320 """ 

321 

322 uri_template: str 

323 name: str = "" 

324 description: str = "" 

325 mime_type: str = "" 

326 annotations: Dict[str, Any] = field(default_factory=dict) 

327 

328 def to_dict(self) -> dict: 

329 d: dict = { 

330 "uriTemplate": self.uri_template, 

331 "name": self.name, 

332 } 

333 if self.description: 

334 d["description"] = self.description 

335 if self.mime_type: 

336 d["mimeType"] = self.mime_type 

337 if self.annotations: 

338 d["annotations"] = self.annotations 

339 return d 

340 

341 @classmethod 

342 def from_dict(cls, d: dict) -> "MCPResourceTemplate": 

343 return cls( 

344 uri_template=d.get("uriTemplate", ""), 

345 name=d.get("name", ""), 

346 description=d.get("description", ""), 

347 mime_type=d.get("mimeType", ""), 

348 annotations=d.get("annotations", {}), 

349 ) 

350 

351 

352# ── MCP Logging ──────────────────────────── 

353 

354 

355class MCPLogLevel(str, Enum): 

356 """MCP 日志级别。""" 

357 DEBUG = "debug" 

358 INFO = "info" 

359 NOTICE = "notice" 

360 WARNING = "warning" 

361 ERROR = "error" 

362 CRITICAL = "critical" 

363 ALERT = "alert" 

364 EMERGENCY = "emergency" 

365 

366 

367class MCPLoggingHandler: 

368 """MCP Client 端日志接收处理。 

369 

370 Client 可通过 `logging/setLevel` 设置日志级别, 

371 Server 通过 `notifications/message` 推送日志消息。 

372 """ 

373 

374 def __init__(self, max_level: MCPLogLevel = MCPLogLevel.INFO): 

375 self._max_level = max_level 

376 self._log_callback: Optional[Callable[[MCPLogLevel, str, str], None]] = None 

377 

378 def set_log_callback( 

379 self, 

380 callback: Callable[[MCPLogLevel, str, str], None], 

381 ) -> None: 

382 """设置日志回调函数。 

383 

384 Args: 

385 callback: fn(level, logger_name, message) 

386 """ 

387 self._log_callback = callback 

388 

389 @property 

390 def max_level(self) -> MCPLogLevel: 

391 return self._max_level 

392 

393 def set_level(self, level: MCPLogLevel) -> None: 

394 """Client 设置日志级别。""" 

395 self._max_level = level 

396 

397 def _level_rank(self, level: MCPLogLevel) -> int: 

398 levels = list(MCPLogLevel) 

399 return levels.index(level) 

400 

401 def should_log(self, level: MCPLogLevel) -> bool: 

402 """判断给定级别的日志是否应被记录。""" 

403 return self._level_rank(level) >= self._level_rank(self._max_level) 

404 

405 def handle_log_message( 

406 self, 

407 level: MCPLogLevel, 

408 logger_name: str, 

409 message: str, 

410 ) -> None: 

411 """处理来自 Server 的日志消息。""" 

412 if self.should_log(level) and self._log_callback: 

413 self._log_callback(level, logger_name, message) 

414 

415 

416# ── MCP Roots ────────────────────────────── 

417 

418 

419@dataclass 

420class MCPRoot: 

421 """MCP Root — Client 暴露给 Server 的可访问文件系统根。""" 

422 uri: str 

423 name: str = "" 

424 

425 def to_dict(self) -> dict: 

426 d: dict = {"uri": self.uri} 

427 if self.name: 

428 d["name"] = self.name 

429 return d 

430 

431 @classmethod 

432 def from_dict(cls, d: dict) -> "MCPRoot": 

433 return cls(uri=d.get("uri", ""), name=d.get("name", ""))