Coverage for agentos/conversation/conversation.py: 39%

205 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1"""AgentOS v1.3.10 - Conversation Manager 模块。 

2 

3多轮对话上下文管理:滑动窗口、自动摘要、对话分支、token 感知裁剪。 

4适用于长会话场景,防止上下文溢出,同时保持关键信息不丢失。 

5""" 

6 

7from __future__ import annotations 

8 

9import hashlib 

10import time 

11from dataclasses import dataclass, field 

12from enum import Enum, auto 

13from typing import Any, Callable 

14 

15 

16class MessageRole(str, Enum): 

17 """消息角色。""" 

18 

19 SYSTEM = "system" 

20 USER = "user" 

21 ASSISTANT = "assistant" 

22 TOOL = "tool" 

23 

24 

25class TrimStrategy(Enum): 

26 """裁剪策略。""" 

27 

28 FIFO = auto() 

29 SUMMARIZE = auto() 

30 IMPORTANCE_WEIGHTED = auto() 

31 TOKEN_BUDGET = auto() 

32 

33 

34@dataclass 

35class Message: 

36 """单条对话消息。""" 

37 

38 role: MessageRole 

39 content: str 

40 timestamp: float = field(default_factory=time.time) 

41 token_count: int = 0 

42 importance: float = 1.0 

43 metadata: dict = field(default_factory=dict) 

44 message_id: str = "" 

45 

46 def __post_init__(self): 

47 if not self.message_id: 

48 raw = f"{self.role.value}:{self.content[:50]}:{self.timestamp}" 

49 self.message_id = hashlib.md5(raw.encode()).hexdigest()[:12] 

50 

51 

52@dataclass 

53class ConversationConfig: 

54 """对话管理配置。""" 

55 

56 max_messages: int = 50 

57 max_tokens: int = 8000 

58 trim_strategy: TrimStrategy = TrimStrategy.FIFO 

59 preserve_system: bool = True 

60 preserve_last_n: int = 4 

61 summary_prompt: str = "" 

62 auto_summarize_threshold: float = 0.75 

63 token_counter: Callable[[str], int] | None = None 

64 

65 

66@dataclass 

67class ConversationStats: 

68 """对话统计。""" 

69 

70 total_messages: int = 0 

71 total_tokens: int = 0 

72 trim_count: int = 0 

73 summarize_count: int = 0 

74 branch_count: int = 0 

75 oldest_timestamp: float = 0.0 

76 newest_timestamp: float = 0.0 

77 

78 

79@dataclass 

80class ConversationSnapshot: 

81 """对话快照(用于分支/恢复)。""" 

82 

83 messages: list[Message] 

84 stats: ConversationStats 

85 snapshot_id: str 

86 created_at: float = field(default_factory=time.time) 

87 label: str = "" 

88 

89 

90class ConversationManager: 

91 """多轮对话上下文管理器。 

92 

93 核心功能: 

94 - 滑动窗口:超出 max_messages/max_tokens 时自动裁剪 

95 - 自动摘要:超出阈值时压缩历史消息为摘要 

96 - 对话分支:支持 fork 创建分支,切换/合并分支 

97 - Token 感知:按 token 预算精确裁剪 

98 """ 

99 

100 def __init__(self, config: ConversationConfig | None = None): 

101 self.config = config or ConversationConfig() 

102 self._messages: list[Message] = [] 

103 self._summary: str = "" 

104 self.stats = ConversationStats() 

105 self._branches: dict[str, ConversationSnapshot] = {} 

106 self._current_branch: str = "main" 

107 self._message_counter: int = 0 

108 

109 # ── 消息管理 ────────────────────────────────────────────── 

110 

111 def add(self, role: MessageRole | str, content: str, **meta) -> Message: 

112 """添加一条消息,自动触发裁剪检查。""" 

113 if isinstance(role, str): 

114 role = MessageRole(role) 

115 token_count = self._count_tokens(content) 

116 msg = Message( 

117 role=role, 

118 content=content, 

119 token_count=token_count, 

120 message_id=self._next_id(), 

121 metadata=meta, 

122 ) 

123 self._messages.append(msg) 

124 self.stats.total_messages += 1 

125 self.stats.total_tokens += token_count 

126 if not self.stats.oldest_timestamp: 

127 self.stats.oldest_timestamp = msg.timestamp 

