Coverage for agentos/channels/router.py: 0%
68 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2AgentOS Channels — 消息路由引擎。
4负责:
5 1. 根据 webhook URL 路径匹配目标渠道适配器
6 2. 维持渠道适配器实例注册表
7 3. 提供统一的 send_message 接口(自动路由到正确渠道)
8"""
10from __future__ import annotations
12import time
13from typing import Optional
15from agentos.channels.base import BaseChannelAdapter, ChannelConfig, ReplyResult
16from agentos.channels.message import ChannelMessage, ChannelType, ConversationContext
19class ChannelRouter:
20 """多渠道路由器 — 注册/查找/分发。
22 每个渠道适配器以 webhook_path 或 channel_type 注册。
23 收到消息时按 path 匹配 > channel_type 匹配 > 默认适配器的顺序查找。
24 """
26 def __init__(self):
27 self._adapters: dict[str, BaseChannelAdapter] = {} # webhook_path → adapter
28 self._by_channel: dict[ChannelType, BaseChannelAdapter] = {}
29 self._default: Optional[BaseChannelAdapter] = None
30 self._contexts: dict[str, ConversationContext] = {} # session_id → context
32 # ── 注册 ──
34 def register(self, adapter: BaseChannelAdapter, webhook_path: str = "") -> None:
35 """注册一个渠道适配器。"""
36 channel = adapter.channel_type
37 self._by_channel[channel] = adapter
38 if webhook_path:
39 self._adapters[webhook_path] = adapter
40 adapter.config.webhook_path = webhook_path
41 if len(self._adapters) == 1:
42 self._default = adapter
44 def unregister(self, channel: ChannelType) -> None:
45 """注销渠道适配器。"""
46 adapter = self._by_channel.pop(channel, None)
47 if adapter:
48 paths_to_remove = [p for p, a in self._adapters.items() if a == adapter]
49 for p in paths_to_remove:
50 del self._adapters[p]
51 if self._default == adapter:
52 self._default = next(iter(self._adapters.values()), None)
54 # ── 查找 ──
56 def find(self, webhook_path: str = "", channel: Optional[ChannelType] = None) -> Optional[BaseChannelAdapter]:
57 """按路径或渠道类型查找适配器。"""
58 if webhook_path and webhook_path in self._adapters:
59 return self._adapters[webhook_path]
60 if channel and channel in self._by_channel:
61 return self._by_channel[channel]
62 return self._default
64 def get(self, channel: ChannelType) -> Optional[BaseChannelAdapter]:
65 """按渠道类型获取适配器。"""
66 return self._by_channel.get(channel)
68 # ── 会话管理 ──
70 def get_context(self, msg: ChannelMessage) -> ConversationContext:
71 """获取或创建会话上下文。"""
72 sid = msg.session_id or msg.sender_id
73 if sid in self._contexts:
74 ctx = self._contexts[sid]
75 ctx.history.append({"role": "user", "content": msg.content, "timestamp": msg.timestamp})
76 if len(ctx.history) > 50:
77 ctx.history = ctx.history[-50:]
78 return ctx
80 ctx = ConversationContext(
81 channel=msg.channel,
82 user_id=msg.sender_id,
83 session_id=sid,
84 channel_config={"conversation_id": msg.conversation_id},
85 metadata={"first_message_at": time.time()},
86 )
87 ctx.history.append({"role": "user", "content": msg.content, "timestamp": msg.timestamp})
88 self._contexts[sid] = ctx
89 return ctx
91 def update_context(self, session_id: str, reply_text: str) -> None:
92 """将 Agent 回复记录到会话上下文。"""
93 if session_id in self._contexts:
94 self._contexts[session_id].history.append({
95 "role": "assistant",
96 "content": reply_text,
97 "timestamp": time.time(),
98 })
100 # ── 统一发送 ──
102 async def send(self, channel: ChannelType, user_id: str, content: str,
103 msg_type: str = "text") -> ReplyResult:
104 """统一发送接口 — 自动路由到目标渠道适配器。"""
105 adapter = self._by_channel.get(channel)
106 if not adapter:
107 return ReplyResult(success=False, error=f"No adapter for {channel.value}")
108 return await adapter.send_message(user_id, content, msg_type)
110 async def broadcast(self, content: str, msg_type: str = "text",
111 exclude: Optional[ChannelType] = None) -> list[ReplyResult]:
112 """向所有已注册渠道广播消息。"""
113 results = []
114 for channel, adapter in self._by_channel.items():
115 if channel == exclude:
116 continue
117 results.append(ReplyResult(success=False, error=f"broadcast to {channel.value} requires explicit user_id"))
118 return results
120 # ── 状态 ──
122 @property
123 def active_channels(self) -> list[dict]:
124 return [
125 {
126 "channel": a.channel_type.value,
127 "webhook_path": a.config.webhook_path,
128 "enabled": a.config.enabled,
129 }
130 for a in self._by_channel.values()
131 ]
133 @property
134 def active_count(self) -> int:
135 return sum(1 for a in self._by_channel.values() if a.config.enabled)