Coverage for agentos/memory/consolidation.py: 37%
265 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
1"""
2AgentOS v1.14.1 — 长期记忆巩固系统 (Memory Consolidation)。
4受 Letta/MemGPT 三层记忆体系启发,在虚拟内存分页器之上增加主动记忆巩固层。
5核心机制:
6- Reflection: 定期分析对话历史,提取关键事实、模式、用户偏好
7- Consolidation: 将 Reflection 结果写入长期向量存储
8- Retrieval: 智能检索历史记忆,注入后续对话上下文
10与 memory/pager.py 的关系:
11- pager.py: 被动分页(上下文窗口溢出时 page_out / keyword search page_in)
12- consolidation.py: 主动巩固(定期分析 → 提取 → 向量化存储)
13"""
15from __future__ import annotations
17import asyncio
18import json
19import time
20import uuid
21from dataclasses import dataclass, field
22from enum import Enum
23from typing import (
24 Any, Callable, Dict, List, Optional, Set, Tuple, Union,
25)
28# ── Memory Data Models ──────────────────────
31class MemoryType(str, Enum):
32 """记忆类型。"""
33 FACT = "fact" # 事实性信息
34 PREFERENCE = "preference" # 用户偏好
35 PATTERN = "pattern" # 行为模式
36 DECISION = "decision" # 决策记录
37 LESSON = "lesson" # 经验教训
38 CONTEXT = "context" # 上下文摘要
41class MemoryImportance(str, Enum):
42 """记忆重要性。"""
43 LOW = "low"
44 MEDIUM = "medium"
45 HIGH = "high"
46 CRITICAL = "critical"
49@dataclass
50class MemoryFragment:
51 """记忆片段 — 从对话中提取的原子事实。"""
53 memory_id: str = field(default_factory=lambda: f"mem-{uuid.uuid4().hex[:12]}")
54 memory_type: MemoryType = MemoryType.FACT
55 content: str = ""
56 importance: MemoryImportance = MemoryImportance.MEDIUM
57 source_messages: List[int] = field(default_factory=list) # 源自哪些消息
58 tags: List[str] = field(default_factory=list)
59 confidence: float = 1.0 # 0.0~1.0 置信度
60 created_at: float = field(default_factory=time.time)
61 last_accessed: float = field(default_factory=time.time)
62 access_count: int = 0
63 embedding: Optional[List[float]] = None # 向量嵌入(惰性计算)
64 metadata: Dict[str, Any] = field(default_factory=dict)
66 def to_dict(self) -> dict:
67 return {
68 "memory_id": self.memory_id,
69 "memory_type": self.memory_type.value,
70 "content": self.content,
71 "importance": self.importance.value,
72 "source_messages": self.source_messages,
73 "tags": self.tags,
74 "confidence": self.confidence,
75 "created_at": self.created_at,
76 "last_accessed": self.last_accessed,
77 "access_count": self.access_count,
78 "metadata": self.metadata,
79 }
81 @classmethod
82 def from_dict(cls, d: dict) -> "MemoryFragment":
83 return cls(
84 memory_id=d.get("memory_id", ""),
85 memory_type=MemoryType(d.get("memory_type", "fact")),
86 content=d.get("content", ""),
87 importance=MemoryImportance(d.get("importance", "medium")),
88 source_messages=d.get("source_messages", []),
89 tags=d.get("tags", []),
90 confidence=d.get("confidence", 1.0),
91 created_at=d.get("created_at", time.time()),
92 last_accessed=d.get("last_accessed", time.time()),
93 access_count=d.get("access_count", 0),
94 metadata=d.get("metadata", {}),
95 )
97 def touch(self) -> None:
98 """更新访问时间。"""
99 self.last_accessed = time.time()
100 self.access_count += 1
103@dataclass
104class ReflectionResult:
105 """一次 Reflection 的输出。"""
107 fragments: List[MemoryFragment] = field(default_factory=list)
108 summary: str = "" # 会话级摘要
109 contradictions: List[Tuple[str, str]] = field(default_factory=list) # 新旧矛盾
110 deprecated_ids: List[str] = field(default_factory=list) # 需淘汰的旧记忆
111 user_profile_update: Dict[str, Any] = field(default_factory=dict)
112 timestamp: float = field(default_factory=time.time)
114 @property
115 def total_fragments(self) -> int:
116 return len(self.fragments)
118 @property
119 def has_insights(self) -> bool:
120 return bool(self.fragments or self.summary or self.user_profile_update)
123# ── Vector Store Interface ──────────────────
126class VectorStoreBackend:
127 """向量存储后端抽象。
129 支持多种后端: 内存、FAISS、Chroma、Pinecone 等。
130 """
132 async def add(
133 self,
134 fragments: List[MemoryFragment],
135 embeddings: List[List[float]],
136 ) -> List[str]:
137 """批量添加记忆片段(含嵌入向量)。返回 memory_ids。"""
138 raise NotImplementedError
140 async def search(
141 self,
142 query_embedding: List[float],
143 top_k: int = 10,
144 filter_types: Optional[List[MemoryType]] = None,
145 min_importance: MemoryImportance = MemoryImportance.LOW,
146 ) -> List[Tuple[MemoryFragment, float]]:
147 """向量相似度搜索。返回 (fragment, score)。"""
148 raise NotImplementedError
150 async def delete(self, memory_ids: List[str]) -> int:
151 """删除指定记忆。返回删除数量。"""
152 raise NotImplementedError
154 async def count(self) -> int:
155 """记忆总数。"""
156 raise NotImplementedError
159class InMemoryVectorStore(VectorStoreBackend):
160 """内存向量存储(开发/测试用)。"""
162 def __init__(self):
163 self._fragments: Dict[str, MemoryFragment] = {}
164 self._embeddings: Dict[str, List[float]] = {}
166 async def add(
167 self,
168 fragments: List[MemoryFragment],
169 embeddings: List[List[float]],
170 ) -> List[str]:
171 ids = []
172 for frag, emb in zip(fragments, embeddings):
173 self._fragments[frag.memory_id] = frag
174 self._embeddings[frag.memory_id] = emb
175 ids.append(frag.memory_id)
176 return ids
178 async def search(
179 self,
180 query_embedding: List[float],
181 top_k: int = 10,
182 filter_types: Optional[List[MemoryType]] = None,
183 min_importance: MemoryImportance = MemoryImportance.LOW,
184 ) -> List[Tuple[MemoryFragment, float]]:
185 results = []
186 importance_rank = {
187 MemoryImportance.LOW: 0,
188 MemoryImportance.MEDIUM: 1,
189 MemoryImportance.HIGH: 2,
190 MemoryImportance.CRITICAL: 3,
191 }
192 min_rank = importance_rank[min_importance]
194 for mid, emb in self._embeddings.items():
195 frag = self._fragments[mid]
196 # Type filter
197 if filter_types and frag.memory_type not in filter_types:
198 continue
199 # Importance filter
200 if importance_rank[frag.importance] < min_rank:
201 continue
202 # Cosine similarity
203 score = self._cosine_similarity(query_embedding, emb)
204 results.append((frag, score))
206 results.sort(key=lambda x: x[1], reverse=True)
207 return results[:top_k]
209 async def delete(self, memory_ids: List[str]) -> int:
210 count = 0
211 for mid in memory_ids:
212 if mid in self._fragments:
213 del self._fragments[mid]
214 self._embeddings.pop(mid, None)
215 count += 1
216 return count
218 async def count(self) -> int:
219 return len(self._fragments)
221 @staticmethod
222 def _cosine_similarity(a: List[float], b: List[float]) -> float:
223 dot = sum(x * y for x, y in zip(a, b))
224 norm_a = sum(x * x for x in a) ** 0.5
225 norm_b = sum(x * x for x in b) ** 0.5
226 if norm_a == 0 or norm_b == 0:
227 return 0.0
228 return dot / (norm_a * norm_b)
231# ── Embedding Provider ──────────────────────
234class EmbeddingProvider:
235 """嵌入向量生成器抽象。"""
237 async def embed(self, texts: List[str]) -> List[List[float]]:
238 """批量生成嵌入向量。"""
239 raise NotImplementedError
241 async def embed_single(self, text: str) -> List[float]:
242 """单条文本嵌入。"""
243 results = await self.embed([text])
244 return results[0]
247class SimpleHashEmbedding(EmbeddingProvider):
248 """简单位次嵌入(开发/测试用,非语义向量)。
250 用字符 n-gram 哈希作为伪嵌入,提供基本的相似度。
251 生产环境应替换为 OpenAI/Cohere 等真实嵌入模型。
252 """
254 def __init__(self, dim: int = 128):
255 self.dim = dim
257 async def embed(self, texts: List[str]) -> List[List[float]]:
258 results = []
259 for text in texts:
260 vec = [0.0] * self.dim
261 # Character 3-gram hashing
262 for i in range(len(text) - 2):
263 gram = text[i:i+3]
264 h = hash(gram) % self.dim
265 vec[h] += 1.0
266 # L2 normalize
267 norm = sum(v * v for v in vec) ** 0.5
268 if norm > 0:
269 vec = [v / norm for v in vec]
270 results.append(vec)
271 return results
274# ── Reflection Engine ───────────────────────
277class ReflectionConfig:
278 """Reflection 触发配置。"""
280 def __init__(
281 self,
282 min_messages_since_last: int = 10,
283 min_seconds_since_last: float = 300.0, # 5 分钟
284 max_conversation_turns: int = 50,
285 auto_reflect: bool = True,
286 ):
287 self.min_messages_since_last = min_messages_since_last
288 self.min_seconds_since_last = min_seconds_since_last
289 self.max_conversation_turns = max_conversation_turns
290 self.auto_reflect = auto_reflect
293class ReflectionEngine:
294 """记忆反思引擎。
296 定期分析对话历史,提取:
297 - Facts: 用户提到的具体信息
298 - Preferences: 用户偏好与习惯
299 - Patterns: 反复出现的行为模式
300 - Lessons: 从错误中学到的经验
302 Usage:
303 engine = ReflectionEngine(llm_reflect_fn, vector_store, embedding_provider)
304 # 在 agent loop 中定期调用
305 should_reflect = engine.should_reflect(message_count)
306 if should_reflect:
307 result = await engine.reflect(messages_history)
308 """
310 def __init__(
311 self,
312 llm_reflect_fn: Optional[Callable[[List[dict], str], Any]] = None,
313 vector_store: Optional[Any] = None,
314 embedding_provider: Optional[Any] = None,
315 config: Optional[ReflectionConfig] = None,
316 ):
317 """
318 Args:
319 llm_reflect_fn: LLM 调用函数,签名 (messages, prompt) -> reflection_text
320 vector_store: 向量存储后端
321 embedding_provider: 嵌入向量生成器
322 config: 触发配置
323 """
324 self._llm_reflect = llm_reflect_fn
325 self._vector_store = vector_store
326 self._embedding_provider = embedding_provider
327 self.config = config or ReflectionConfig()
328 self._last_reflection_time: float = 0.0
329 self._message_count_since_reflection: int = 0
330 self._reflection_count: int = 0
332 def should_reflect(self, current_message_count: int) -> bool:
333 """判断是否应该触发 Reflection。"""
334 if not self.config.auto_reflect:
335 return False
336 if self._reflection_count == 0 and current_message_count >= 5:
337 return True # 首次在 5 条消息后触发
338 msg_check = (
339 self._message_count_since_reflection
340 >= self.config.min_messages_since_last
341 )
342 time_check = (
343 time.time() - self._last_reflection_time
344 >= self.config.min_seconds_since_last
345 )
346 return msg_check or time_check
348 async def reflect(
349 self,
350 messages: List[dict],
351 existing_fragments: Optional[List[MemoryFragment]] = None,
352 ) -> ReflectionResult:
353 """执行一次 Reflection。
355 Args:
356 messages: 对话历史(dict 列表,含 role/content)
357 existing_fragments: 已有的记忆片段(用于矛盾检测)
359 Returns:
360 ReflectionResult 含新提取的记忆片段
361 """
362 self._last_reflection_time = time.time()
363 self._reflection_count += 1
365 # 1. 构建 Reflection prompt
366 prompt = self._build_reflection_prompt(messages, existing_fragments)
368 # 2. 调用 LLM 提取记忆
369 reflection_text = await self._llm_reflect(messages, prompt)
371 # 3. 解析 LLM 输出
372 result = self._parse_reflection_output(reflection_text, len(messages))
374 # 4. 生成嵌入向量
375 if result.fragments:
376 texts = [f.content for f in result.fragments]
377 embeddings = await self._embedding_provider.embed(texts)
378 for frag, emb in zip(result.fragments, embeddings):
379 frag.embedding = emb
381 # 5. 存入向量库
382 if result.fragments:
383 await self._vector_store.add(result.fragments, embeddings)
385 # 6. 淘汰旧记忆
386 if result.deprecated_ids:
387 await self._vector_store.delete(result.deprecated_ids)
389 # Reset counter
390 self._message_count_since_reflection = 0
392 return result
394 def record_message(self) -> None:
395 """记录一条新消息(用于计数触发)。"""
396 self._message_count_since_reflection += 1
398 async def retrieve_relevant(
399 self,
400 query: str,
401 top_k: int = 5,
402 filter_types: Optional[List[MemoryType]] = None,
403 ) -> List[MemoryFragment]:
404 """检索与查询相关的记忆。
406 Args:
407 query: 查询文本
408 top_k: 返回数量
409 filter_types: 按类型过滤
411 Returns:
412 相关记忆片段列表
413 """
414 query_embedding = await self._embedding_provider.embed_single(query)
415 results = await self._vector_store.search(
416 query_embedding,
417 top_k=top_k,
418 filter_types=filter_types,
419 )
420 fragments = []
421 for frag, score in results:
422 frag.touch()
423 fragments.append(frag)
424 return fragments
426 def _build_reflection_prompt(
427 self,
428 messages: List[dict],
429 existing_fragments: Optional[List[MemoryFragment]] = None,
430 ) -> str:
431 """构建 Reflection prompt。"""
432 existing_str = ""
433 if existing_fragments:
434 existing_items = [
435 f"- [{f.memory_type.value}] {f.content}"
436 for f in existing_fragments[:20]
437 ]
438 existing_str = (
439 "\n\nExisting memories:\n" + "\n".join(existing_items)
440 )
442 return f"""You are a memory consolidation system. Analyze the conversation and extract:
4441. FACTS: Specific information mentioned (names, dates, numbers, tools used, decisions made)
4452. PREFERENCES: User preferences, likes, dislikes, habits
4463. PATTERNS: Repeated behaviors, common workflows, recurring topics
4474. LESSONS: What went wrong, what worked, what to avoid next time
449For each extracted item, assign:
450- type: "fact" | "preference" | "pattern" | "decision" | "lesson"
451- importance: "low" | "medium" | "high" | "critical"
452- confidence: 0.0 to 1.0
454{existing_str}
456Output JSON array only:
457[{{"type": "...", "content": "...", "importance": "...", "confidence": 0.9, "tags": ["..."]}}]
459If nothing significant to extract, output empty array: []"""
461 def _parse_reflection_output(
462 self,
463 text: str,
464 source_msg_count: int,
465 ) -> ReflectionResult:
466 """解析 LLM 输出的 JSON。"""
467 result = ReflectionResult()
469 try:
470 # Extract JSON array
471 start = text.find("[")
472 end = text.rfind("]")
473 if start >= 0 and end > start:
474 json_str = text[start:end + 1]
475 items = json.loads(json_str)
476 for item in items:
477 frag = MemoryFragment(
478 memory_type=MemoryType(item.get("type", "fact")),
479 content=item.get("content", ""),
480 importance=MemoryImportance(
481 item.get("importance", "medium")
482 ),
483 confidence=float(item.get("confidence", 1.0)),
484 tags=item.get("tags", []),
485 source_messages=list(
486 range(
487 max(0, source_msg_count - 20),
488 source_msg_count,
489 )
490 ),
491 )
492 if frag.content.strip():
493 result.fragments.append(frag)
494 except (json.JSONDecodeError, KeyError, ValueError):
495 pass
497 return result
499 @property
500 def stats(self) -> Dict[str, Any]:
501 return {
502 "reflection_count": self._reflection_count,
503 "last_reflection_time": self._last_reflection_time,
504 "messages_since_last": self._message_count_since_reflection,
505 }
507 # ── Persistence (v1.14.9) ────────────────
509 def get_state(self) -> Dict[str, Any]:
510 """Export ReflectionEngine state for persistence."""
511 return {
512 "reflection_count": self._reflection_count,
513 "last_reflection_time": self._last_reflection_time,
514 "message_count_since_reflection": self._message_count_since_reflection,
515 "vector_store_fragments": {
516 mid: frag.to_dict()
517 for mid, frag in self._vector_store._fragments.items()
518 } if hasattr(self._vector_store, '_fragments') and self._vector_store else {},
519 "vector_store_embeddings": {
520 mid: list(emb) if emb else []
521 for mid, emb in (self._vector_store._embeddings.items()
522 if hasattr(self._vector_store, '_embeddings') and self._vector_store
523 else {}.items())
524 },
525 }
527 def restore_state(self, state: Dict[str, Any]) -> None:
528 """Restore ReflectionEngine from a persisted snapshot."""
529 self._reflection_count = state.get("reflection_count", 0)
530 self._last_reflection_time = state.get("last_reflection_time", 0.0)
531 self._message_count_since_reflection = state.get("message_count_since_reflection", 0)
533 if self._vector_store and hasattr(self._vector_store, '_fragments'):
534 self._vector_store._fragments.clear()
535 self._vector_store._embeddings.clear()
536 for mid, frag_data in state.get("vector_store_fragments", {}).items():
537 self._vector_store._fragments[mid] = MemoryFragment.from_dict(frag_data)
538 for mid, emb in state.get("vector_store_embeddings", {}).items():
539 self._vector_store._embeddings[mid] = emb
542# ── Memory Context Injector ─────────────────
545class MemoryContextInjector:
546 """记忆上下文注入器。
548 在每次 Agent 对话开始时,自动检索相关历史记忆,
549 注入到 system prompt 或上下文中。
551 Usage:
552 injector = MemoryContextInjector(reflection_engine)
553 context = await injector.build_context("user query here")
554 messages.insert(0, {"role": "system", "content": context})
555 """
557 def __init__(
558 self,
559 reflection_engine: ReflectionEngine,
560 max_context_length: int = 2000,
561 max_fragments: int = 5,
562 ):
563 self._engine = reflection_engine
564 self.max_context_length = max_context_length
565 self.max_fragments = max_fragments
567 async def build_context(
568 self,
569 query: str,
570 include_types: Optional[List[MemoryType]] = None,
571 ) -> str:
572 """构建上下文注入文本。"""
573 fragments = await self._engine.retrieve_relevant(
574 query,
575 top_k=self.max_fragments,
576 filter_types=include_types,
577 )
579 if not fragments:
580 return ""
582 lines = ["[Relevant Memories]"]
583 for frag in fragments:
584 lines.append(
585 f"- [{frag.memory_type.value}] {frag.content}"
586 f" (confidence: {frag.confidence:.0%})"
587 )
589 context = "\n".join(lines)
590 if len(context) > self.max_context_length:
591 context = context[:self.max_context_length] + "..."
593 return context
595 async def build_condensed_context(
596 self,
597 query: str,
598 ) -> str:
599 """构建紧凑上下文(仅高重要性记忆)。"""
600 fragments = await self._engine.retrieve_relevant(
601 query,
602 top_k=self.max_fragments,
603 )
604 # Filter: only HIGH/CRITICAL
605 important = [
606 f for f in fragments
607 if f.importance in (MemoryImportance.HIGH, MemoryImportance.CRITICAL)
608 ]
609 if not important:
610 return ""
612 lines = ["[Key Context]"]
613 for frag in important[:3]:
614 lines.append(f"- {frag.content}")
616 return "\n".join(lines)
619# ── Memory Consolidation Pipeline ───────────
622class MemoryConsolidationPipeline:
623 """记忆巩固流水线(一键集成)。
625 组合 ReflectionEngine + MemoryContextInjector,
626 提供开箱即用的记忆系统。
628 Usage:
629 pipeline = MemoryConsolidationPipeline(llm_fn)
630 # 在 agent loop 中:
631 pipeline.record_message()
632 if pipeline.should_reflect():
633 await pipeline.reflect(messages)
634 context = await pipeline.get_context(user_query)
635 """
637 def __init__(
638 self,
639 llm_reflect_fn: Callable,
640 vector_store: Optional[VectorStoreBackend] = None,
641 embedding_provider: Optional[EmbeddingProvider] = None,
642 config: Optional[ReflectionConfig] = None,
643 ):
644 self._vector_store = vector_store or InMemoryVectorStore()
645 self._embedding_provider = embedding_provider or SimpleHashEmbedding(128)
646 self._reflection_engine = ReflectionEngine(
647 llm_reflect_fn=llm_reflect_fn,
648 vector_store=self._vector_store,
649 embedding_provider=self._embedding_provider,
650 config=config,
651 )
652 self._injector = MemoryContextInjector(self._reflection_engine)
654 def record_message(self) -> None:
655 self._reflection_engine.record_message()
657 def should_reflect(self) -> bool:
658 return self._reflection_engine.should_reflect(
659 self._reflection_engine._message_count_since_reflection
660 )
662 async def reflect(self, messages: List[dict]) -> ReflectionResult:
663 return await self._reflection_engine.reflect(messages)
665 async def get_context(self, query: str) -> str:
666 return await self._injector.build_context(query)
668 async def get_condensed_context(self, query: str) -> str:
669 return await self._injector.build_condensed_context(query)
671 @property
672 def stats(self) -> Dict[str, Any]:
673 return {
674 "reflection": self._reflection_engine.stats,
675 "total_memories": asyncio.get_event_loop().run_until_complete(
676 self._vector_store.count()
677 ) if asyncio.get_event_loop().is_running() else 0,
678 }
680 # ── Persistence (v1.14.9) ────────────────
682 def get_state(self) -> Dict[str, Any]:
683 """Export consolidation pipeline state for persistence. Delegates to ReflectionEngine."""
684 return self._reflection_engine.get_state()
686 def restore_state(self, state: Dict[str, Any]) -> None:
687 """Restore consolidation pipeline from a persisted snapshot."""
688 self._reflection_engine.restore_state(state)