Coverage for agentos/memory/long_term.py: 27%
103 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:01 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:01 +0800
1"""
2AgentOS v0.20 长期记忆系统。
3RAG检索 + 知识图谱双重记忆。
4"""
6from __future__ import annotations
8from dataclasses import dataclass, field
9from typing import Any
12@dataclass
13class MemoryEntry:
14 """长期记忆条目。"""
15 id: str
16 content: str
17 embedding: list[float] | None = None
18 metadata: dict[str, Any] = field(default_factory=dict)
19 created_at: float = 0.0
22class LongTermMemory:
23 """
24 长期记忆 — RAG + 知识图谱。
26 功能:
27 - 语义检索(向量相似度)
28 - 关键词检索(倒排索引)
29 - 实体关系图(知识图谱)
30 - 记忆衰减(时间加权)
31 - 自动摘要压缩
32 """
34 def __init__(self, embedding_dim: int = 1536, max_entries: int = 100000):
35 self._entries: dict[str, MemoryEntry] = {}
36 self._keyword_index: dict[str, set[str]] = {}
37 self._entity_graph: dict[str, set[tuple[str, str]]] = {}
38 self._embedding_dim = embedding_dim
39 self._max_entries = max_entries
41 def add(self, entry: MemoryEntry):
42 """添加记忆条目。"""
43 if len(self._entries) >= self._max_entries:
44 self._evict_oldest()
45 self._entries[entry.id] = entry
46 self._index_keywords(entry)
48 def search_by_keyword(self, query: str, top_k: int = 10) -> list[MemoryEntry]:
49 """关键词检索。"""
50 keywords = query.lower().split()
51 scores: dict[str, int] = {}
52 for kw in keywords:
53 for entry_id in self._keyword_index.get(kw, set()):
54 scores[entry_id] = scores.get(entry_id, 0) + 1
55 ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
56 return [self._entries[eid] for eid, _ in ranked if eid in self._entries]
58 def search_by_vector(self, query_embedding: list[float], top_k: int = 10) -> list[MemoryEntry]:
59 """向量相似度检索(余弦相似度)。"""
60 def cosine(a, b):
61 dot = sum(x * y for x, y in zip(a, b))
62 norm_a = sum(x * x for x in a) ** 0.5
63 norm_b = sum(x * x for x in b) ** 0.5
64 return dot / (norm_a * norm_b + 1e-8)
66 scored = []
67 for entry in self._entries.values():
68 if entry.embedding:
69 sim = cosine(query_embedding, entry.embedding)
70 scored.append((sim, entry))
71 scored.sort(key=lambda x: x[0], reverse=True)
72 return [entry for _, entry in scored[:top_k]]
74 def add_relation(self, entity_a: str, relation: str, entity_b: str):
75 """添加知识图谱三元组。"""
76 self._entity_graph.setdefault(entity_a, set()).add((relation, entity_b))
77 self._entity_graph.setdefault(entity_b, set()).add((relation + "_reverse", entity_a))
79 def query_relations(self, entity: str, depth: int = 1) -> list[tuple[str, str]]:
80 """查询实体的关系。"""
81 results = list(self._entity_graph.get(entity, set()))
82 return results[:50]
84 def _index_keywords(self, entry: MemoryEntry):
85 for word in entry.content.lower().split():
86 clean = "".join(c for c in word if c.isalnum())
87 if clean and len(clean) > 1:
88 self._keyword_index.setdefault(clean, set()).add(entry.id)
90 def _evict_oldest(self):
91 oldest = min(self._entries.values(), key=lambda e: e.created_at)
92 del self._entries[oldest.id]
93 for kw_set in self._keyword_index.values():
94 kw_set.discard(oldest.id)
96 # ── Persistence (v1.14.9) ────────────────
98 def get_state(self) -> dict[str, Any]:
99 """Export LongTermMemory state for persistence."""
100 return {
101 "embedding_dim": self._embedding_dim,
102 "max_entries": self._max_entries,
103 "entries": {
104 eid: {
105 "id": entry.id,
106 "content": entry.content,
107 "embedding": entry.embedding,
108 "metadata": entry.metadata,
109 "created_at": entry.created_at,
110 }
111 for eid, entry in self._entries.items()
112 },
113 "entity_graph": {
114 entity: [(r, e) for r, e in relations]
115 for entity, relations in self._entity_graph.items()
116 },
117 }
119 def restore_state(self, state: dict[str, Any]) -> None:
120 """Restore LongTermMemory from a persisted snapshot."""
121 self._embedding_dim = state.get("embedding_dim", self._embedding_dim)
122 self._max_entries = state.get("max_entries", self._max_entries)
123 self._entries.clear()
124 self._keyword_index.clear()
125 self._entity_graph.clear()
127 for eid, entry_data in state.get("entries", {}).items():
128 entry = MemoryEntry(
129 id=entry_data.get("id", eid),
130 content=entry_data.get("content", ""),
131 embedding=entry_data.get("embedding"),
132 metadata=entry_data.get("metadata", {}),
133 created_at=entry_data.get("created_at", 0.0),
134 )
135 self._entries[eid] = entry
136 self._index_keywords(entry)
138 for entity, relations in state.get("entity_graph", {}).items():
139 for rel, target in relations:
140 self._entity_graph.setdefault(entity, set()).add((rel, target))
143class MemoryStore:
144 """三层记忆系统的统一入口。"""
146 def __init__(self, long_term: LongTermMemory | None = None):
147 self.working: dict[str, Any] = {}
148 self.short_term: list[dict] = []
149 self.long_term = long_term or LongTermMemory()
151 def remember(self, key: str, value: Any, long_term: bool = False):
152 """存储记忆。"""
153 if long_term:
154 entry = MemoryEntry(id=key, content=str(value), created_at=__import__("time").time())
155 self.long_term.add(entry)
156 else:
157 self.working[key] = value
158 self.short_term.append({"key": key, "value": value})
160 def recall(self, query: str, use_long_term: bool = True) -> list[Any]:
161 """检索记忆。"""
162 results = []
163 # 工作记忆优先
164 if query in self.working:
165 results.append(self.working[query])
166 # 短期记忆
167 for item in self.short_term:
168 if query.lower() in item["key"].lower():
169 results.append(item["value"])
170 # 长期记忆
171 if use_long_term and not results:
172 long_results = self.long_term.search_by_keyword(query)
173 results.extend([e.content for e in long_results])
174 return results if results else None
176 def clear_short_term(self):
177 self.short_term.clear()
179 # ── Persistence (v1.14.9) ────────────────
181 def get_state(self) -> dict[str, Any]:
182 """Export MemoryStore state for persistence."""
183 return {
184 "working": self.working,
185 "short_term": self.short_term,
186 }
188 def restore_state(self, state: dict[str, Any]) -> None:
189 """Restore MemoryStore from a persisted snapshot."""
190 self.working = dict(state.get("working", {}))
191 self.short_term = list(state.get("short_term", []))