Coverage for agentos/memory/consolidation.py: 37%

265 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 16:36 +0800

1""" 

2AgentOS v1.14.1 — 长期记忆巩固系统 (Memory Consolidation)。 

3 

4受 Letta/MemGPT 三层记忆体系启发,在虚拟内存分页器之上增加主动记忆巩固层。 

5核心机制: 

6- Reflection: 定期分析对话历史,提取关键事实、模式、用户偏好 

7- Consolidation: 将 Reflection 结果写入长期向量存储 

8- Retrieval: 智能检索历史记忆,注入后续对话上下文 

9 

10与 memory/pager.py 的关系: 

11- pager.py: 被动分页(上下文窗口溢出时 page_out / keyword search page_in) 

12- consolidation.py: 主动巩固(定期分析 → 提取 → 向量化存储) 

13""" 

14 

15from __future__ import annotations 

16 

17import asyncio 

18import json 

19import time 

20import uuid 

21from dataclasses import dataclass, field 

22from enum import Enum 

23from typing import ( 

24 Any, Callable, Dict, List, Optional, Set, Tuple, Union, 

25) 

26 

27 

28# ── Memory Data Models ────────────────────── 

29 

30 

31class MemoryType(str, Enum): 

32 """记忆类型。""" 

33 FACT = "fact" # 事实性信息 

34 PREFERENCE = "preference" # 用户偏好 

35 PATTERN = "pattern" # 行为模式 

36 DECISION = "decision" # 决策记录 

37 LESSON = "lesson" # 经验教训 

38 CONTEXT = "context" # 上下文摘要 

39 

40 

41class MemoryImportance(str, Enum): 

42 """记忆重要性。""" 

43 LOW = "low" 

44 MEDIUM = "medium" 

45 HIGH = "high" 

46 CRITICAL = "critical" 

47 

48 

49@dataclass 

50class MemoryFragment: 

51 """记忆片段 — 从对话中提取的原子事实。""" 

52 

53 memory_id: str = field(default_factory=lambda: f"mem-{uuid.uuid4().hex[:12]}") 

54 memory_type: MemoryType = MemoryType.FACT 

55 content: str = "" 

56 importance: MemoryImportance = MemoryImportance.MEDIUM 

57 source_messages: List[int] = field(default_factory=list) # 源自哪些消息 

58 tags: List[str] = field(default_factory=list) 

59 confidence: float = 1.0 # 0.0~1.0 置信度 

60 created_at: float = field(default_factory=time.time) 

61 last_accessed: float = field(default_factory=time.time) 

62 access_count: int = 0 

63 embedding: Optional[List[float]] = None # 向量嵌入(惰性计算) 

64 metadata: Dict[str, Any] = field(default_factory=dict) 

65 

66 def to_dict(self) -> dict: 

67 return { 

68 "memory_id": self.memory_id, 

69 "memory_type": self.memory_type.value, 

70 "content": self.content, 

71 "importance": self.importance.value, 

72 "source_messages": self.source_messages, 

73 "tags": self.tags, 

74 "confidence": self.confidence, 

75 "created_at": self.created_at, 

76 "last_accessed": self.last_accessed, 

77 "access_count": self.access_count, 

78 "metadata": self.metadata, 

79 } 

80 

81 @classmethod 

82 def from_dict(cls, d: dict) -> "MemoryFragment": 

83 return cls( 

84 memory_id=d.get("memory_id", ""), 

85 memory_type=MemoryType(d.get("memory_type", "fact")), 

86 content=d.get("content", ""), 

87 importance=MemoryImportance(d.get("importance", "medium")), 

88 source_messages=d.get("source_messages", []), 

89 tags=d.get("tags", []), 

90 confidence=d.get("confidence", 1.0), 

91 created_at=d.get("created_at", time.time()), 

92 last_accessed=d.get("last_accessed", time.time()), 

93 access_count=d.get("access_count", 0), 

94 metadata=d.get("metadata", {}), 

95 ) 

96 

97 def touch(self) -> None: 

98 """更新访问时间。""" 

99 self.last_accessed = time.time() 

100 self.access_count += 1 

101 

102 

103@dataclass 

104class ReflectionResult: 

105 """一次 Reflection 的输出。""" 

106 

