Coverage for agentos/memory/long_term.py: 27%

103 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 16:01 +0800

1""" 

2AgentOS v0.20 长期记忆系统。 

3RAG检索 + 知识图谱双重记忆。 

4""" 

5 

6from __future__ import annotations 

7 

8from dataclasses import dataclass, field 

9from typing import Any 

10 

11 

12@dataclass 

13class MemoryEntry: 

14 """长期记忆条目。""" 

15 id: str 

16 content: str 

17 embedding: list[float] | None = None 

18 metadata: dict[str, Any] = field(default_factory=dict) 

19 created_at: float = 0.0 

20 

21 

22class LongTermMemory: 

23 """ 

24 长期记忆 — RAG + 知识图谱。 

25 

26 功能: 

27 - 语义检索(向量相似度) 

28 - 关键词检索(倒排索引) 

29 - 实体关系图(知识图谱) 

30 - 记忆衰减(时间加权) 

31 - 自动摘要压缩 

32 """ 

33 

34 def __init__(self, embedding_dim: int = 1536, max_entries: int = 100000): 

35 self._entries: dict[str, MemoryEntry] = {} 

36 self._keyword_index: dict[str, set[str]] = {} 

37 self._entity_graph: dict[str, set[tuple[str, str]]] = {} 

38 self._embedding_dim = embedding_dim 

39 self._max_entries = max_entries 

40 

41 def add(self, entry: MemoryEntry): 

42 """添加记忆条目。""" 

43 if len(self._entries) >= self._max_entries: 

44 self._evict_oldest() 

45 self._entries[entry.id] = entry 

46 self._index_keywords(entry) 

47 

48 def search_by_keyword(self, query: str, top_k: int = 10) -> list[MemoryEntry]: 

49 """关键词检索。""" 

50 keywords = query.lower().split() 

51 scores: dict[str, int] = {} 

52 for kw in keywords: 

53 for entry_id in self._keyword_index.get(kw, set()): 

54 scores[entry_id] = scores.get(entry_id, 0) + 1 

55 ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] 

56 return [self._entries[eid] for eid, _ in ranked if eid in self._entries] 

57 

58 def search_by_vector(self, query_embedding: list[float], top_k: int = 10) -> list[MemoryEntry]: 

59 """向量相似度检索(余弦相似度)。""" 

60 def cosine(a, b): 

61 dot = sum(x * y for x, y in zip(a, b)) 

62 norm_a = sum(x * x for x in a) ** 0.5 

63 norm_b = sum(x * x for x in b) ** 0.5 

64 return dot / (norm_a * norm_b + 1e-8) 

65 

66 scored = [] 

67 for entry in self._entries.values(): 

68 if entry.embedding: 

69 sim = cosine(query_embedding, entry.embedding) 

70 scored.append((sim, entry)) 

71 scored.sort(key=lambda x: x[0], reverse=True) 

72 return [entry for _, entry in scored[:top_k]] 

73 

74 def add_relation(self, entity_a: str, relation: str, entity_b: str): 

75 """添加知识图谱三元组。""" 

76 self._entity_graph.setdefault(entity_a, set()).add((relation, entity_b)) 

77 self._entity_graph.setdefault(entity_b, set()).add((relation + "_reverse", entity_a)) 

78 

79 def query_relations(self, entity: str, depth: int = 1) -> list[tuple[str, str]]: 

80 """查询实体的关系。""" 

81 results = list(self._entity_graph.get(entity, set())) 

82 return results[:50] 

83 

84 def _index_keywords(self, entry: MemoryEntry): 

85 for word in entry.content.lower().split(): 

86 clean = "".join(c for c in word if c.isalnum()) 

87 if clean and len(clean) > 1: 

88 self._keyword_index.setdefault(clean, set()).add(entry.id) 

89 

90 def _evict_oldest(self): 

91 oldest = min(self._entries.values(), key=lambda e: e.created_at) 

92 del self._entries[oldest.id] 

93 for kw_set in self._keyword_index.values(): 

94 kw_set.discard(oldest.id) 

95 

96 # ── Persistence (v1.14.9) ──────────────── 

97 

98 def get_state(self) -> dict[str, Any]: 

99 """Export LongTermMemory state for persistence.""" 

100 return { 

101 "embedding_dim": self._embedding_dim, 

102 "max_entries": self._max_entries, 

103 "entries": { 

104 eid: { 

105 "id": entry.id, 

106 "content": entry.content, 

107 "embedding": entry.embedding, 

108 "metadata": entry.metadata, 

109 "created_at": entry.created_at, 

110 } 

111 for eid, entry in self._entries.items() 

112 }, 

113 "entity_graph": { 

114 entity: [(r, e) for r, e in relations] 

115 for entity, relations in self._entity_graph.items() 

116 }, 

117 } 

118 

119 def restore_state(self, state: dict[str, Any]) -> None: 

120 """Restore LongTermMemory from a persisted snapshot.""" 

121 self._embedding_dim = state.get("embedding_dim", self._embedding_dim) 

122 self._max_entries = state.get("max_entries", self._max_entries) 

123 self._entries.clear() 

124 self._keyword_index.clear() 

125 self._entity_graph.clear() 

126 

127 for eid, entry_data in state.get("entries", {}).items(): 

128 entry = MemoryEntry( 

129 id=entry_data.get("id", eid), 

130 content=entry_data.get("content", ""), 

131 embedding=entry_data.get("embedding"), 

132 metadata=entry_data.get("metadata", {}), 

133 created_at=entry_data.get("created_at", 0.0), 

134 ) 

135 self._entries[eid] = entry 

136 self._index_keywords(entry) 

137 

138 for entity, relations in state.get("entity_graph", {}).items(): 

139 for rel, target in relations: 

140 self._entity_graph.setdefault(entity, set()).add((rel, target)) 

141 

142 

143class MemoryStore: 

144 """三层记忆系统的统一入口。""" 

145 

146 def __init__(self, long_term: LongTermMemory | None = None): 

147 self.working: dict[str, Any] = {} 

148 self.short_term: list[dict] = [] 

149 self.long_term = long_term or LongTermMemory() 

150 

151 def remember(self, key: str, value: Any, long_term: bool = False): 

152 """存储记忆。""" 

153 if long_term: 

154 entry = MemoryEntry(id=key, content=str(value), created_at=__import__("time").time()) 

155 self.long_term.add(entry) 

156 else: 

157 self.working[key] = value 

158 self.short_term.append({"key": key, "value": value}) 

159 

160 def recall(self, query: str, use_long_term: bool = True) -> list[Any]: 

161 """检索记忆。""" 

162 results = [] 

163 # 工作记忆优先 

164 if query in self.working: 

165 results.append(self.working[query]) 

166 # 短期记忆 

167 for item in self.short_term: 

168 if query.lower() in item["key"].lower(): 

169 results.append(item["value"]) 

170 # 长期记忆 

171 if use_long_term and not results: 

172 long_results = self.long_term.search_by_keyword(query) 

173 results.extend([e.content for e in long_results]) 

174 return results if results else None 

175 

176 def clear_short_term(self): 

177 self.short_term.clear() 

178 

179 # ── Persistence (v1.14.9) ──────────────── 

180 

181 def get_state(self) -> dict[str, Any]: 

182 """Export MemoryStore state for persistence.""" 

183 return { 

184 "working": self.working, 

185 "short_term": self.short_term, 

186 } 

187 

188 def restore_state(self, state: dict[str, Any]) -> None: 

189 """Restore MemoryStore from a persisted snapshot.""" 

190 self.working = dict(state.get("working", {})) 

191 self.short_term = list(state.get("short_term", []))