Coverage for agentos/channels/gateway.py: 0%
128 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2AgentOS Channels — 消息网关。
4提供统一的 FastAPI 应用,一站式接入所有渠道 webhook:
5 - POST /webhook/wechat 微信公众号
6 - POST /webhook/wecom 企业微信
7 - POST /webhook/feishu 飞书
8 - POST /webhook/dingtalk 钉钉
9 - POST /webhook/qq QQ
11同时提供管理接口:
12 - GET /channels 列出所有已注册渠道
13 - GET /health 健康检查
15依赖:
16 pip install fastapi uvicorn
17"""
19from __future__ import annotations
21import asyncio
22import hashlib
23import json
24import logging
25import time
26from pathlib import Path
27from typing import Any, Callable, Optional
29from agentos.channels.base import BaseChannelAdapter, ChannelConfig, ReplyResult
30from agentos.channels.message import ChannelMessage, ChannelType, MessageType, ConversationContext
31from agentos.channels.router import ChannelRouter
33logger = logging.getLogger("agentos.channels.gateway")
35_router = ChannelRouter()
37# ── 用户自定义消息处理器 ──
39_on_message: Optional[Callable[[ChannelMessage, ConversationContext], Any]] = None
42def on_message(handler: Callable[[ChannelMessage, ConversationContext], Any]):
43 """装饰器:注册消息处理器。
45 当任意渠道收到用户消息时,自动调用此 handler。
46 handler 签名: async def handler(msg: ChannelMessage, ctx: ConversationContext) -> str | dict
47 返回字符串直接作为回复;返回 dict 可包含 {"reply": "...", "image": "..."} 等。
48 """
49 global _on_message
50 _on_message = handler
51 return handler
54def _check_signature(adapter: BaseChannelAdapter, raw_data: bytes, headers: dict) -> bool:
55 """检查请求签名(安全校验)。"""
56 return adapter.verify_signature(raw_data, headers)
59def _process_message_async(adapter: BaseChannelAdapter, raw_data: bytes,
60 headers: dict) -> Any:
61 """异步处理消息:签名校验 → 解析 → 路由 → 调用 handler → 回复。
63 此函数作为同步入口,内部使用 asyncio.run 驱动异步流程。
64 符合 FastAPI 的同步/异步兼容要求。
65 """
66 loop = asyncio.get_event_loop()
67 if loop.is_running():
68 # 已在事件循环中,创建 task
69 return _process_message_coro(adapter, raw_data, headers)
70 return asyncio.run(_process_message_coro(adapter, raw_data, headers))
73async def _process_message_coro(adapter: BaseChannelAdapter, raw_data: bytes,
74 headers: dict) -> Any:
75 """消息处理协程。"""
76 # 1. 签名校验
77 if not _check_signature(adapter, raw_data, headers):
78 logger.warning(f"[{adapter.channel_type.value}] 签名校验失败")
79 return adapter.build_reply(
80 ChannelMessage(channel=adapter.channel_type, content="signature_error"),
81 "签名校验失败", success=False,
82 )
84 try:
85 # 2. 解析消息
86 msg: ChannelMessage = adapter.parse_webhook(raw_data, headers)
87 except Exception as e:
88 logger.exception(f"[{adapter.channel_type.value}] 消息解析失败: {e}")
89 return adapter.build_reply(
90 ChannelMessage(channel=adapter.channel_type, content="parse_error"),
91 "消息解析失败", success=False,
92 )
94 if not msg.content and not msg.media_url:
95 # 空消息/心跳 — 直接返回 200
96 return adapter.build_reply(msg, "", success=True)
98 # 3. 获取会话上下文
99 ctx = _router.get_context(msg)
101 # 4. 调用用户注册的消息处理器
102 reply_text = ""
103 try:
104 if _on_message:
105 result = _on_message(msg, ctx)
106 if asyncio.iscoroutine(result):
107 result = await result
108 if isinstance(result, str):
109 reply_text = result
110 elif isinstance(result, dict):
111 reply_text = result.get("reply", "")
112 else:
113 reply_text = f"[AgentOS Gateway] 收到来自 {msg.channel.value} 的消息,但未注册消息处理器。"
114 except Exception as e:
115 logger.exception(f"[{adapter.channel_type.value}] handler 异常: {e}")
116 reply_text = f"处理失败: {e}"
118 # 5. 记录到上下文
119 if reply_text:
120 _router.update_context(msg.session_id or msg.sender_id, reply_text)
122 # 6. 构建回复
123 return adapter.build_reply(msg, reply_text, success=True)
126# ── FASTAPI 应用工厂 ──
128def create_app(title: str = "AgentOS Channel Gateway",
129 version: str = "1.0.0",
130 adapter_webhook_paths: Optional[dict[ChannelType, str]] = None,
131 ) -> Any:
132 """创建 FastAPI 应用实例。
134 Args:
135 title: 应用标题
136 version: 版本号
137 adapter_webhook_paths: 渠道 → webhook 路径映射,如 {ChannelType.WECHAT: "/webhook/wechat"}
139 Returns:
140 FastAPI app 实例
142 用法:
143 from agentos.channels.gateway import create_app, on_message
144 from agentos.channels.adapters import WeChatAdapter, WeComAdapter, FeishuAdapter
145 from agentos.channels.message import ChannelType
147 app = create_app(adapter_webhook_paths={
148 ChannelType.WECHAT: "/webhook/wechat",
149 ChannelType.WECOM: "/webhook/wecom",
150 })
152 @on_message
153 async def handle_message(msg, ctx):
154 return f"Echo: {msg.content}"
156 if __name__ == "__main__":
157 import uvicorn
158 uvicorn.run(app, host="0.0.0.0", port=8000)
159 """
160 try:
161 from fastapi import FastAPI, Request, HTTPException
162 from fastapi.responses import JSONResponse, PlainTextResponse, Response
163 except ImportError:
164 raise ImportError(
165 "agentos.channels.gateway requires fastapi and uvicorn. "
166 "Install with: pip install fastapi uvicorn"
167 )
169 app = FastAPI(title=title, version=version)
171 # 默认 webhook 路径映射
172 default_paths: dict[ChannelType, str] = {
173 ChannelType.WECHAT_MP: "/webhook/wechat",
174 ChannelType.WECOM: "/webhook/wecom",
175 ChannelType.FEISHU: "/webhook/feishu",
176 ChannelType.DINGTALK: "/webhook/dingtalk",
177 ChannelType.QQ: "/webhook/qq",
178 }
180 if adapter_webhook_paths:
181 default_paths.update(adapter_webhook_paths)
183 # ── 注册渠道 webhook 路由 ──
185 def _make_webhook_handler(adapter: BaseChannelAdapter):
186 """为每个渠道适配器创建 webhook handler。"""
187 async def handler(request: Request):
188 # 读取原始 body
189 try:
190 raw_data = await request.body()
191 except Exception:
192 raw_data = b""
194 # 处理 GET 请求(飞书 URL 验证、微信 Token 验证)
195 if request.method == "GET":
196 query = dict(request.query_params)
197 if adapter.channel_type == ChannelType.FEISHU:
198 return adapter.build_reply(
199 ChannelMessage(channel=ChannelType.FEISHU, content=""),
200 query.get("challenge", ""), success=True,
201 )
202 if adapter.channel_type == ChannelType.WECHAT_MP:
203 # 微信公众号 URL 验证
204 token = adapter.config.token
205 sig = query.get("signature", "")
206 ts = query.get("timestamp", "")
207 nonce = query.get("nonce", "")
208 echostr = query.get("echostr", "")
209 if token:
210 params = sorted([token, ts, nonce])
211 expected = hashlib.sha1("".join(params).encode()).hexdigest()
212 if sig == expected:
213 return PlainTextResponse(echostr)
214 else:
215 raise HTTPException(status_code=403, detail="Signature verification failed")
216 return PlainTextResponse(echostr)
217 return JSONResponse({"status": "ok"})
219 # POST 请求 — 消息处理
220 headers = dict(request.headers)
221 try:
222 result = await _process_message_coro(adapter, raw_data, headers)
223 if result is None:
224 return JSONResponse({"status": "ok"})
225 if isinstance(result, (str, bytes)):
226 return Response(content=result if isinstance(result, bytes) else result.encode(),
227 media_type="application/xml" if adapter.channel_type in
228 (ChannelType.WECHAT_MP, ChannelType.WECOM) else "application/json")
229 if isinstance(result, dict):
230 return JSONResponse(result)
231 return result
232 except Exception as e:
233 logger.exception(f"Gateway handler error: {e}")
234 raise HTTPException(status_code=500, detail=str(e))
236 return handler
238 # 从已注册适配器动态创建路由
239 # 注意:路由在 app startup 事件中注册,因为适配器可能在 create_app 之后才注册
240 @app.on_event("startup")
241 async def _register_routes():
242 """应用启动时注册所有渠道路由。"""
243 channel_adapters: dict[ChannelType, BaseChannelAdapter] = {}
244 for webhook_path, adapter in _router._adapters.items():
245 channel_adapters[adapter.channel_type] = adapter
246 for channel, adapter in _router._by_channel.items():
247 if channel not in channel_adapters:
248 channel_adapters[channel] = adapter
249 if channel in default_paths:
250 _router._adapters[default_paths[channel]] = adapter
252 for channel, adapter in channel_adapters.items():
253 webhook_path = adapter.config.webhook_path or default_paths.get(channel, "")
254 if webhook_path:
255 _router._adapters[webhook_path] = adapter
256 app.add_api_route(webhook_path, _make_webhook_handler(adapter),
257 methods=["GET", "POST"], name=f"webhook_{channel.value}")
259 # ── 管理接口 ──
261 @app.get("/channels")
262 async def list_channels():
263 return {
264 "channels": _router.active_channels,
265 "active_count": _router.active_count,
266 "timestamp": time.time(),
267 }
269 @app.get("/health")
270 async def health_check():
271 return {"status": "healthy", "channels": _router.active_count, "timestamp": time.time()}
273 return app
276# ── 便利函数 ──
278def get_router() -> ChannelRouter:
279 """获取全局路由器实例。"""
280 return _router
283def register_adapter(adapter: BaseChannelAdapter, webhook_path: str = "") -> ChannelRouter:
284 """手动注册渠道适配器到全局路由器。"""
285 _router.register(adapter, webhook_path)
286 return _router