Coverage for agentos/memory/conversation.py: 42%

146 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 16:01 +0800

1""" 

2Conversation Memory with sliding window management. 

3 

4Manages multi-turn conversations with configurable window strategies: 

5- Sliding window (FIFO with max turns) 

6- Token-aware window (trim by token count) 

7- Importance-weighted (keep high-importance turns, evict low) 

8- Hybrid (combine token budget + importance scoring) 

9""" 

10 

11from __future__ import annotations 

12 

13from dataclasses import dataclass, field 

14from enum import Enum 

15from typing import Any, Optional 

16 

17 

18class WindowStrategy(Enum): 

19 SLIDING = "sliding" 

20 """FIFO: keep last N turns, evict oldest.""" 

21 

22 TOKEN_AWARE = "token_aware" 

23 """Keep as many turns as fit within token budget.""" 

24 

25 IMPORTANCE = "importance" 

26 """Keep high-importance turns, evict lowest scores.""" 

27 

28 HYBRID = "hybrid" 

29 """Token budget + importance scoring combined.""" 

30 

31 

32@dataclass 

33class ConversationTurn: 

34 """Single turn in a conversation.""" 

35 

36 role: str 

37 """'user', 'assistant', 'system', 'tool'.""" 

38 

39 content: str 

40 timestamp: float = 0.0 

41 token_count: int = 0 

42 importance: float = 0.5 

43 """0.0 = least important, 1.0 = most important.""" 

44 

45 metadata: dict[str, Any] = field(default_factory=dict) 

46 

47 

48@dataclass 

49class WindowConfig: 

50 """Configuration for conversation window management.""" 

51 

52 strategy: WindowStrategy = WindowStrategy.SLIDING 

53 

54 max_turns: int = 20 

55 """Max conversation turns (sliding window).""" 

56 

57 max_tokens: int = 8000 

58 """Max total token budget (token_aware / hybrid).""" 

59 

60 importance_threshold: float = 0.3 

61 """Minimum importance score to keep (importance / hybrid).""" 

62 

63 system_prompt: Optional[str] = None 

64 """System prompt always kept at top of window.""" 

65 

66 preserve_last_n: int = 2 

67 """Always keep the last N turns regardless of eviction rules.""" 

68 

69 

70class ConversationMemory: 

71 """ 

72 Multi-turn conversation memory with sliding window strategies. 

73 

74 Example:: 

75 

76 mem = ConversationMemory(WindowConfig(strategy=WindowStrategy.HYBRID, max_tokens=4000)) 

77 mem.add_turn(ConversationTurn(role="user", content="Hello")) 

78 mem.add_turn(ConversationTurn(role="assistant", content="Hi! How can I help?")) 

79 messages = mem.get_messages() # [{"role": "user", "content": "Hello"}, ...] 

80 """ 

81 

82 def __init__(self, config: Optional[WindowConfig] = None): 

83 self.config = config or WindowConfig() 

84 self._turns: list[ConversationTurn] = [] 

85 self._token_count_cache: int = 0 

86 

87 def add_turn(self, turn: ConversationTurn) -> None: 

88 """Add a turn and apply window eviction if needed.""" 

89 self._turns.append(turn) 

90 self._token_count_cache += turn.token_count if turn.token_count > 0 else self._estimate_tokens(turn.content) 

91 self._apply_window() 

92 

93 def add_user_message(self, content: str, importance: float = 0.5) -> None: 

94 self.add_turn(ConversationTurn( 

95 role="user", content=content, importance=importance, 

96 token_count=self._estimate_tokens(content), 

97 )) 

98 

99 def add_assistant_message(self, content: str, importance: float = 0.5) -> None: 

100 self.add_turn(ConversationTurn( 

101 role="assistant", content=content, importance=importance, 

102 token_count=self._estimate_tokens(content), 

103 )) 

104 

105 def add_system_message(self, content: str) -> None: 

106 self.add_turn(ConversationTurn( 

107 role="system", content=content, importance=1.0, 

108 token_count=self._estimate_tokens(content), 

109 )) 

110 

111 def _apply_window(self) -> None: 

112 """Apply the configured window strategy to evict excess turns.""" 

113 strategy = self.config.strategy 

114 

115 if strategy == WindowStrategy.SLIDING: 

116 self._evict_sliding() 

117 elif strategy == WindowStrategy.TOKEN_AWARE: 

118 while self._token_count_cache > self.config.max_tokens and len(self._turns) > self.config.preserve_last_n: 

119 self._evict_one(0) 

120 elif strategy == WindowStrategy.IMPORTANCE: 

121 self._evict_by_importance() 

122 elif strategy == WindowStrategy.HYBRID: 

123 self._evict_hybrid() 

124 

125 def _evict_sliding(self) -> None: 

126 """FIFO: remove oldest turns exceeding max_turns.""" 

127 preserve = self.config.preserve_last_n 

128 max_keep = self.config.max_turns 

129 

130 while len(self._turns) > max_keep: 

131 evict_idx = 0 

132 # Don't evict system prompt 

133 if self._turns[0].role == "system": 

134 evict_idx = 1 

135 # Don't evict preserved last N turns 

136 if len(self._turns) - evict_idx <= preserve: 

137 break 

138 self._evict_one(evict_idx) 

139 

140 def _evict_by_importance(self) -> None: 

141 """Evict lowest-importance turns above threshold.""" 

142 preserve = self.config.preserve_last_n 

143 threshold = self.config.importance_threshold 

144 

145 while True: 

146 candidates = [ 

147 (i, t) 

148 for i, t in enumerate(self._turns) 

149 if t.role != "system" 

150 and i < len(self._turns) - preserve 

151 and t.importance < threshold 

152 ] 

153 if not candidates: 

154 break 

155 

