Coverage for agentos/channels/router.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2AgentOS Channels — 消息路由引擎。 

3 

4负责: 

5 1. 根据 webhook URL 路径匹配目标渠道适配器 

6 2. 维持渠道适配器实例注册表 

7 3. 提供统一的 send_message 接口(自动路由到正确渠道) 

8""" 

9 

10from __future__ import annotations 

11 

12import time 

13from typing import Optional 

14 

15from agentos.channels.base import BaseChannelAdapter, ChannelConfig, ReplyResult 

16from agentos.channels.message import ChannelMessage, ChannelType, ConversationContext 

17 

18 

19class ChannelRouter: 

20 """多渠道路由器 — 注册/查找/分发。 

21 

22 每个渠道适配器以 webhook_path 或 channel_type 注册。 

23 收到消息时按 path 匹配 > channel_type 匹配 > 默认适配器的顺序查找。 

24 """ 

25 

26 def __init__(self): 

27 self._adapters: dict[str, BaseChannelAdapter] = {} # webhook_path → adapter 

28 self._by_channel: dict[ChannelType, BaseChannelAdapter] = {} 

29 self._default: Optional[BaseChannelAdapter] = None 

30 self._contexts: dict[str, ConversationContext] = {} # session_id → context 

31 

32 # ── 注册 ── 

33 

34 def register(self, adapter: BaseChannelAdapter, webhook_path: str = "") -> None: 

35 """注册一个渠道适配器。""" 

36 channel = adapter.channel_type 

37 self._by_channel[channel] = adapter 

38 if webhook_path: 

39 self._adapters[webhook_path] = adapter 

40 adapter.config.webhook_path = webhook_path 

41 if len(self._adapters) == 1: 

42 self._default = adapter 

43 

44 def unregister(self, channel: ChannelType) -> None: 

45 """注销渠道适配器。""" 

46 adapter = self._by_channel.pop(channel, None) 

47 if adapter: 

48 paths_to_remove = [p for p, a in self._adapters.items() if a == adapter] 

49 for p in paths_to_remove: 

50 del self._adapters[p] 

51 if self._default == adapter: 

52 self._default = next(iter(self._adapters.values()), None) 

53 

54 # ── 查找 ── 

55 

56 def find(self, webhook_path: str = "", channel: Optional[ChannelType] = None) -> Optional[BaseChannelAdapter]: 

57 """按路径或渠道类型查找适配器。""" 

58 if webhook_path and webhook_path in self._adapters: 

59 return self._adapters[webhook_path] 

60 if channel and channel in self._by_channel: 

61 return self._by_channel[channel] 

62 return self._default 

63 

64 def get(self, channel: ChannelType) -> Optional[BaseChannelAdapter]: 

65 """按渠道类型获取适配器。""" 

66 return self._by_channel.get(channel) 

67 

68 # ── 会话管理 ── 

69 

70 def get_context(self, msg: ChannelMessage) -> ConversationContext: 

71 """获取或创建会话上下文。""" 

72 sid = msg.session_id or msg.sender_id 

73 if sid in self._contexts: 

74 ctx = self._contexts[sid] 

75 ctx.history.append({"role": "user", "content": msg.content, "timestamp": msg.timestamp}) 

76 if len(ctx.history) > 50: 

77 ctx.history = ctx.history[-50:] 

78 return ctx 

79 

80 ctx = ConversationContext( 

81 channel=msg.channel, 

82 user_id=msg.sender_id, 

83 session_id=sid, 

84 channel_config={"conversation_id": msg.conversation_id}, 

85 metadata={"first_message_at": time.time()}, 

86 ) 

87 ctx.history.append({"role": "user", "content": msg.content, "timestamp": msg.timestamp}) 

88 self._contexts[sid] = ctx 

89 return ctx 

90 

91 def update_context(self, session_id: str, reply_text: str) -> None: 

92 """将 Agent 回复记录到会话上下文。""" 

93 if session_id in self._contexts: 

94 self._contexts[session_id].history.append({ 

95 "role": "assistant", 

96 "content": reply_text, 

97 "timestamp": time.time(), 

98 }) 

99 

100 # ── 统一发送 ── 

101 

102 async def send(self, channel: ChannelType, user_id: str, content: str, 

103 msg_type: str = "text") -> ReplyResult: 

104 """统一发送接口 — 自动路由到目标渠道适配器。""" 

105 adapter = self._by_channel.get(channel) 

106 if not adapter: 

107 return ReplyResult(success=False, error=f"No adapter for {channel.value}") 

108 return await adapter.send_message(user_id, content, msg_type) 

109 

110 async def broadcast(self, content: str, msg_type: str = "text", 

111 exclude: Optional[ChannelType] = None) -> list[ReplyResult]: 

112 """向所有已注册渠道广播消息。""" 

113 results = [] 

114 for channel, adapter in self._by_channel.items(): 

115 if channel == exclude: 

116 continue 

117 results.append(ReplyResult(success=False, error=f"broadcast to {channel.value} requires explicit user_id")) 

118 return results 

119 

120 # ── 状态 ── 

121 

122 @property 

123 def active_channels(self) -> list[dict]: 

124 return [ 

125 { 

126 "channel": a.channel_type.value, 

127 "webhook_path": a.config.webhook_path, 

128 "enabled": a.config.enabled, 

129 } 

130 for a in self._by_channel.values() 

131 ] 

132 

133 @property 

134 def active_count(self) -> int: 

135 return sum(1 for a in self._by_channel.values() if a.config.enabled)