Coverage for memory / session.py: 87%

126 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-29 02:55 +0800

1import json 

2import uuid 

3from datetime import datetime 

4from pathlib import Path 

5from qrclaw.logger import get_logger 

6from qrclaw.config import OPENAI_MODEL 

7 

8logger = get_logger("qrclaw.memory.session") 

9 

10# 初始化 tiktoken encoder 

11import tiktoken 

12try: 

13 _encoding = tiktoken.encoding_for_model(OPENAI_MODEL) 

14except KeyError: 

15 _encoding = tiktoken.get_encoding("cl100k_base") 

16 logger.debug(f"模型 {OPENAI_MODEL} 无对应 encoder,使用 cl100k_base") 

17 

18 

19def list_sessions(sessions_dir: Path) -> list[dict]: 

20 """ 

21 列出指定目录下所有已保存的会话。 

22 

23 Args: 

24 sessions_dir: 会话文件目录(由 Workspace 提供) 

25 Returns: 

26 list[dict]: 每项包含 id、message_count、updated_at 

27 """ 

28 sessions_dir.mkdir(parents=True, exist_ok=True) 

29 sessions = [] 

30 for path in sorted(sessions_dir.glob("*.json"), key=lambda p: p.stat().st_mtime, reverse=True): 

31 try: 

32 data = json.loads(path.read_text(encoding="utf-8")) 

33 msg_count = len([m for m in data if m.get("role") != "system"]) 

34 mtime = datetime.fromtimestamp(path.stat().st_mtime).strftime("%Y-%m-%d %H:%M") 

35 sessions.append({ 

36 "id": path.stem, 

37 "message_count": msg_count, 

38 "updated_at": mtime, 

39 }) 

40 except Exception: 

41 pass 

42 return sessions 

43 

44 

45def get_last_session_id(sessions_dir: Path) -> str | None: 

46 """ 

47 获取最近使用的会话 ID(按修改时间排序,取最新的)。 

48 

49 Args: 

50 sessions_dir: 会话文件目录 

51 Returns: 

52 str | None: 会话 ID,如果没有会话则返回 None 

53 """ 

54 sessions_dir.mkdir(parents=True, exist_ok=True) 

55 sessions = sorted(sessions_dir.glob("*.json"), key=lambda p: p.stat().st_mtime, reverse=True) 

56 if sessions: 

57 return sessions[0].stem 

58 return None 

59 

60 

61def delete_session(session_id: str, sessions_dir: Path) -> bool: 

62 """删除指定会话文件,返回是否成功。""" 

63 path = sessions_dir / f"{session_id}.json" 

64 if path.exists(): 

65 path.unlink() 

66 logger.info(f"删除会话: {session_id}") 

67 return True 

68 return False 

69 

70 

71def count_tokens(messages: list[dict]) -> int: 

72 """ 

73 精确计算消息列表的 token 数(使用 tiktoken)。 

74 

75 Args: 

76 messages: OpenAI 格式的消息列表 

77 Returns: 

78 int: token 总数 

79 """ 

80 tokens = 0 

81 for msg in messages: 

82 # 每条消息有固定开销 

83 tokens += 4 # {"role": "...", "content": "..."} 格式开销 

84 for key, value in msg.items(): 

85 if value is not None: 

86 tokens += len(_encoding.encode(str(value))) 

87 tokens += 2 # 对话开销 

88 return tokens 

89 

90 

91class Session: 

92 def __init__(self, sessions_dir: Path, session_id: str = None, resume: bool = True): 

93 """ 

94 初始化会话。 

95 

96 Args: 

97 sessions_dir: 会话文件目录(由 Workspace 提供) 

98 session_id: 指定会话 ID,为 None 时自动选择 

99 resume: 是否恢复最近的会话(仅在 session_id 为 None 时生效) 

100 """ 

101 # 如果指定了 session_id,直接使用 

102 if session_id is not None: 

103 self.session_id = session_id 

104 # 否则根据 resume 决定是恢复最近会话还是创建新会话 

105 elif resume: 

106 # 尝试恢复最近的会话 

107 last_id = get_last_session_id(sessions_dir) 

108 if last_id: 

109 self.session_id = last_id 

110 logger.info(f"恢复最近的会话: {last_id}") 

111 else: 

112 # 没有历史会话,创建新的 

113 short = uuid.uuid4().hex[:8] 