107 fragments: List[MemoryFragment] = field(default_factory=list) 

108 summary: str = "" # 会话级摘要 

109 contradictions: List[Tuple[str, str]] = field(default_factory=list) # 新旧矛盾 

110 deprecated_ids: List[str] = field(default_factory=list) # 需淘汰的旧记忆 

111 user_profile_update: Dict[str, Any] = field(default_factory=dict) 

112 timestamp: float = field(default_factory=time.time) 

113 

114 @property 

115 def total_fragments(self) -> int: 

116 return len(self.fragments) 

117 

118 @property 

119 def has_insights(self) -> bool: 

120 return bool(self.fragments or self.summary or self.user_profile_update) 

121 

122 

123# ── Vector Store Interface ────────────────── 

124 

125 

126class VectorStoreBackend: 

127 """向量存储后端抽象。 

128 

129 支持多种后端: 内存、FAISS、Chroma、Pinecone 等。 

130 """ 

131 

132 async def add( 

133 self, 

134 fragments: List[MemoryFragment], 

135 embeddings: List[List[float]], 

136 ) -> List[str]: 

137 """批量添加记忆片段(含嵌入向量)。返回 memory_ids。""" 

138 raise NotImplementedError 

139 

140 async def search( 

141 self, 

142 query_embedding: List[float], 

143 top_k: int = 10, 

144 filter_types: Optional[List[MemoryType]] = None, 

145 min_importance: MemoryImportance = MemoryImportance.LOW, 

146 ) -> List[Tuple[MemoryFragment, float]]: 

147 """向量相似度搜索。返回 (fragment, score)。""" 

148 raise NotImplementedError 

149 

150 async def delete(self, memory_ids: List[str]) -> int: 

151 """删除指定记忆。返回删除数量。""" 

152 raise NotImplementedError 

153 

154 async def count(self) -> int: 

155 """记忆总数。""" 

156 raise NotImplementedError 

157 

158 

159class InMemoryVectorStore(VectorStoreBackend): 

160 """内存向量存储(开发/测试用)。""" 

161 

162 def __init__(self): 

163 self._fragments: Dict[str, MemoryFragment] = {} 

164 self._embeddings: Dict[str, List[float]] = {} 

165 

166 async def add( 

167 self, 

168 fragments: List[MemoryFragment], 

169 embeddings: List[List[float]], 

170 ) -> List[str]: 

171 ids = [] 

172 for frag, emb in zip(fragments, embeddings): 

173 self._fragments[frag.memory_id] = frag 

174 self._embeddings[frag.memory_id] = emb 

175 ids.append(frag.memory_id) 

176 return ids 

177 

178 async def search( 

179 self, 

180 query_embedding: List[float], 

181 top_k: int = 10, 

182 filter_types: Optional[List[MemoryType]] = None, 

183 min_importance: MemoryImportance = MemoryImportance.LOW, 

184 ) -> List[Tuple[MemoryFragment, float]]: 

185 results = [] 

186 importance_rank = { 

187 MemoryImportance.LOW: 0, 

188 MemoryImportance.MEDIUM: 1, 

189 MemoryImportance.HIGH: 2, 

190 MemoryImportance.CRITICAL: 3, 

191 } 

192 min_rank = importance_rank[min_importance] 

193 

194 for mid, emb in self._embeddings.items(): 

195 frag = self._fragments[mid] 

196 # Type filter 

197 if filter_types and frag.memory_type not in filter_types: 

198 continue 

199 # Importance filter 

200 if importance_rank[frag.importance] < min_rank: 

201 continue 

202 # Cosine similarity 

203 score = self._cosine_similarity(query_embedding, emb) 

204 results.append((frag, score)) 

205 

206 results.sort(key=lambda x: x[1], reverse=True) 

207 return results[:top_k] 

208 

209 async def delete(self, memory_ids: List[str]) -> int: 

210 count = 0 

211 for mid in memory_ids: 

212 if mid in self._fragments: 

213 del self._fragments[mid] 

214 self._embeddings.pop(mid, None) 

215 count += 1 

216 return count 

217 

218 async def count(self) -> int: 

219 return len(self._fragments) 

220 

221 @staticmethod 

222 def _cosine_similarity(a: List[float], b: List[float]) -> float: 

