Coverage for agentos/channels/gateway.py: 0%

128 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2AgentOS Channels — 消息网关。 

3 

4提供统一的 FastAPI 应用,一站式接入所有渠道 webhook: 

5 - POST /webhook/wechat 微信公众号 

6 - POST /webhook/wecom 企业微信 

7 - POST /webhook/feishu 飞书 

8 - POST /webhook/dingtalk 钉钉 

9 - POST /webhook/qq QQ 

10 

11同时提供管理接口: 

12 - GET /channels 列出所有已注册渠道 

13 - GET /health 健康检查 

14 

15依赖: 

16 pip install fastapi uvicorn 

17""" 

18 

19from __future__ import annotations 

20 

21import asyncio 

22import hashlib 

23import json 

24import logging 

25import time 

26from pathlib import Path 

27from typing import Any, Callable, Optional 

28 

29from agentos.channels.base import BaseChannelAdapter, ChannelConfig, ReplyResult 

30from agentos.channels.message import ChannelMessage, ChannelType, MessageType, ConversationContext 

31from agentos.channels.router import ChannelRouter 

32 

33logger = logging.getLogger("agentos.channels.gateway") 

34 

35_router = ChannelRouter() 

36 

37# ── 用户自定义消息处理器 ── 

38 

39_on_message: Optional[Callable[[ChannelMessage, ConversationContext], Any]] = None 

40 

41 

42def on_message(handler: Callable[[ChannelMessage, ConversationContext], Any]): 

43 """装饰器:注册消息处理器。 

44 

45 当任意渠道收到用户消息时,自动调用此 handler。 

46 handler 签名: async def handler(msg: ChannelMessage, ctx: ConversationContext) -> str | dict 

47 返回字符串直接作为回复;返回 dict 可包含 {"reply": "...", "image": "..."} 等。 

48 """ 

49 global _on_message 

50 _on_message = handler 

51 return handler 

52 

53 

54def _check_signature(adapter: BaseChannelAdapter, raw_data: bytes, headers: dict) -> bool: 

55 """检查请求签名(安全校验)。""" 

56 return adapter.verify_signature(raw_data, headers) 

57 

58 

59def _process_message_async(adapter: BaseChannelAdapter, raw_data: bytes, 

60 headers: dict) -> Any: 

61 """异步处理消息:签名校验 → 解析 → 路由 → 调用 handler → 回复。 

62 

63 此函数作为同步入口,内部使用 asyncio.run 驱动异步流程。 

64 符合 FastAPI 的同步/异步兼容要求。 

65 """ 

66 loop = asyncio.get_event_loop() 

67 if loop.is_running(): 

68 # 已在事件循环中,创建 task 

69 return _process_message_coro(adapter, raw_data, headers) 

70 return asyncio.run(_process_message_coro(adapter, raw_data, headers)) 

71 

72 

73async def _process_message_coro(adapter: BaseChannelAdapter, raw_data: bytes, 

74 headers: dict) -> Any: 

75 """消息处理协程。""" 

76 # 1. 签名校验 

77 if not _check_signature(adapter, raw_data, headers): 

78 logger.warning(f"[{adapter.channel_type.value}] 签名校验失败") 

79 return adapter.build_reply( 

80 ChannelMessage(channel=adapter.channel_type, content="signature_error"), 

81 "签名校验失败", success=False, 

82 ) 

83 

84 try: 

85 # 2. 解析消息 

86 msg: ChannelMessage = adapter.parse_webhook(raw_data, headers) 

87 except Exception as e: 

88 logger.exception(f"[{adapter.channel_type.value}] 消息解析失败: {e}") 

89 return adapter.build_reply( 

90 ChannelMessage(channel=adapter.channel_type, content="parse_error"), 

91 "消息解析失败", success=False, 

92 ) 

93 

94 if not msg.content and not msg.media_url: 

95 # 空消息/心跳 — 直接返回 200 

96 return adapter.build_reply(msg, "", success=True) 

97 

98 # 3. 获取会话上下文 

99 ctx = _router.get_context(msg) 

100 

101 # 4. 调用用户注册的消息处理器 

102 reply_text = "" 

103 try: 

104 if _on_message: 

105 result = _on_message(msg, ctx) 

106 if asyncio.iscoroutine(result): 

107 result = await result 

108 if isinstance(result, str): 

109 reply_text = result 

110 elif isinstance(result, dict): 

111 reply_text = result.get("reply", "") 

112 else: 

113 reply_text = f"[AgentOS Gateway] 收到来自 {msg.channel.value} 的消息,但未注册消息处理器。" 

114 except Exception as e: 

115 logger.exception(f"[{adapter.channel_type.value}] handler 异常: {e}") 

116 reply_text = f"处理失败: {e}" 

117 

118 # 5. 记录到上下文 

119 if reply_text: 

120 _router.update_context(msg.session_id or msg.sender_id, reply_text) 

121 

122 # 6. 构建回复 

123 return adapter.build_reply(msg, reply_text, success=True) 

124 

125 

126# ── FASTAPI 应用工厂 ── 

127 

128def create_app(title: str = "AgentOS Channel Gateway", 

129 version: str = "1.0.0", 

130 adapter_webhook_paths: Optional[dict[ChannelType, str]] = None, 

131 ) -> Any: 

132 """创建 FastAPI 应用实例。 

133 

134 Args: 

135 title: 应用标题 

136 version: 版本号 

137 adapter_webhook_paths: 渠道 → webhook 路径映射,如 {ChannelType.WECHAT: "/webhook/wechat"} 

138 

139 Returns: 

140 FastAPI app 实例 

141 

142 用法: 

143 from agentos.channels.gateway import create_app, on_message 

144 from agentos.channels.adapters import WeChatAdapter, WeComAdapter, FeishuAdapter 