128 self.stats.newest_timestamp = msg.timestamp 

129 self._enforce_limits() 

130 return msg 

131 

132 def add_many(self, messages: list[tuple[str, str]]) -> list[Message]: 

133 """批量添加消息。""" 

134 return [self.add(role, content) for role, content in messages] 

135 

136 def get_context( 

137 self, include_summary: bool = True, limit: int | None = None 

138 ) -> list[dict]: 

139 """获取当前对话上下文,返回 OpenAI 兼容格式。""" 

140 result: list[dict] = [] 

141 if include_summary and self._summary: 

142 result.append({"role": "system", "content": f"[对话摘要] {self._summary}"}) 

143 msgs = self._messages[-limit:] if limit else self._messages 

144 for msg in msgs: 

145 result.append({"role": msg.role.value, "content": msg.content}) 

146 return result 

147 

148 def get_system_prompt(self) -> str: 

149 """提取 system 消息。""" 

150 for msg in self._messages: 

151 if msg.role == MessageRole.SYSTEM: 

152 return msg.content 

153 return "" 

154 

155 # ── 裁剪与压缩 ──────────────────────────────────────────── 

156 

157 def _enforce_limits(self): 

158 """检查并执行裁剪。""" 

159 changed = False 

160 while len(self._messages) > self.config.max_messages: 

161 self._trim_one() 

162 changed = True 

163 while self.stats.total_tokens > self.config.max_tokens: 

164 self._trim_one() 

165 changed = True 

166 if ( 

167 changed 

168 and self.config.trim_strategy == TrimStrategy.SUMMARIZE 

169 and self.config.summary_prompt 

170 ): 

171 self._update_summary() 

172 

173 def _trim_one(self): 

174 """按裁剪策略移除一条消息。""" 

175 if self.config.trim_strategy == TrimStrategy.FIFO: 

176 self._trim_fifo() 

177 elif self.config.trim_strategy == TrimStrategy.IMPORTANCE_WEIGHTED: 

178 self._trim_lowest_importance() 

179 elif self.config.trim_strategy == TrimStrategy.TOKEN_BUDGET: 

180 self._trim_token_budget() 

181 else: 

182 self._trim_fifo() 

183 

184 def _trim_fifo(self): 

185 """先进先出裁剪:移除最旧非保留消息。""" 

186 preserve = self.config.preserve_last_n 

187 for i, msg in enumerate(self._messages): 

188 if self.config.preserve_system and msg.role == MessageRole.SYSTEM: 

189 continue 

190 if len(self._messages) - i <= preserve: 

191 break 

192 self.stats.total_tokens -= msg.token_count 

193 self.stats.trim_count += 1 

194 self._messages.pop(i) 

195 return 

196 

197 def _trim_lowest_importance(self): 

198 """移除重要性最低的消息。""" 

199 preserve = self.config.preserve_last_n 

200 candidates = list(enumerate(self._messages)) 

201 if self.config.preserve_system: 

202 candidates = [(i, m) for i, m in candidates if m.role != MessageRole.SYSTEM] 

203 if len(candidates) <= preserve: 

204 return 

205 candidates = candidates[:-preserve] 

206 idx, _ = min(candidates, key=lambda x: x[1].importance) 

207 msg = self._messages.pop(idx) 

208 self.stats.total_tokens -= msg.token_count 

209 self.stats.trim_count += 1 

210 

211 def _trim_token_budget(self): 

212 """按 token 预算裁剪。""" 

213 budget = int(self.config.max_tokens * self.config.auto_summarize_threshold) 

214 preserve_last = self.config.preserve_last_n 

215 system_count = sum(1 for m in self._messages if m.role == MessageRole.SYSTEM and self.config.preserve_system) 

216 while self.stats.total_tokens > budget and len(self._messages) > preserve_last + system_count: 

217 for i, msg in enumerate(self._messages): 

218 if self.config.preserve_system and msg.role == MessageRole.SYSTEM: 

219 continue 

220 if len(self._messages) - i <= preserve_last: 

221 break 

222 self.stats.total_tokens -= msg.token_count 

223 self.stats.trim_count += 1 

224 self._messages.pop(i) 

225 break 

226 else: 

227 break 

228 

229 def _update_summary(self): 

230 """更新对话摘要(调用方需通过 summarize_callback 注入 LLM 实现)。""" 