223 dot = sum(x * y for x, y in zip(a, b)) 

224 norm_a = sum(x * x for x in a) ** 0.5 

225 norm_b = sum(x * x for x in b) ** 0.5 

226 if norm_a == 0 or norm_b == 0: 

227 return 0.0 

228 return dot / (norm_a * norm_b) 

229 

230 

231# ── Embedding Provider ────────────────────── 

232 

233 

234class EmbeddingProvider: 

235 """嵌入向量生成器抽象。""" 

236 

237 async def embed(self, texts: List[str]) -> List[List[float]]: 

238 """批量生成嵌入向量。""" 

239 raise NotImplementedError 

240 

241 async def embed_single(self, text: str) -> List[float]: 

242 """单条文本嵌入。""" 

243 results = await self.embed([text]) 

244 return results[0] 

245 

246 

247class SimpleHashEmbedding(EmbeddingProvider): 

248 """简单位次嵌入(开发/测试用,非语义向量)。 

249 

250 用字符 n-gram 哈希作为伪嵌入,提供基本的相似度。 

251 生产环境应替换为 OpenAI/Cohere 等真实嵌入模型。 

252 """ 

253 

254 def __init__(self, dim: int = 128): 

255 self.dim = dim 

256 

257 async def embed(self, texts: List[str]) -> List[List[float]]: 

258 results = [] 

259 for text in texts: 

260 vec = [0.0] * self.dim 

261 # Character 3-gram hashing 

262 for i in range(len(text) - 2): 

263 gram = text[i:i+3] 

264 h = hash(gram) % self.dim 

265 vec[h] += 1.0 

266 # L2 normalize 

267 norm = sum(v * v for v in vec) ** 0.5 

268 if norm > 0: 

269 vec = [v / norm for v in vec] 

270 results.append(vec) 

271 return results 

272 

273 

274# ── Reflection Engine ─────────────────────── 

275 

276 

277class ReflectionConfig: 

278 """Reflection 触发配置。""" 

279 

280 def __init__( 

281 self, 

282 min_messages_since_last: int = 10, 

283 min_seconds_since_last: float = 300.0, # 5 分钟 

284 max_conversation_turns: int = 50, 

285 auto_reflect: bool = True, 

286 ): 

287 self.min_messages_since_last = min_messages_since_last 

288 self.min_seconds_since_last = min_seconds_since_last 

289 self.max_conversation_turns = max_conversation_turns 

290 self.auto_reflect = auto_reflect 

291 

292 

293class ReflectionEngine: 

294 """记忆反思引擎。 

295 

296 定期分析对话历史,提取: 

297 - Facts: 用户提到的具体信息 

298 - Preferences: 用户偏好与习惯 

299 - Patterns: 反复出现的行为模式 

300 - Lessons: 从错误中学到的经验 

301 

302 Usage: 

303 engine = ReflectionEngine(llm_reflect_fn, vector_store, embedding_provider) 

304 # 在 agent loop 中定期调用 

305 should_reflect = engine.should_reflect(message_count) 

306 if should_reflect: 

307 result = await engine.reflect(messages_history) 

308 """ 

309 

310 def __init__( 

311 self, 

312 llm_reflect_fn: Optional[Callable[[List[dict], str], Any]] = None, 

313 vector_store: Optional[Any] = None, 

314 embedding_provider: Optional[Any] = None, 

315 config: Optional[ReflectionConfig] = None, 

316 ): 

317 """ 

318 Args: 

319 llm_reflect_fn: LLM 调用函数,签名 (messages, prompt) -> reflection_text 

320 vector_store: 向量存储后端 

321 embedding_provider: 嵌入向量生成器 

322 config: 触发配置 

323 """ 

324 self._llm_reflect = llm_reflect_fn 

325 self._vector_store = vector_store 

326 self._embedding_provider = embedding_provider 

327 self.config = config or ReflectionConfig() 

328 self._last_reflection_time: float = 0.0 

329 self._message_count_since_reflection: int = 0 

330 self._reflection_count: int = 0 

331 

332 def should_reflect(self, current_message_count: int) -> bool: 

333 """判断是否应该触发 Reflection。""" 

334 if not self.config.auto_reflect: 

335 return False 

336 if self._reflection_count == 0 and current_message_count >= 5: 

