Coverage for agentos/vectorstore/db.py: 23%

155 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2AgentOS v0.30 向量数据库集成 — Chroma + FAISS。 

3语义记忆检索、知识库索引。 

4""" 

5 

6from dataclasses import dataclass, field 

7from typing import Optional 

8import os 

9import pickle 

10import json 

11import uuid 

12 

13 

14@dataclass 

15class VectorEntry: 

16 """向量条目。""" 

17 id: str 

18 text: str 

19 metadata: dict = field(default_factory=dict) 

20 score: float = 0.0 

21 

22 

23class BaseVectorStore: 

24 """向量存储基类。""" 

25 def add(self, texts: list[str], metadatas: list[dict] | None = None, ids: list[str] | None = None) -> list[str]: ... 

26 def search(self, query: str, top_k: int = 5) -> list[VectorEntry]: ... 

27 def delete(self, ids: list[str]): ... 

28 def count(self) -> int: ... 

29 

30 

31class FAISSVectorStore(BaseVectorStore): 

32 """基于 FAISS 的轻量向量存储。""" 

33 

34 def __init__(self, dim: int = 768, index_path: str = ""): 

35 self.dim = dim 

36 self.index_path = index_path 

37 self._index = None 

38 self._store: dict[str, tuple[list[float], str, dict]] = {} 

39 self._next_id = 0 

40 if index_path and os.path.exists(index_path): 

41 self._load() 

42 

43 def _init_index(self): 

44 try: 

45 import faiss 

46 self._index = faiss.IndexFlatIP(self.dim) 

47 except ImportError: 

48 self._index = None 

49 

50 def add(self, texts: list[str], metadatas: list[dict] | None = None, ids: list[str] | None = None) -> list[str]: 

51 embeddings = self._embed(texts) 

52 if not self._index: 

53 self._init_index() 

54 if self._index: 

55 import numpy as np 

56 vecs = np.array(embeddings, dtype=np.float32) 

57 self._index.add(vecs) 

58 

59 res_ids = [] 

60 for i, text in enumerate(texts): 

61 rid = ids[i] if ids else f"v{self._next_id}" 

62 self._next_id += 1 

63 self._store[rid] = (embeddings[i], text, metadatas[i] if metadatas else {}) 

64 res_ids.append(rid) 

65 return res_ids 

66 

67 def _fallback_search(self, q_vec, top_k): 

68 """Fallback余弦相似度搜索(无faiss时使用)。""" 

69 import math 

70 scores = [] 

71 for rid, (vec, text, meta) in self._store.items(): 

72 dot = sum(a*b for a,b in zip(q_vec, vec)) 

73 na = math.sqrt(sum(a*a for a in q_vec)) 

74 nb = math.sqrt(sum(b*b for b in vec)) 

75 sim = dot/(na*nb) if na*nb > 0 else 0.0 

76 scores.append((sim, rid, text, meta)) 

77 scores.sort(key=lambda x: x[0], reverse=True) 

78 return [VectorEntry(id=rid, text=text, metadata=meta, score=float(s)) 

79 for s, rid, text, meta in scores[:top_k]] 

80 

81 def search(self, query: str, top_k: int = 5) -> list[VectorEntry]: 

82 if not self._store: 

83 return [] 

84 q_vec = self._embed([query])[0] 

85 if not self._index: 

86 return self._fallback_search(q_vec, top_k) 

87 q_vec = self._embed([query])[0] 

88 import numpy as np 

89 D, I = self._index.search(np.array([q_vec], dtype=np.float32), min(top_k, self.count())) 

90 results = [] 

91 for score, idx in zip(D[0], I[0]): 

92 if idx < 0: 

93 continue 

94 rid = f"v{idx}" 

95 if rid in self._store: 

96 _, text, meta = self._store[rid] 

97 results.append(VectorEntry(id=rid, text=text, metadata=meta, score=float(score))) 

98 return results 

99 

100 def delete(self, ids: list[str]): 

101 for rid in ids: 

102 self._store.pop(rid, None) 

103 

104 def count(self) -> int: 

105 return len(self._store) 

106 

107 def _embed(self, texts: list[str]) -> list[list[float]]: 

108 """轻量嵌入:使用 all-MiniLM-L6-v2 或回退到 TF-IDF。""" 

109 try: 

110 from sentence_transformers import SentenceTransformer 

111 model = SentenceTransformer("all-MiniLM-L6-v2") 

112 embeddings = model.encode(texts, normalize_embeddings=True) 

113 return embeddings.tolist() 

114 except ImportError: 

115 return self._tfidf_embed(texts) 

116 

117 def _tfidf_embed(self, texts: list[str]) -> list[list[float]]: 

118 """TF-IDF 回退,仅作占位。""" 

119 import hashlib 

120 dim = self.dim 

121 result = [] 

122 for t in texts: 

123 h = hashlib.sha256(t.encode()).digest() 

124 vec = [(h[i] / 255.0) for i in range(min(len(h), dim))] 

125 vec += [0.0] * (dim - len(vec)) 

126 result.append(vec) 

127 return result 

128 

129 def _save(self): 

130 if self.index_path: 

131 os.makedirs(os.path.dirname(self.index_path) or ".", exist_ok=True) 

132 with open(self.index_path, "wb") as f: 

133 pickle.dump({"store": self._store, "next_id": self._next_id}, f) 

134 

135 def _load(self): 

136 with open(self.index_path, "rb") as f: 

137 data = pickle.load(f) 

138 self._store = data["store"] 

139 self._next_id = data["next_id"] 

140 

141 def __del__(self): 

142 if self.index_path: 

143 self._save() 

144 

145 

146class ChromaVectorStore(BaseVectorStore): 

147 """Chroma 向量存储。""" 

148 

149 def __init__(self, collection_name: str = "agentos", persist_dir: str = "./chroma_data"): 

150 self.collection_name = collection_name 

151 self.persist_dir = persist_dir 

152 self._client = None 

153 self._collection = None 

154 self._init() 

155 

156 def _init(self): 

157 try: 

158 import chromadb 

159 self._client = chromadb.PersistentClient(path=self.persist_dir) 

160 self._collection = self._client.get_or_create_collection(self.collection_name) 

161 except ImportError: 

162 self._collection = None 

163 

164 def add(self, texts: list[str], metadatas: list[dict] | None = None, ids: list[str] | None = None) -> list[str]: 

165 if not self._collection: 

166 ids = ids or [f"v{len(self._fallback_store)}-{i}" for i in range(len(texts))] 

167 for i, t in enumerate(texts): 

168 self._fallback_store[ids[i]] = {"text": t, "metadata": metadatas[i] if metadatas else {}} 

169 return ids 

170 

171 ids = ids or [str(uuid.uuid4())[:8] for _ in texts] 

172 self._collection.add(documents=texts, metadatas=metadatas or [{}] * len(texts), ids=ids) 

173 return ids 

174 

175 def search(self, query: str, top_k: int = 5) -> list[VectorEntry]: 

176 if not self._collection: 

177 if self._fallback_store: 

178 return [ 

179 VectorEntry(id=k, text=v["text"], metadata=v["metadata"], score=0.5) 

180 for k, v in list(self._fallback_store.items())[:top_k] 

181 ] 

182 return [] 

183 results = self._collection.query(query_texts=[query], n_results=top_k) 

184 entries = [] 

185 for i, rid in enumerate(results.get("ids", [[]])[0]): 

186 entries.append( 

187 VectorEntry( 

188 id=rid, 

189 text=results["documents"][0][i] if results.get("documents") else "", 

190 metadata=results["metadatas"][0][i] if results.get("metadatas") else {}, 

191 score=1.0 - results["distances"][0][i] if results.get("distances") else 0.0, 

192 ) 

193 ) 

194 return entries 

195 

196 def delete(self, ids: list[str]): 

197 if self._collection: 

198 self._collection.delete(ids=ids) 

199 else: 

200 for rid in ids: 

201 self._fallback_store.pop(rid, None) 

202 

203 def count(self) -> int: 

204 if self._collection: 

205 return self._collection.count() 

206 return len(self._fallback_store) 

207 

208 @property 

209 def _fallback_store(self) -> dict: 

210 if not hasattr(self, "_fb"): 

211 self._fb = {} 

212 return self._fb