231 self._summary = f"[共 {len(self._messages)} 条消息, {self.stats.total_tokens} tokens]" 

232 

233 def set_summarizer(self, callback: Callable[[list[Message]], str]): 

234 """注入摘要回调。""" 

235 self._summarizer = callback 

236 

237 # ── 对话分支 ────────────────────────────────────────────── 

238 

239 def fork(self, label: str = "") -> ConversationSnapshot: 

240 """创建对话分支快照。""" 

241 import uuid 

242 

243 sid = uuid.uuid4().hex[:8] 

244 snapshot = ConversationSnapshot( 

245 messages=list(self._messages), 

246 stats=ConversationStats( 

247 total_messages=self.stats.total_messages, 

248 total_tokens=self.stats.total_tokens, 

249 trim_count=self.stats.trim_count, 

250 summarize_count=self.stats.summarize_count, 

251 branch_count=self.stats.branch_count, 

252 oldest_timestamp=self.stats.oldest_timestamp, 

253 newest_timestamp=self.stats.newest_timestamp, 

254 ), 

255 snapshot_id=sid, 

256 label=label or f"branch-{sid}", 

257 ) 

258 self._branches[sid] = snapshot 

259 self.stats.branch_count += 1 

260 return snapshot 

261 

262 def switch_branch(self, snapshot_id: str): 

263 """切换到指定分支。""" 

264 snapshot = self._branches.get(snapshot_id) 

265 if not snapshot: 

266 raise KeyError(f"Branch '{snapshot_id}' not found") 

267 self._messages = list(snapshot.messages) 

268 self.stats = ConversationStats( 

269 total_messages=snapshot.stats.total_messages, 

270 total_tokens=snapshot.stats.total_tokens, 

271 trim_count=snapshot.stats.trim_count, 

272 summarize_count=snapshot.stats.summarize_count, 

273 branch_count=snapshot.stats.branch_count, 

274 oldest_timestamp=snapshot.stats.oldest_timestamp, 

275 newest_timestamp=snapshot.stats.newest_timestamp, 

276 ) 

277 self._current_branch = snapshot_id 

278 

279 def merge_branch(self, snapshot_id: str, strategy: str = "append"): 

280 """合并分支消息到当前对话。""" 

281 snapshot = self._branches.get(snapshot_id) 

282 if not snapshot: 

283 raise KeyError(f"Branch '{snapshot_id}' not found") 

284 if strategy == "append": 

285 existing_ids = {m.message_id for m in self._messages} 

286 for msg in snapshot.messages: 

287 if msg.message_id not in existing_ids: 

288 self._messages.append(msg) 

289 self.stats.total_messages += 1 

290 self.stats.total_tokens += msg.token_count 

291 elif strategy == "replace": 

292 self._messages = list(snapshot.messages) 

293 self._enforce_limits() 

294 

295 def list_branches(self) -> dict[str, ConversationSnapshot]: 

296 """列出所有分支。""" 

297 return dict(self._branches) 

298 

299 # ── 工具方法 ────────────────────────────────────────────── 

300 

301 def _count_tokens(self, text: str) -> int: 

302 """估算 token 数。""" 

303 if self.config.token_counter: 

304 return self.config.token_counter(text) 

305 return len(text) // 3 

306 

307 def _next_id(self) -> str: 

308 self._message_counter += 1 

309 return f"msg_{self._message_counter:06d}" 

310 

311 def clear(self, keep_system: bool = True): 

312 """清空对话历史。""" 

313 system_msgs = [m for m in self._messages if m.role == MessageRole.SYSTEM] if keep_system else [] 

314 self._messages = system_msgs 

315 self._summary = "" 

316 self.stats = ConversationStats() 

317 

318 @property 

319 def message_count(self) -> int: 

320 return len(self._messages) 

321 

322 @property 

323 def token_count(self) -> int: 

324 return self.stats.total_tokens 

325 

326 def __len__(self) -> int: 

327 return len(self._messages) 

328 

329 def __repr__(self) -> str: 

330 return f"<Conversation messages={len(self)} tokens={self.token_count} branches={len(self._branches)}>" 

331 

332 

333__all__ = [ 

334 "ConversationManager", 

335 "ConversationConfig", 

336 "ConversationStats", 

337 "ConversationSnapshot", 

338 "Message", 

339 "MessageRole", 

340 "TrimStrategy", 

341]