337 return True # 首次在 5 条消息后触发 

338 msg_check = ( 

339 self._message_count_since_reflection 

340 >= self.config.min_messages_since_last 

341 ) 

342 time_check = ( 

343 time.time() - self._last_reflection_time 

344 >= self.config.min_seconds_since_last 

345 ) 

346 return msg_check or time_check 

347 

348 async def reflect( 

349 self, 

350 messages: List[dict], 

351 existing_fragments: Optional[List[MemoryFragment]] = None, 

352 ) -> ReflectionResult: 

353 """执行一次 Reflection。 

354 

355 Args: 

356 messages: 对话历史(dict 列表,含 role/content) 

357 existing_fragments: 已有的记忆片段(用于矛盾检测) 

358 

359 Returns: 

360 ReflectionResult 含新提取的记忆片段 

361 """ 

362 self._last_reflection_time = time.time() 

363 self._reflection_count += 1 

364 

365 # 1. 构建 Reflection prompt 

366 prompt = self._build_reflection_prompt(messages, existing_fragments) 

367 

368 # 2. 调用 LLM 提取记忆 

369 reflection_text = await self._llm_reflect(messages, prompt) 

370 

371 # 3. 解析 LLM 输出 

372 result = self._parse_reflection_output(reflection_text, len(messages)) 

373 

374 # 4. 生成嵌入向量 

375 if result.fragments: 

376 texts = [f.content for f in result.fragments] 

377 embeddings = await self._embedding_provider.embed(texts) 

378 for frag, emb in zip(result.fragments, embeddings): 

379 frag.embedding = emb 

380 

381 # 5. 存入向量库 

382 if result.fragments: 

383 await self._vector_store.add(result.fragments, embeddings) 

384 

385 # 6. 淘汰旧记忆 

386 if result.deprecated_ids: 

387 await self._vector_store.delete(result.deprecated_ids) 

388 

389 # Reset counter 

390 self._message_count_since_reflection = 0 

391 

392 return result 

393 

394 def record_message(self) -> None: 

395 """记录一条新消息(用于计数触发)。""" 

396 self._message_count_since_reflection += 1 

397 

398 async def retrieve_relevant( 

399 self, 

400 query: str, 

401 top_k: int = 5, 

402 filter_types: Optional[List[MemoryType]] = None, 

403 ) -> List[MemoryFragment]: 

404 """检索与查询相关的记忆。 

405 

406 Args: 

407 query: 查询文本 

408 top_k: 返回数量 

409 filter_types: 按类型过滤 

410 

411 Returns: 

412 相关记忆片段列表 

413 """ 

414 query_embedding = await self._embedding_provider.embed_single(query) 

415 results = await self._vector_store.search( 

416 query_embedding, 

417 top_k=top_k, 

418 filter_types=filter_types, 

419 ) 

420 fragments = [] 

421 for frag, score in results: 

422 frag.touch() 

423 fragments.append(frag) 

424 return fragments 

425 

426 def _build_reflection_prompt( 

427 self, 

428 messages: List[dict], 

429 existing_fragments: Optional[List[MemoryFragment]] = None, 

430 ) -> str: 

431 """构建 Reflection prompt。""" 

432 existing_str = "" 

433 if existing_fragments: 

434 existing_items = [ 

435 f"- [{f.memory_type.value}] {f.content}" 

436 for f in existing_fragments[:20] 

437 ] 

438 existing_str = ( 

439 "\n\nExisting memories:\n" + "\n".join(existing_items) 

440 ) 

441 

442 return f"""You are a memory consolidation system. Analyze the conversation and extract: 

443 

4441. FACTS: Specific information mentioned (names, dates, numbers, tools used, decisions made) 

4452. PREFERENCES: User preferences, likes, dislikes, habits 

4463. PATTERNS: Repeated behaviors, common workflows, recurring topics 

4474. LESSONS: What went wrong, what worked, what to avoid next time 

448 

449For each extracted item, assign: 

450- type: "fact" | "preference" | "pattern" | "decision" | "lesson" 

451- importance: "low" | "medium" | "high" | "critical" 

452- confidence: 0.0 to 1.0 

453 

454{existing_str} 

455 

456Output JSON array only: 

457[{{"type": "...", "content": "...", "importance": "...", "confidence": 0.9, "tags": ["..."]}}] 

458 

459If nothing significant to extract, output empty array: []""" 