114 self.session_id = f"{datetime.now().strftime('%Y%m%d')}-{short}" 

115 logger.info(f"创建新会话: {self.session_id}") 

116 else: 

117 # 不恢复,创建新会话 

118 short = uuid.uuid4().hex[:8] 

119 self.session_id = f"{datetime.now().strftime('%Y%m%d')}-{short}" 

120 logger.info(f"创建新会话: {self.session_id}") 

121 

122 self.messages: list[dict] = [] 

123 

124 # 上下文使用情况 

125 self.prompt_tokens = 0 

126 self.completion_tokens = 0 

127 self.total_tokens = 0 

128 

129 # 当前活跃计划 

130 self.active_plan: dict | None = None 

131 

132 # 会话文件路径(由 Workspace 提供的目录决定) 

133 sessions_dir.mkdir(parents=True, exist_ok=True) 

134 self._path = sessions_dir / f"{self.session_id}.json" 

135 

136 logger.debug(f"初始化会话: {self.session_id}, 路径: {self._path}") 

137 

138 # 启动时加载历史 

139 self._load() 

140 

141 def add(self, message: dict): 

142 """追加一条消息,并立即存盘""" 

143 self.messages.append(message) 

144 self._save() 

145 logger.debug(f"添加消息: {message.get('role', 'unknown')}, 当前会话消息数: {len(self.messages)}") 

146 

147 def update_tokens(self, prompt_tokens: int, completion_tokens: int, total_tokens: int): 

148 """更新 token 使用情况""" 

149 self.prompt_tokens = prompt_tokens 

150 self.completion_tokens = completion_tokens 

151 self.total_tokens = total_tokens 

152 logger.debug(f"更新 token: prompt={prompt_tokens}, completion={completion_tokens}, total={total_tokens}") 

153 

154 def set_plan(self, goal: str, steps: list[dict]): 

155 """设置当前活跃计划""" 

156 self.active_plan = { 

157 "goal": goal, 

158 "steps": [{"id": s["id"], "description": s["description"], "done": False} for s in steps], 

159 } 

160 logger.info(f"设置执行计划: {goal}, 共 {len(steps)}") 

161 

162 def complete_step(self, step_id: int) -> bool: 

163 """标记某步骤为已完成,返回是否全部完成""" 

164 if not self.active_plan: 

165 return False 

166 for step in self.active_plan["steps"]: 

167 if step["id"] == step_id: 

168 step["done"] = True 

169 logger.info(f"计划步骤 {step_id} 已完成") 

170 break 

171 all_done = all(s["done"] for s in self.active_plan["steps"]) 

172 if all_done: 

173 logger.info("所有计划步骤已完成,清空计划") 

174 self.active_plan = None 

175 return all_done 

176 

177 def clear_plan(self): 

178 """清空当前计划""" 

179 self.active_plan = None 

180 logger.info("计划已清空") 

181 

182 def clear(self): 

183 """清空当前会话""" 

184 logger.info(f"清除会话: {self.session_id}") 

185 self.messages = [] 

186 self.prompt_tokens = 0 

187 self.completion_tokens = 0 

188 self.total_tokens = 0 

189 if self._path.exists(): 

190 self._path.unlink() 

191 logger.debug(f"删除会话文件: {self._path}") 

192 

193 def _save(self): 

194 try: 

195 self._path.write_text( 

196 json.dumps(self.messages, ensure_ascii=False, indent=2), 

197 encoding="utf-8" 

198 ) 

199 logger.debug(f"会话保存成功: {self._path}") 

200 except Exception as e: 

201 logger.error(f"会话保存失败: {e}", exc_info=True) 

202 raise 

203 

204 def _load(self): 

205 if self._path.exists(): 

206 try: 

207 data = json.loads(self._path.read_text(encoding="utf-8")) 

208 # 过滤掉旧历史里的 system 消息,system prompt 由 agent 实时生成 

209 self.messages = [m for m in data if m.get("role") != "system"] 

210 # 精确计算已加载消息的 token 数 

211 self.prompt_tokens = count_tokens(self.messages) 

212 logger.info(f"加载历史会话: {len(self.messages)} 条消息, {self.prompt_tokens} tokens") 

213 except Exception as e: 

214 logger.error(f"加载会话失败: {e}", exc_info=True) 

215 self.messages = [] 

216 else: 

217 logger.debug("未找到历史会话,创建新会话")