156 # Evict the least important 

157 idx, _ = min(candidates, key=lambda x: x[1].importance) 

158 self._evict_one(idx) 

159 if not any( 

160 t.importance < threshold 

161 for i, t in enumerate(self._turns) 

162 if t.role != "system" and i < len(self._turns) - preserve 

163 ): 

164 break 

165 

166 def _evict_hybrid(self) -> None: 

167 """Token budget + importance scoring combined.""" 

168 preserve = self.config.preserve_last_n 

169 threshold = self.config.importance_threshold 

170 

171 # First, evict low-importance turns within budget 

172 while self._token_count_cache > self.config.max_tokens: 

173 candidates = [ 

174 (i, t) 

175 for i, t in enumerate(self._turns) 

176 if t.role != "system" 

177 and i < len(self._turns) - preserve 

178 and t.importance < threshold 

179 ] 

180 if not candidates: 

181 # Fall back to evicting oldest non-system turn 

182 oldest_idx = -1 

183 for i, t in enumerate(self._turns): 

184 if t.role != "system" and i < len(self._turns) - preserve: 

185 oldest_idx = i 

186 break 

187 if oldest_idx == -1: 

188 break 

189 self._evict_one(oldest_idx) 

190 else: 

191 idx, _ = min(candidates, key=lambda x: x[1].importance) 

192 self._evict_one(idx) 

193 

194 def _evict_one(self, index: int) -> None: 

195 """Remove a single turn at given index.""" 

196 if 0 <= index < len(self._turns): 

197 turn = self._turns.pop(index) 

198 self._token_count_cache -= turn.token_count if turn.token_count > 0 else self._estimate_tokens(turn.content) 

199 self._token_count_cache = max(0, self._token_count_cache) 

200 

201 def get_messages(self) -> list[dict[str, str]]: 

202 """Return conversation as list of dicts (OpenAI chat format).""" 

203 msgs: list[dict[str, str]] = [] 

204 if self.config.system_prompt: 

205 msgs.append({"role": "system", "content": self.config.system_prompt}) 

206 for turn in self._turns: 

207 msgs.append({"role": turn.role, "content": turn.content}) 

208 return msgs 

209 

210 def get_turns(self) -> list[ConversationTurn]: 

211 return list(self._turns) 

212 

213 @property 

214 def turn_count(self) -> int: 

215 return len(self._turns) 

216 

217 @property 

218 def token_count(self) -> int: 

219 return self._token_count_cache 

220 

221 def clear(self) -> None: 

222 """Reset conversation memory.""" 

223 self._turns.clear() 

224 self._token_count_cache = 0 

225 

226 def to_summary(self) -> str: 

227 """Generate a brief summary of the conversation memory.""" 

228 turns = self._turns 

229 if not turns: 

230 return "Empty conversation." 

231 

232 lines = [ 

233 f"Total turns: {len(turns)}", 

234 f"Total tokens (est.): {self._token_count_cache}", 

235 f"First turn: [{turns[0].role}] {turns[0].content[:80]}...", 

236 ] 

237 if len(turns) > 1: 

238 lines.append(f"Last turn: [{turns[-1].role}] {turns[-1].content[:80]}...") 

239 return "\n".join(lines) 

240 

241 @staticmethod 

242 def _estimate_tokens(text: str) -> int: 

243 """Rough token estimation: ~4 chars per token.""" 

244 return max(1, len(text) // 4) 

245 

246 def __len__(self) -> int: 

247 return len(self._turns) 

248 

249 def __repr__(self) -> str: 

250 return f"ConversationMemory(turns={len(self._turns)}, tokens={self._token_count_cache}, strategy={self.config.strategy.value})" 

251 

252 # ── Persistence (v1.14.9) ──────────────── 

253 

254 def get_state(self) -> dict[str, Any]: 

255 """Export conversation memory state for persistence.""" 

256 return { 

257 "config": { 

258 "strategy": self.config.strategy.value, 

259 "max_turns": self.config.max_turns, 

260 "max_tokens": self.config.max_tokens, 

261 "importance_threshold": self.config.importance_threshold, 

262 "system_prompt": self.config.system_prompt, 

263 "preserve_last_n": self.config.preserve_last_n, 

264 }, 

265 "turns": [ 

266 { 

267 "role": turn.role, 

268 "content": turn.content, 

269 "timestamp": turn.timestamp, 

270 "token_count": turn.token_count, 

271 "importance": turn.importance, 

272 "metadata": turn.metadata, 

273 } 

274 for turn in self._turns 

275 ], 

276 "token_count_cache": self._token_count_cache, 

277 } 

278 

279 def restore_state(self, state: dict[str, Any]) -> None: 

280 """Restore conversation memory from a persisted snapshot.""" 

281 config_data = state.get("config", {}) 

282 self.config = WindowConfig( 

283 strategy=WindowStrategy(config_data.get("strategy", "sliding")), 

284 max_turns=config_data.get("max_turns", 20), 

285 max_tokens=config_data.get("max_tokens", 8000), 

286 importance_threshold=config_data.get("importance_threshold", 0.3), 

287 system_prompt=config_data.get("system_prompt"), 

288 preserve_last_n=config_data.get("preserve_last_n", 2), 

289 ) 

290 self._turns = [] 

291 for turn_data in state.get("turns", []): 

292 self._turns.append(ConversationTurn( 

293 role=turn_data.get("role", "user"), 

294 content=turn_data.get("content", ""), 

295 timestamp=turn_data.get("timestamp", 0.0), 

296 token_count=turn_data.get("token_count", 0), 

297 importance=turn_data.get("importance", 0.5), 

298 metadata=turn_data.get("metadata", {}), 

299 )) 

300 self._token_count_cache = state.get("token_count_cache", 0)