460 

461 def _parse_reflection_output( 

462 self, 

463 text: str, 

464 source_msg_count: int, 

465 ) -> ReflectionResult: 

466 """解析 LLM 输出的 JSON。""" 

467 result = ReflectionResult() 

468 

469 try: 

470 # Extract JSON array 

471 start = text.find("[") 

472 end = text.rfind("]") 

473 if start >= 0 and end > start: 

474 json_str = text[start:end + 1] 

475 items = json.loads(json_str) 

476 for item in items: 

477 frag = MemoryFragment( 

478 memory_type=MemoryType(item.get("type", "fact")), 

479 content=item.get("content", ""), 

480 importance=MemoryImportance( 

481 item.get("importance", "medium") 

482 ), 

483 confidence=float(item.get("confidence", 1.0)), 

484 tags=item.get("tags", []), 

485 source_messages=list( 

486 range( 

487 max(0, source_msg_count - 20), 

488 source_msg_count, 

489 ) 

490 ), 

491 ) 

492 if frag.content.strip(): 

493 result.fragments.append(frag) 

494 except (json.JSONDecodeError, KeyError, ValueError): 

495 pass 

496 

497 return result 

498 

499 @property 

500 def stats(self) -> Dict[str, Any]: 

501 return { 

502 "reflection_count": self._reflection_count, 

503 "last_reflection_time": self._last_reflection_time, 

504 "messages_since_last": self._message_count_since_reflection, 

505 } 

506 

507 # ── Persistence (v1.14.9) ──────────────── 

508 

509 def get_state(self) -> Dict[str, Any]: 

510 """Export ReflectionEngine state for persistence.""" 

511 return { 

512 "reflection_count": self._reflection_count, 

513 "last_reflection_time": self._last_reflection_time, 

514 "message_count_since_reflection": self._message_count_since_reflection, 

515 "vector_store_fragments": { 

516 mid: frag.to_dict() 

517 for mid, frag in self._vector_store._fragments.items() 

518 } if hasattr(self._vector_store, '_fragments') and self._vector_store else {}, 

519 "vector_store_embeddings": { 

520 mid: list(emb) if emb else [] 

521 for mid, emb in (self._vector_store._embeddings.items() 

522 if hasattr(self._vector_store, '_embeddings') and self._vector_store 

523 else {}.items()) 

524 }, 

525 } 

526 

527 def restore_state(self, state: Dict[str, Any]) -> None: 

528 """Restore ReflectionEngine from a persisted snapshot.""" 

529 self._reflection_count = state.get("reflection_count", 0) 

530 self._last_reflection_time = state.get("last_reflection_time", 0.0) 

531 self._message_count_since_reflection = state.get("message_count_since_reflection", 0) 

532 

533 if self._vector_store and hasattr(self._vector_store, '_fragments'): 

534 self._vector_store._fragments.clear() 

535 self._vector_store._embeddings.clear() 

536 for mid, frag_data in state.get("vector_store_fragments", {}).items(): 

537 self._vector_store._fragments[mid] = MemoryFragment.from_dict(frag_data) 

538 for mid, emb in state.get("vector_store_embeddings", {}).items(): 

539 self._vector_store._embeddings[mid] = emb 

540 

541 

542# ── Memory Context Injector ───────────────── 

543 

544 

545class MemoryContextInjector: 

546 """记忆上下文注入器。 

547 

548 在每次 Agent 对话开始时,自动检索相关历史记忆, 

549 注入到 system prompt 或上下文中。 

550 

551 Usage: 

552 injector = MemoryContextInjector(reflection_engine) 

553 context = await injector.build_context("user query here") 

554 messages.insert(0, {"role": "system", "content": context}) 

555 """ 

556 

557 def __init__( 

558 self, 

559 reflection_engine: ReflectionEngine, 

560 max_context_length: int = 2000, 

561 max_fragments: int = 5, 

562 ): 

563 self._engine = reflection_engine 

564 self.max_context_length = max_context_length 

565 self.max_fragments = max_fragments 

566 

567 async def build_context( 

568 self, 

569 query: str, 

570 include_types: Optional[List[MemoryType]] = None, 

571 ) -> str: 

