Coverage for agentos/conversation/conversation.py: 39%
205 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""AgentOS v1.3.10 - Conversation Manager 模块。
3多轮对话上下文管理:滑动窗口、自动摘要、对话分支、token 感知裁剪。
4适用于长会话场景,防止上下文溢出,同时保持关键信息不丢失。
5"""
7from __future__ import annotations
9import hashlib
10import time
11from dataclasses import dataclass, field
12from enum import Enum, auto
13from typing import Any, Callable
16class MessageRole(str, Enum):
17 """消息角色。"""
19 SYSTEM = "system"
20 USER = "user"
21 ASSISTANT = "assistant"
22 TOOL = "tool"
25class TrimStrategy(Enum):
26 """裁剪策略。"""
28 FIFO = auto()
29 SUMMARIZE = auto()
30 IMPORTANCE_WEIGHTED = auto()
31 TOKEN_BUDGET = auto()
34@dataclass
35class Message:
36 """单条对话消息。"""
38 role: MessageRole
39 content: str
40 timestamp: float = field(default_factory=time.time)
41 token_count: int = 0
42 importance: float = 1.0
43 metadata: dict = field(default_factory=dict)
44 message_id: str = ""
46 def __post_init__(self):
47 if not self.message_id:
48 raw = f"{self.role.value}:{self.content[:50]}:{self.timestamp}"
49 self.message_id = hashlib.md5(raw.encode()).hexdigest()[:12]
52@dataclass
53class ConversationConfig:
54 """对话管理配置。"""
56 max_messages: int = 50
57 max_tokens: int = 8000
58 trim_strategy: TrimStrategy = TrimStrategy.FIFO
59 preserve_system: bool = True
60 preserve_last_n: int = 4
61 summary_prompt: str = ""
62 auto_summarize_threshold: float = 0.75
63 token_counter: Callable[[str], int] | None = None
66@dataclass
67class ConversationStats:
68 """对话统计。"""
70 total_messages: int = 0
71 total_tokens: int = 0
72 trim_count: int = 0
73 summarize_count: int = 0
74 branch_count: int = 0
75 oldest_timestamp: float = 0.0
76 newest_timestamp: float = 0.0
79@dataclass
80class ConversationSnapshot:
81 """对话快照(用于分支/恢复)。"""
83 messages: list[Message]
84 stats: ConversationStats
85 snapshot_id: str
86 created_at: float = field(default_factory=time.time)
87 label: str = ""
90class ConversationManager:
91 """多轮对话上下文管理器。
93 核心功能:
94 - 滑动窗口:超出 max_messages/max_tokens 时自动裁剪
95 - 自动摘要:超出阈值时压缩历史消息为摘要
96 - 对话分支:支持 fork 创建分支,切换/合并分支
97 - Token 感知:按 token 预算精确裁剪
98 """
100 def __init__(self, config: ConversationConfig | None = None):
101 self.config = config or ConversationConfig()
102 self._messages: list[Message] = []
103 self._summary: str = ""
104 self.stats = ConversationStats()
105 self._branches: dict[str, ConversationSnapshot] = {}
106 self._current_branch: str = "main"
107 self._message_counter: int = 0
109 # ── 消息管理 ──────────────────────────────────────────────
111 def add(self, role: MessageRole | str, content: str, **meta) -> Message:
112 """添加一条消息,自动触发裁剪检查。"""
113 if isinstance(role, str):
114 role = MessageRole(role)
115 token_count = self._count_tokens(content)
116 msg = Message(
117 role=role,
118 content=content,
119 token_count=token_count,
120 message_id=self._next_id(),
121 metadata=meta,
122 )
123 self._messages.append(msg)
124 self.stats.total_messages += 1
125 self.stats.total_tokens += token_count
126 if not self.stats.oldest_timestamp:
127 self.stats.oldest_timestamp = msg.timestamp
128 self.stats.newest_timestamp = msg.timestamp
129 self._enforce_limits()
130 return msg
132 def add_many(self, messages: list[tuple[str, str]]) -> list[Message]:
133 """批量添加消息。"""
134 return [self.add(role, content) for role, content in messages]
136 def get_context(
137 self, include_summary: bool = True, limit: int | None = None
138 ) -> list[dict]:
139 """获取当前对话上下文,返回 OpenAI 兼容格式。"""
140 result: list[dict] = []
141 if include_summary and self._summary:
142 result.append({"role": "system", "content": f"[对话摘要] {self._summary}"})
143 msgs = self._messages[-limit:] if limit else self._messages
144 for msg in msgs:
145 result.append({"role": msg.role.value, "content": msg.content})
146 return result
148 def get_system_prompt(self) -> str:
149 """提取 system 消息。"""
150 for msg in self._messages:
151 if msg.role == MessageRole.SYSTEM:
152 return msg.content
153 return ""
155 # ── 裁剪与压缩 ────────────────────────────────────────────
157 def _enforce_limits(self):
158 """检查并执行裁剪。"""
159 changed = False
160 while len(self._messages) > self.config.max_messages:
161 self._trim_one()
162 changed = True
163 while self.stats.total_tokens > self.config.max_tokens:
164 self._trim_one()
165 changed = True
166 if (
167 changed
168 and self.config.trim_strategy == TrimStrategy.SUMMARIZE
169 and self.config.summary_prompt
170 ):
171 self._update_summary()
173 def _trim_one(self):
174 """按裁剪策略移除一条消息。"""
175 if self.config.trim_strategy == TrimStrategy.FIFO:
176 self._trim_fifo()
177 elif self.config.trim_strategy == TrimStrategy.IMPORTANCE_WEIGHTED:
178 self._trim_lowest_importance()
179 elif self.config.trim_strategy == TrimStrategy.TOKEN_BUDGET:
180 self._trim_token_budget()
181 else:
182 self._trim_fifo()
184 def _trim_fifo(self):
185 """先进先出裁剪:移除最旧非保留消息。"""
186 preserve = self.config.preserve_last_n
187 for i, msg in enumerate(self._messages):
188 if self.config.preserve_system and msg.role == MessageRole.SYSTEM:
189 continue
190 if len(self._messages) - i <= preserve:
191 break
192 self.stats.total_tokens -= msg.token_count
193 self.stats.trim_count += 1
194 self._messages.pop(i)
195 return
197 def _trim_lowest_importance(self):
198 """移除重要性最低的消息。"""
199 preserve = self.config.preserve_last_n
200 candidates = list(enumerate(self._messages))
201 if self.config.preserve_system:
202 candidates = [(i, m) for i, m in candidates if m.role != MessageRole.SYSTEM]
203 if len(candidates) <= preserve:
204 return
205 candidates = candidates[:-preserve]
206 idx, _ = min(candidates, key=lambda x: x[1].importance)
207 msg = self._messages.pop(idx)
208 self.stats.total_tokens -= msg.token_count
209 self.stats.trim_count += 1
211 def _trim_token_budget(self):
212 """按 token 预算裁剪。"""
213 budget = int(self.config.max_tokens * self.config.auto_summarize_threshold)
214 preserve_last = self.config.preserve_last_n
215 system_count = sum(1 for m in self._messages if m.role == MessageRole.SYSTEM and self.config.preserve_system)
216 while self.stats.total_tokens > budget and len(self._messages) > preserve_last + system_count:
217 for i, msg in enumerate(self._messages):
218 if self.config.preserve_system and msg.role == MessageRole.SYSTEM:
219 continue
220 if len(self._messages) - i <= preserve_last:
221 break
222 self.stats.total_tokens -= msg.token_count
223 self.stats.trim_count += 1
224 self._messages.pop(i)
225 break
226 else:
227 break
229 def _update_summary(self):
230 """更新对话摘要(调用方需通过 summarize_callback 注入 LLM 实现)。"""
231 self._summary = f"[共 {len(self._messages)} 条消息, {self.stats.total_tokens} tokens]"
233 def set_summarizer(self, callback: Callable[[list[Message]], str]):
234 """注入摘要回调。"""
235 self._summarizer = callback
237 # ── 对话分支 ──────────────────────────────────────────────
239 def fork(self, label: str = "") -> ConversationSnapshot:
240 """创建对话分支快照。"""
241 import uuid
243 sid = uuid.uuid4().hex[:8]
244 snapshot = ConversationSnapshot(
245 messages=list(self._messages),
246 stats=ConversationStats(
247 total_messages=self.stats.total_messages,
248 total_tokens=self.stats.total_tokens,
249 trim_count=self.stats.trim_count,
250 summarize_count=self.stats.summarize_count,
251 branch_count=self.stats.branch_count,
252 oldest_timestamp=self.stats.oldest_timestamp,
253 newest_timestamp=self.stats.newest_timestamp,
254 ),
255 snapshot_id=sid,
256 label=label or f"branch-{sid}",
257 )
258 self._branches[sid] = snapshot
259 self.stats.branch_count += 1
260 return snapshot
262 def switch_branch(self, snapshot_id: str):
263 """切换到指定分支。"""
264 snapshot = self._branches.get(snapshot_id)
265 if not snapshot:
266 raise KeyError(f"Branch '{snapshot_id}' not found")
267 self._messages = list(snapshot.messages)
268 self.stats = ConversationStats(
269 total_messages=snapshot.stats.total_messages,
270 total_tokens=snapshot.stats.total_tokens,
271 trim_count=snapshot.stats.trim_count,
272 summarize_count=snapshot.stats.summarize_count,
273 branch_count=snapshot.stats.branch_count,
274 oldest_timestamp=snapshot.stats.oldest_timestamp,
275 newest_timestamp=snapshot.stats.newest_timestamp,
276 )
277 self._current_branch = snapshot_id
279 def merge_branch(self, snapshot_id: str, strategy: str = "append"):
280 """合并分支消息到当前对话。"""
281 snapshot = self._branches.get(snapshot_id)
282 if not snapshot:
283 raise KeyError(f"Branch '{snapshot_id}' not found")
284 if strategy == "append":
285 existing_ids = {m.message_id for m in self._messages}
286 for msg in snapshot.messages:
287 if msg.message_id not in existing_ids:
288 self._messages.append(msg)
289 self.stats.total_messages += 1
290 self.stats.total_tokens += msg.token_count
291 elif strategy == "replace":
292 self._messages = list(snapshot.messages)
293 self._enforce_limits()
295 def list_branches(self) -> dict[str, ConversationSnapshot]:
296 """列出所有分支。"""
297 return dict(self._branches)
299 # ── 工具方法 ──────────────────────────────────────────────
301 def _count_tokens(self, text: str) -> int:
302 """估算 token 数。"""
303 if self.config.token_counter:
304 return self.config.token_counter(text)
305 return len(text) // 3
307 def _next_id(self) -> str:
308 self._message_counter += 1
309 return f"msg_{self._message_counter:06d}"
311 def clear(self, keep_system: bool = True):
312 """清空对话历史。"""
313 system_msgs = [m for m in self._messages if m.role == MessageRole.SYSTEM] if keep_system else []
314 self._messages = system_msgs
315 self._summary = ""
316 self.stats = ConversationStats()
318 @property
319 def message_count(self) -> int:
320 return len(self._messages)
322 @property
323 def token_count(self) -> int:
324 return self.stats.total_tokens
326 def __len__(self) -> int:
327 return len(self._messages)
329 def __repr__(self) -> str:
330 return f"<Conversation messages={len(self)} tokens={self.token_count} branches={len(self._branches)}>"
333__all__ = [
334 "ConversationManager",
335 "ConversationConfig",
336 "ConversationStats",
337 "ConversationSnapshot",
338 "Message",
339 "MessageRole",
340 "TrimStrategy",
341]