145 from agentos.channels.message import ChannelType 

146 

147 app = create_app(adapter_webhook_paths={ 

148 ChannelType.WECHAT: "/webhook/wechat", 

149 ChannelType.WECOM: "/webhook/wecom", 

150 }) 

151 

152 @on_message 

153 async def handle_message(msg, ctx): 

154 return f"Echo: {msg.content}" 

155 

156 if __name__ == "__main__": 

157 import uvicorn 

158 uvicorn.run(app, host="0.0.0.0", port=8000) 

159 """ 

160 try: 

161 from fastapi import FastAPI, Request, HTTPException 

162 from fastapi.responses import JSONResponse, PlainTextResponse, Response 

163 except ImportError: 

164 raise ImportError( 

165 "agentos.channels.gateway requires fastapi and uvicorn. " 

166 "Install with: pip install fastapi uvicorn" 

167 ) 

168 

169 app = FastAPI(title=title, version=version) 

170 

171 # 默认 webhook 路径映射 

172 default_paths: dict[ChannelType, str] = { 

173 ChannelType.WECHAT_MP: "/webhook/wechat", 

174 ChannelType.WECOM: "/webhook/wecom", 

175 ChannelType.FEISHU: "/webhook/feishu", 

176 ChannelType.DINGTALK: "/webhook/dingtalk", 

177 ChannelType.QQ: "/webhook/qq", 

178 } 

179 

180 if adapter_webhook_paths: 

181 default_paths.update(adapter_webhook_paths) 

182 

183 # ── 注册渠道 webhook 路由 ── 

184 

185 def _make_webhook_handler(adapter: BaseChannelAdapter): 

186 """为每个渠道适配器创建 webhook handler。""" 

187 async def handler(request: Request): 

188 # 读取原始 body 

189 try: 

190 raw_data = await request.body() 

191 except Exception: 

192 raw_data = b"" 

193 

194 # 处理 GET 请求(飞书 URL 验证、微信 Token 验证) 

195 if request.method == "GET": 

196 query = dict(request.query_params) 

197 if adapter.channel_type == ChannelType.FEISHU: 

198 return adapter.build_reply( 

199 ChannelMessage(channel=ChannelType.FEISHU, content=""), 

200 query.get("challenge", ""), success=True, 

201 ) 

202 if adapter.channel_type == ChannelType.WECHAT_MP: 

203 # 微信公众号 URL 验证 

204 token = adapter.config.token 

205 sig = query.get("signature", "") 

206 ts = query.get("timestamp", "") 

207 nonce = query.get("nonce", "") 

208 echostr = query.get("echostr", "") 

209 if token: 

210 params = sorted([token, ts, nonce]) 

211 expected = hashlib.sha1("".join(params).encode()).hexdigest() 

212 if sig == expected: 

213 return PlainTextResponse(echostr) 

214 else: 

215 raise HTTPException(status_code=403, detail="Signature verification failed") 

216 return PlainTextResponse(echostr) 

217 return JSONResponse({"status": "ok"}) 

218 

219 # POST 请求 — 消息处理 

220 headers = dict(request.headers) 

221 try: 

222 result = await _process_message_coro(adapter, raw_data, headers) 

223 if result is None: 

224 return JSONResponse({"status": "ok"}) 

225 if isinstance(result, (str, bytes)): 

226 return Response(content=result if isinstance(result, bytes) else result.encode(), 

227 media_type="application/xml" if adapter.channel_type in 

228 (ChannelType.WECHAT_MP, ChannelType.WECOM) else "application/json") 

229 if isinstance(result, dict): 

230 return JSONResponse(result) 

231 return result 

232 except Exception as e: 

233 logger.exception(f"Gateway handler error: {e}") 

234 raise HTTPException(status_code=500, detail=str(e)) 

235 

236 return handler 

237 

238 # 从已注册适配器动态创建路由 

239 # 注意:路由在 app startup 事件中注册,因为适配器可能在 create_app 之后才注册 

240 @app.on_event("startup") 

241 async def _register_routes(): 

242 """应用启动时注册所有渠道路由。""" 

243 channel_adapters: dict[ChannelType, BaseChannelAdapter] = {} 

244 for webhook_path, adapter in _router._adapters.items(): 

245 channel_adapters[adapter.channel_type] = adapter 

246 for channel, adapter in _router._by_channel.items(): 

247 if channel not in channel_adapters: 

248 channel_adapters[channel] = adapter 

249 if channel in default_paths: 

250 _router._adapters[default_paths[channel]] = adapter 

251 

252 for channel, adapter in channel_adapters.items(): 

253 webhook_path = adapter.config.webhook_path or default_paths.get(channel, "") 

254 if webhook_path: 

255 _router._adapters[webhook_path] = adapter 

256 app.add_api_route(webhook_path, _make_webhook_handler(adapter), 

257 methods=["GET", "POST"], name=f"webhook_{channel.value}") 

258 

259 # ── 管理接口 ── 

260 

261 @app.get("/channels") 

262 async def list_channels(): 

263 return { 

264 "channels": _router.active_channels, 

265 "active_count": _router.active_count, 

266 "timestamp": time.time(), 

267 } 

268 

269 @app.get("/health") 

270 async def health_check(): 

271 return {"status": "healthy", "channels": _router.active_count, "timestamp": time.time()} 

272 

273 return app 

274 

275 

276# ── 便利函数 ── 

277 

278def get_router() -> ChannelRouter: 

279 """获取全局路由器实例。""" 

280 return _router 

281 

282 

283def register_adapter(adapter: BaseChannelAdapter, webhook_path: str = "") -> ChannelRouter: 

284 """手动注册渠道适配器到全局路由器。""" 

285 _router.register(adapter, webhook_path) 

286 return _router