572 """构建上下文注入文本。""" 

573 fragments = await self._engine.retrieve_relevant( 

574 query, 

575 top_k=self.max_fragments, 

576 filter_types=include_types, 

577 ) 

578 

579 if not fragments: 

580 return "" 

581 

582 lines = ["[Relevant Memories]"] 

583 for frag in fragments: 

584 lines.append( 

585 f"- [{frag.memory_type.value}] {frag.content}" 

586 f" (confidence: {frag.confidence:.0%})" 

587 ) 

588 

589 context = "\n".join(lines) 

590 if len(context) > self.max_context_length: 

591 context = context[:self.max_context_length] + "..." 

592 

593 return context 

594 

595 async def build_condensed_context( 

596 self, 

597 query: str, 

598 ) -> str: 

599 """构建紧凑上下文(仅高重要性记忆)。""" 

600 fragments = await self._engine.retrieve_relevant( 

601 query, 

602 top_k=self.max_fragments, 

603 ) 

604 # Filter: only HIGH/CRITICAL 

605 important = [ 

606 f for f in fragments 

607 if f.importance in (MemoryImportance.HIGH, MemoryImportance.CRITICAL) 

608 ] 

609 if not important: 

610 return "" 

611 

612 lines = ["[Key Context]"] 

613 for frag in important[:3]: 

614 lines.append(f"- {frag.content}") 

615 

616 return "\n".join(lines) 

617 

618 

619# ── Memory Consolidation Pipeline ─────────── 

620 

621 

622class MemoryConsolidationPipeline: 

623 """记忆巩固流水线(一键集成)。 

624 

625 组合 ReflectionEngine + MemoryContextInjector, 

626 提供开箱即用的记忆系统。 

627 

628 Usage: 

629 pipeline = MemoryConsolidationPipeline(llm_fn) 

630 # 在 agent loop 中: 

631 pipeline.record_message() 

632 if pipeline.should_reflect(): 

633 await pipeline.reflect(messages) 

634 context = await pipeline.get_context(user_query) 

635 """ 

636 

637 def __init__( 

638 self, 

639 llm_reflect_fn: Callable, 

640 vector_store: Optional[VectorStoreBackend] = None, 

641 embedding_provider: Optional[EmbeddingProvider] = None, 

642 config: Optional[ReflectionConfig] = None, 

643 ): 

644 self._vector_store = vector_store or InMemoryVectorStore() 

645 self._embedding_provider = embedding_provider or SimpleHashEmbedding(128) 

646 self._reflection_engine = ReflectionEngine( 

647 llm_reflect_fn=llm_reflect_fn, 

648 vector_store=self._vector_store, 

649 embedding_provider=self._embedding_provider, 

650 config=config, 

651 ) 

652 self._injector = MemoryContextInjector(self._reflection_engine) 

653 

654 def record_message(self) -> None: 

655 self._reflection_engine.record_message() 

656 

657 def should_reflect(self) -> bool: 

658 return self._reflection_engine.should_reflect( 

659 self._reflection_engine._message_count_since_reflection 

660 ) 

661 

662 async def reflect(self, messages: List[dict]) -> ReflectionResult: 

663 return await self._reflection_engine.reflect(messages) 

664 

665 async def get_context(self, query: str) -> str: 

666 return await self._injector.build_context(query) 

667 

668 async def get_condensed_context(self, query: str) -> str: 

669 return await self._injector.build_condensed_context(query) 

670 

671 @property 

672 def stats(self) -> Dict[str, Any]: 

673 return { 

674 "reflection": self._reflection_engine.stats, 

675 "total_memories": asyncio.get_event_loop().run_until_complete( 

676 self._vector_store.count() 

677 ) if asyncio.get_event_loop().is_running() else 0, 

678 } 

679 

680 # ── Persistence (v1.14.9) ──────────────── 

681 

682 def get_state(self) -> Dict[str, Any]: 

683 """Export consolidation pipeline state for persistence. Delegates to ReflectionEngine.""" 

684 return self._reflection_engine.get_state() 

685 

686 def restore_state(self, state: Dict[str, Any]) -> None: 

687 """Restore consolidation pipeline from a persisted snapshot.""" 

688 self._reflection_engine.restore_state(state)