Coverage for agentos/mcp/sampling.py: 0%
203 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2AgentOS v1.14.0 — MCP Sampling 支持。
4MCP Sampling 允许 MCP Server 向 Client 发起 LLM 请求。
5这是 MCP 协议中 server→client 方向的核心能力,对标 Claude Desktop 的采样功能。
7协议流程:
81. Server 发送 `sampling/createMessage` 请求到 Client
92. Client 调用 LLM 生成回复
103. Client 返回结果给 Server
11"""
13from __future__ import annotations
15import asyncio
16import json
17from dataclasses import dataclass, field
18from enum import Enum
19from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
22# ── Sampling Data Models ────────────────────
25class SamplingRole(str, Enum):
26 """Sampling 消息角色。"""
27 USER = "user"
28 ASSISTANT = "assistant"
31@dataclass
32class SamplingContentBlock:
33 """Sampling 内容块。
35 支持 text 和 image 两种类型。
36 """
38 type: str = "text" # text | image
39 text: str = ""
40 data: str = "" # base64 image data
41 mime_type: str = ""
43 def to_dict(self) -> dict:
44 d: dict = {"type": self.type}
45 if self.type == "text":
46 d["text"] = self.text
47 elif self.type == "image":
48 d["data"] = self.data
49 d["mimeType"] = self.mime_type
50 return d
52 @classmethod
53 def from_dict(cls, d: dict) -> "SamplingContentBlock":
54 return cls(
55 type=d.get("type", "text"),
56 text=d.get("text", ""),
57 data=d.get("data", ""),
58 mime_type=d.get("mimeType", ""),
59 )
61 @classmethod
62 def text_block(cls, text: str) -> "SamplingContentBlock":
63 return cls(type="text", text=text)
65 @classmethod
66 def image_block(cls, base64_data: str, mime_type: str = "image/png") -> "SamplingContentBlock":
67 return cls(type="image", data=base64_data, mime_type=mime_type)
70@dataclass
71class SamplingMessage:
72 """Sampling 消息。"""
73 role: SamplingRole
74 content: Union[str, List[SamplingContentBlock]]
76 def to_dict(self) -> dict:
77 if isinstance(self.content, str):
78 return {
79 "role": self.role.value,
80 "content": {"type": "text", "text": self.content},
81 }
82 return {
83 "role": self.role.value,
84 "content": [c.to_dict() for c in self.content],
85 }
87 @classmethod
88 def from_dict(cls, d: dict) -> "SamplingMessage":
89 role = SamplingRole(d["role"])
90 content_raw = d["content"]
91 if isinstance(content_raw, str):
92 content = content_raw
93 elif isinstance(content_raw, dict):
94 content = [SamplingContentBlock.from_dict(content_raw)]
95 elif isinstance(content_raw, list):
96 content = [SamplingContentBlock.from_dict(c) for c in content_raw]
97 else:
98 content = str(content_raw)
99 return cls(role=role, content=content)
102@dataclass
103class SamplingRequest:
104 """Server 发起的 Sampling 请求。
106 符合 MCP sampling/createMessage 规范。
107 """
109 messages: List[SamplingMessage]
110 model_preferences: Optional[Dict[str, Any]] = None
111 system_prompt: str = ""
112 include_context: str = "none" # none | thisServer | allServers
113 temperature: float = 0.7
114 max_tokens: int = 4096
115 stop_sequences: List[str] = field(default_factory=list)
116 metadata: Dict[str, Any] = field(default_factory=dict)
118 def to_dict(self) -> dict:
119 d: dict = {
120 "messages": [m.to_dict() for m in self.messages],
121 "maxTokens": self.max_tokens,
122 }
123 if self.model_preferences:
124 d["modelPreferences"] = self.model_preferences
125 if self.system_prompt:
126 d["systemPrompt"] = self.system_prompt
127 if self.include_context != "none":
128 d["includeContext"] = self.include_context
129 if self.temperature != 0.7:
130 d["temperature"] = self.temperature
131 if self.stop_sequences:
132 d["stopSequences"] = self.stop_sequences
133 if self.metadata:
134 d["metadata"] = self.metadata
135 return d
137 @classmethod
138 def from_dict(cls, d: dict) -> "SamplingRequest":
139 return cls(
140 messages=[SamplingMessage.from_dict(m) for m in d.get("messages", [])],
141 model_preferences=d.get("modelPreferences"),
142 system_prompt=d.get("systemPrompt", ""),
143 include_context=d.get("includeContext", "none"),
144 temperature=d.get("temperature", 0.7),
145 max_tokens=d.get("maxTokens", 4096),
146 stop_sequences=d.get("stopSequences", []),
147 metadata=d.get("metadata", {}),
148 )
151@dataclass
152class SamplingResponse:
153 """Sampling 响应。
155 Client 调用 LLM 后返回给 Server 的结果。
156 """
158 model: str = ""
159 role: SamplingRole = SamplingRole.ASSISTANT
160 content: Union[str, List[SamplingContentBlock]] = ""
161 stop_reason: str = "" # endTurn | stopSequence | maxTokens
162 metadata: Dict[str, Any] = field(default_factory=dict)
164 def to_dict(self) -> dict:
165 content_block: dict
166 if isinstance(self.content, str):
167 content_block = {"type": "text", "text": self.content}
168 elif isinstance(self.content, list):
169 content_block = [c.to_dict() for c in self.content]
170 else:
171 content_block = {"type": "text", "text": str(self.content)}
173 return {
174 "model": self.model,
175 "role": self.role.value,
176 "content": content_block,
177 "stopReason": self.stop_reason or "endTurn",
178 }
180 @classmethod
181 def from_dict(cls, d: dict) -> "SamplingResponse":
182 content_raw = d.get("content", "")
183 if isinstance(content_raw, str):
184 content = content_raw
185 elif isinstance(content_raw, dict):
186 content = [SamplingContentBlock.from_dict(content_raw)]
187 elif isinstance(content_raw, list):
188 content = [SamplingContentBlock.from_dict(c) for c in content_raw]
189 else:
190 content = str(content_raw)
191 return cls(
192 model=d.get("model", ""),
193 role=SamplingRole(d.get("role", "assistant")),
194 content=content,
195 stop_reason=d.get("stopReason", ""),
196 )
199# ── Sampling Handler ────────────────────────
202class SamplingError(Exception):
203 """Sampling 错误。"""
204 def __init__(self, code: int, message: str):
205 self.code = code
206 self.message = message
207 super().__init__(f"Sampling Error [{code}]: {message}")
210# LLM 调用接口
211LLMCallFn = Callable[
212 [SamplingRequest],
213 Coroutine[Any, Any, SamplingResponse],
214]
217class MCPClientSampling:
218 """MCP Client 端 Sampling 支持。
220 在 MCPClient 上挂载此 handler 后,MCP Server 可通过
221 `sampling/createMessage` 向 Client 发起 LLM 请求。
223 Usage:
224 client = MCPClient()
225 sampling = MCPClientSampling(my_llm_call_fn)
226 client.set_sampling_handler(sampling)
227 """
229 def __init__(
230 self,
231 llm_call_fn: LLMCallFn,
232 default_model: str = "claude-sonnet-4-20250514",
233 max_tokens_limit: int = 8192,
234 allow_image_input: bool = True,
235 ):
236 self._llm_call = llm_call_fn
237 self.default_model = default_model
238 self.max_tokens_limit = max_tokens_limit
239 self.allow_image_input = allow_image_input
241 async def handle_create_message(
242 self,
243 params: dict[str, Any],
244 ) -> dict[str, Any]:
245 """处理 sampling/createMessage 请求。
247 Args:
248 params: MCP 请求参数
250 Returns:
251 MCP JSON-RPC 响应 result 部分
252 """
253 try:
254 request = SamplingRequest.from_dict(params)
255 except Exception as e:
256 raise SamplingError(-32602, f"Invalid sampling request: {e}")
258 # Validate max_tokens
259 if request.max_tokens > self.max_tokens_limit:
260 request.max_tokens = self.max_tokens_limit
262 # Validate image input
263 if not self.allow_image_input:
264 for msg in request.messages:
265 if isinstance(msg.content, list):
266 has_image = any(
267 c.type == "image" for c in msg.content
268 )
269 if has_image:
270 raise SamplingError(
271 -32000,
272 "Image input is not allowed by client policy",
273 )
275 try:
276 response = await self._llm_call(request)
277 except Exception as e:
278 raise SamplingError(-32001, f"LLM call failed: {e}")
280 return {
281 "model": response.model or self.default_model,
282 "role": response.role.value,
283 "content": response.to_dict()["content"],
284 "stopReason": response.stop_reason or "endTurn",
285 }
288# ── Mock LLM for Testing ───────────────────
291async def mock_llm_call(request: SamplingRequest) -> SamplingResponse:
292 """Mock LLM 调用函数(用于测试)。
294 简单回显最后一条用户消息的内容。
295 """
296 last_msg = request.messages[-1] if request.messages else None
297 if last_msg and isinstance(last_msg.content, str):
298 content = f"[Mock] Echo: {last_msg.content[:100]}"
299 elif last_msg and isinstance(last_msg.content, list):
300 text_parts = [c.text for c in last_msg.content if c.type == "text"]
301 content = f"[Mock] Echo: {' '.join(text_parts)[:100]}"
302 else:
303 content = "[Mock] No input messages."
305 return SamplingResponse(
306 model="mock-model",
307 content=content,
308 stop_reason="endTurn",
309 )
312# ── Resource Templates ──────────────────────
315@dataclass
316class MCPResourceTemplate:
317 """MCP Resource Template(URI 模板)。
319 Server 可暴露参数化的资源模板,Client 可按模板实例化资源。
320 """
322 uri_template: str
323 name: str = ""
324 description: str = ""
325 mime_type: str = ""
326 annotations: Dict[str, Any] = field(default_factory=dict)
328 def to_dict(self) -> dict:
329 d: dict = {
330 "uriTemplate": self.uri_template,
331 "name": self.name,
332 }
333 if self.description:
334 d["description"] = self.description
335 if self.mime_type:
336 d["mimeType"] = self.mime_type
337 if self.annotations:
338 d["annotations"] = self.annotations
339 return d
341 @classmethod
342 def from_dict(cls, d: dict) -> "MCPResourceTemplate":
343 return cls(
344 uri_template=d.get("uriTemplate", ""),
345 name=d.get("name", ""),
346 description=d.get("description", ""),
347 mime_type=d.get("mimeType", ""),
348 annotations=d.get("annotations", {}),
349 )
352# ── MCP Logging ────────────────────────────
355class MCPLogLevel(str, Enum):
356 """MCP 日志级别。"""
357 DEBUG = "debug"
358 INFO = "info"
359 NOTICE = "notice"
360 WARNING = "warning"
361 ERROR = "error"
362 CRITICAL = "critical"
363 ALERT = "alert"
364 EMERGENCY = "emergency"
367class MCPLoggingHandler:
368 """MCP Client 端日志接收处理。
370 Client 可通过 `logging/setLevel` 设置日志级别,
371 Server 通过 `notifications/message` 推送日志消息。
372 """
374 def __init__(self, max_level: MCPLogLevel = MCPLogLevel.INFO):
375 self._max_level = max_level
376 self._log_callback: Optional[Callable[[MCPLogLevel, str, str], None]] = None
378 def set_log_callback(
379 self,
380 callback: Callable[[MCPLogLevel, str, str], None],
381 ) -> None:
382 """设置日志回调函数。
384 Args:
385 callback: fn(level, logger_name, message)
386 """
387 self._log_callback = callback
389 @property
390 def max_level(self) -> MCPLogLevel:
391 return self._max_level
393 def set_level(self, level: MCPLogLevel) -> None:
394 """Client 设置日志级别。"""
395 self._max_level = level
397 def _level_rank(self, level: MCPLogLevel) -> int:
398 levels = list(MCPLogLevel)
399 return levels.index(level)
401 def should_log(self, level: MCPLogLevel) -> bool:
402 """判断给定级别的日志是否应被记录。"""
403 return self._level_rank(level) >= self._level_rank(self._max_level)
405 def handle_log_message(
406 self,
407 level: MCPLogLevel,
408 logger_name: str,
409 message: str,
410 ) -> None:
411 """处理来自 Server 的日志消息。"""
412 if self.should_log(level) and self._log_callback:
413 self._log_callback(level, logger_name, message)
416# ── MCP Roots ──────────────────────────────
419@dataclass
420class MCPRoot:
421 """MCP Root — Client 暴露给 Server 的可访问文件系统根。"""
422 uri: str
423 name: str = ""
425 def to_dict(self) -> dict:
426 d: dict = {"uri": self.uri}
427 if self.name:
428 d["name"] = self.name
429 return d
431 @classmethod
432 def from_dict(cls, d: dict) -> "MCPRoot":
433 return cls(uri=d.get("uri", ""), name=d.get("name", ""))