Coverage for memory / session.py: 87%
126 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-29 02:55 +0800
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-29 02:55 +0800
1import json
2import uuid
3from datetime import datetime
4from pathlib import Path
5from qrclaw.logger import get_logger
6from qrclaw.config import OPENAI_MODEL
8logger = get_logger("qrclaw.memory.session")
10# 初始化 tiktoken encoder
11import tiktoken
12try:
13 _encoding = tiktoken.encoding_for_model(OPENAI_MODEL)
14except KeyError:
15 _encoding = tiktoken.get_encoding("cl100k_base")
16 logger.debug(f"模型 {OPENAI_MODEL} 无对应 encoder,使用 cl100k_base")
19def list_sessions(sessions_dir: Path) -> list[dict]:
20 """
21 列出指定目录下所有已保存的会话。
23 Args:
24 sessions_dir: 会话文件目录(由 Workspace 提供)
25 Returns:
26 list[dict]: 每项包含 id、message_count、updated_at
27 """
28 sessions_dir.mkdir(parents=True, exist_ok=True)
29 sessions = []
30 for path in sorted(sessions_dir.glob("*.json"), key=lambda p: p.stat().st_mtime, reverse=True):
31 try:
32 data = json.loads(path.read_text(encoding="utf-8"))
33 msg_count = len([m for m in data if m.get("role") != "system"])
34 mtime = datetime.fromtimestamp(path.stat().st_mtime).strftime("%Y-%m-%d %H:%M")
35 sessions.append({
36 "id": path.stem,
37 "message_count": msg_count,
38 "updated_at": mtime,
39 })
40 except Exception:
41 pass
42 return sessions
45def get_last_session_id(sessions_dir: Path) -> str | None:
46 """
47 获取最近使用的会话 ID(按修改时间排序,取最新的)。
49 Args:
50 sessions_dir: 会话文件目录
51 Returns:
52 str | None: 会话 ID,如果没有会话则返回 None
53 """
54 sessions_dir.mkdir(parents=True, exist_ok=True)
55 sessions = sorted(sessions_dir.glob("*.json"), key=lambda p: p.stat().st_mtime, reverse=True)
56 if sessions:
57 return sessions[0].stem
58 return None
61def delete_session(session_id: str, sessions_dir: Path) -> bool:
62 """删除指定会话文件,返回是否成功。"""
63 path = sessions_dir / f"{session_id}.json"
64 if path.exists():
65 path.unlink()
66 logger.info(f"删除会话: {session_id}")
67 return True
68 return False
71def count_tokens(messages: list[dict]) -> int:
72 """
73 精确计算消息列表的 token 数(使用 tiktoken)。
75 Args:
76 messages: OpenAI 格式的消息列表
77 Returns:
78 int: token 总数
79 """
80 tokens = 0
81 for msg in messages:
82 # 每条消息有固定开销
83 tokens += 4 # {"role": "...", "content": "..."} 格式开销
84 for key, value in msg.items():
85 if value is not None:
86 tokens += len(_encoding.encode(str(value)))
87 tokens += 2 # 对话开销
88 return tokens
91class Session:
92 def __init__(self, sessions_dir: Path, session_id: str = None, resume: bool = True):
93 """
94 初始化会话。
96 Args:
97 sessions_dir: 会话文件目录(由 Workspace 提供)
98 session_id: 指定会话 ID,为 None 时自动选择
99 resume: 是否恢复最近的会话(仅在 session_id 为 None 时生效)
100 """
101 # 如果指定了 session_id,直接使用
102 if session_id is not None:
103 self.session_id = session_id
104 # 否则根据 resume 决定是恢复最近会话还是创建新会话
105 elif resume:
106 # 尝试恢复最近的会话
107 last_id = get_last_session_id(sessions_dir)
108 if last_id:
109 self.session_id = last_id
110 logger.info(f"恢复最近的会话: {last_id}")
111 else:
112 # 没有历史会话,创建新的
113 short = uuid.uuid4().hex[:8]
114 self.session_id = f"{datetime.now().strftime('%Y%m%d')}-{short}"
115 logger.info(f"创建新会话: {self.session_id}")
116 else:
117 # 不恢复,创建新会话
118 short = uuid.uuid4().hex[:8]
119 self.session_id = f"{datetime.now().strftime('%Y%m%d')}-{short}"
120 logger.info(f"创建新会话: {self.session_id}")
122 self.messages: list[dict] = []
124 # 上下文使用情况
125 self.prompt_tokens = 0
126 self.completion_tokens = 0
127 self.total_tokens = 0
129 # 当前活跃计划
130 self.active_plan: dict | None = None
132 # 会话文件路径(由 Workspace 提供的目录决定)
133 sessions_dir.mkdir(parents=True, exist_ok=True)
134 self._path = sessions_dir / f"{self.session_id}.json"
136 logger.debug(f"初始化会话: {self.session_id}, 路径: {self._path}")
138 # 启动时加载历史
139 self._load()
141 def add(self, message: dict):
142 """追加一条消息,并立即存盘"""
143 self.messages.append(message)
144 self._save()
145 logger.debug(f"添加消息: {message.get('role', 'unknown')}, 当前会话消息数: {len(self.messages)}")
147 def update_tokens(self, prompt_tokens: int, completion_tokens: int, total_tokens: int):
148 """更新 token 使用情况"""
149 self.prompt_tokens = prompt_tokens
150 self.completion_tokens = completion_tokens
151 self.total_tokens = total_tokens
152 logger.debug(f"更新 token: prompt={prompt_tokens}, completion={completion_tokens}, total={total_tokens}")
154 def set_plan(self, goal: str, steps: list[dict]):
155 """设置当前活跃计划"""
156 self.active_plan = {
157 "goal": goal,
158 "steps": [{"id": s["id"], "description": s["description"], "done": False} for s in steps],
159 }
160 logger.info(f"设置执行计划: {goal}, 共 {len(steps)} 步")
162 def complete_step(self, step_id: int) -> bool:
163 """标记某步骤为已完成,返回是否全部完成"""
164 if not self.active_plan:
165 return False
166 for step in self.active_plan["steps"]:
167 if step["id"] == step_id:
168 step["done"] = True
169 logger.info(f"计划步骤 {step_id} 已完成")
170 break
171 all_done = all(s["done"] for s in self.active_plan["steps"])
172 if all_done:
173 logger.info("所有计划步骤已完成,清空计划")
174 self.active_plan = None
175 return all_done
177 def clear_plan(self):
178 """清空当前计划"""
179 self.active_plan = None
180 logger.info("计划已清空")
182 def clear(self):
183 """清空当前会话"""
184 logger.info(f"清除会话: {self.session_id}")
185 self.messages = []
186 self.prompt_tokens = 0
187 self.completion_tokens = 0
188 self.total_tokens = 0
189 if self._path.exists():
190 self._path.unlink()
191 logger.debug(f"删除会话文件: {self._path}")
193 def _save(self):
194 try:
195 self._path.write_text(
196 json.dumps(self.messages, ensure_ascii=False, indent=2),
197 encoding="utf-8"
198 )
199 logger.debug(f"会话保存成功: {self._path}")
200 except Exception as e:
201 logger.error(f"会话保存失败: {e}", exc_info=True)
202 raise
204 def _load(self):
205 if self._path.exists():
206 try:
207 data = json.loads(self._path.read_text(encoding="utf-8"))
208 # 过滤掉旧历史里的 system 消息,system prompt 由 agent 实时生成
209 self.messages = [m for m in data if m.get("role") != "system"]
210 # 精确计算已加载消息的 token 数
211 self.prompt_tokens = count_tokens(self.messages)
212 logger.info(f"加载历史会话: {len(self.messages)} 条消息, {self.prompt_tokens} tokens")
213 except Exception as e:
214 logger.error(f"加载会话失败: {e}", exc_info=True)
215 self.messages = []
216 else:
217 logger.debug("未找到历史会话,创建新会话")