Coverage for agentos/vectorstore/db.py: 23%
155 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2AgentOS v0.30 向量数据库集成 — Chroma + FAISS。
3语义记忆检索、知识库索引。
4"""
6from dataclasses import dataclass, field
7from typing import Optional
8import os
9import pickle
10import json
11import uuid
14@dataclass
15class VectorEntry:
16 """向量条目。"""
17 id: str
18 text: str
19 metadata: dict = field(default_factory=dict)
20 score: float = 0.0
23class BaseVectorStore:
24 """向量存储基类。"""
25 def add(self, texts: list[str], metadatas: list[dict] | None = None, ids: list[str] | None = None) -> list[str]: ...
26 def search(self, query: str, top_k: int = 5) -> list[VectorEntry]: ...
27 def delete(self, ids: list[str]): ...
28 def count(self) -> int: ...
31class FAISSVectorStore(BaseVectorStore):
32 """基于 FAISS 的轻量向量存储。"""
34 def __init__(self, dim: int = 768, index_path: str = ""):
35 self.dim = dim
36 self.index_path = index_path
37 self._index = None
38 self._store: dict[str, tuple[list[float], str, dict]] = {}
39 self._next_id = 0
40 if index_path and os.path.exists(index_path):
41 self._load()
43 def _init_index(self):
44 try:
45 import faiss
46 self._index = faiss.IndexFlatIP(self.dim)
47 except ImportError:
48 self._index = None
50 def add(self, texts: list[str], metadatas: list[dict] | None = None, ids: list[str] | None = None) -> list[str]:
51 embeddings = self._embed(texts)
52 if not self._index:
53 self._init_index()
54 if self._index:
55 import numpy as np
56 vecs = np.array(embeddings, dtype=np.float32)
57 self._index.add(vecs)
59 res_ids = []
60 for i, text in enumerate(texts):
61 rid = ids[i] if ids else f"v{self._next_id}"
62 self._next_id += 1
63 self._store[rid] = (embeddings[i], text, metadatas[i] if metadatas else {})
64 res_ids.append(rid)
65 return res_ids
67 def _fallback_search(self, q_vec, top_k):
68 """Fallback余弦相似度搜索(无faiss时使用)。"""
69 import math
70 scores = []
71 for rid, (vec, text, meta) in self._store.items():
72 dot = sum(a*b for a,b in zip(q_vec, vec))
73 na = math.sqrt(sum(a*a for a in q_vec))
74 nb = math.sqrt(sum(b*b for b in vec))
75 sim = dot/(na*nb) if na*nb > 0 else 0.0
76 scores.append((sim, rid, text, meta))
77 scores.sort(key=lambda x: x[0], reverse=True)
78 return [VectorEntry(id=rid, text=text, metadata=meta, score=float(s))
79 for s, rid, text, meta in scores[:top_k]]
81 def search(self, query: str, top_k: int = 5) -> list[VectorEntry]:
82 if not self._store:
83 return []
84 q_vec = self._embed([query])[0]
85 if not self._index:
86 return self._fallback_search(q_vec, top_k)
87 q_vec = self._embed([query])[0]
88 import numpy as np
89 D, I = self._index.search(np.array([q_vec], dtype=np.float32), min(top_k, self.count()))
90 results = []
91 for score, idx in zip(D[0], I[0]):
92 if idx < 0:
93 continue
94 rid = f"v{idx}"
95 if rid in self._store:
96 _, text, meta = self._store[rid]
97 results.append(VectorEntry(id=rid, text=text, metadata=meta, score=float(score)))
98 return results
100 def delete(self, ids: list[str]):
101 for rid in ids:
102 self._store.pop(rid, None)
104 def count(self) -> int:
105 return len(self._store)
107 def _embed(self, texts: list[str]) -> list[list[float]]:
108 """轻量嵌入:使用 all-MiniLM-L6-v2 或回退到 TF-IDF。"""
109 try:
110 from sentence_transformers import SentenceTransformer
111 model = SentenceTransformer("all-MiniLM-L6-v2")
112 embeddings = model.encode(texts, normalize_embeddings=True)
113 return embeddings.tolist()
114 except ImportError:
115 return self._tfidf_embed(texts)
117 def _tfidf_embed(self, texts: list[str]) -> list[list[float]]:
118 """TF-IDF 回退,仅作占位。"""
119 import hashlib
120 dim = self.dim
121 result = []
122 for t in texts:
123 h = hashlib.sha256(t.encode()).digest()
124 vec = [(h[i] / 255.0) for i in range(min(len(h), dim))]
125 vec += [0.0] * (dim - len(vec))
126 result.append(vec)
127 return result
129 def _save(self):
130 if self.index_path:
131 os.makedirs(os.path.dirname(self.index_path) or ".", exist_ok=True)
132 with open(self.index_path, "wb") as f:
133 pickle.dump({"store": self._store, "next_id": self._next_id}, f)
135 def _load(self):
136 with open(self.index_path, "rb") as f:
137 data = pickle.load(f)
138 self._store = data["store"]
139 self._next_id = data["next_id"]
141 def __del__(self):
142 if self.index_path:
143 self._save()
146class ChromaVectorStore(BaseVectorStore):
147 """Chroma 向量存储。"""
149 def __init__(self, collection_name: str = "agentos", persist_dir: str = "./chroma_data"):
150 self.collection_name = collection_name
151 self.persist_dir = persist_dir
152 self._client = None
153 self._collection = None
154 self._init()
156 def _init(self):
157 try:
158 import chromadb
159 self._client = chromadb.PersistentClient(path=self.persist_dir)
160 self._collection = self._client.get_or_create_collection(self.collection_name)
161 except ImportError:
162 self._collection = None
164 def add(self, texts: list[str], metadatas: list[dict] | None = None, ids: list[str] | None = None) -> list[str]:
165 if not self._collection:
166 ids = ids or [f"v{len(self._fallback_store)}-{i}" for i in range(len(texts))]
167 for i, t in enumerate(texts):
168 self._fallback_store[ids[i]] = {"text": t, "metadata": metadatas[i] if metadatas else {}}
169 return ids
171 ids = ids or [str(uuid.uuid4())[:8] for _ in texts]
172 self._collection.add(documents=texts, metadatas=metadatas or [{}] * len(texts), ids=ids)
173 return ids
175 def search(self, query: str, top_k: int = 5) -> list[VectorEntry]:
176 if not self._collection:
177 if self._fallback_store:
178 return [
179 VectorEntry(id=k, text=v["text"], metadata=v["metadata"], score=0.5)
180 for k, v in list(self._fallback_store.items())[:top_k]
181 ]
182 return []
183 results = self._collection.query(query_texts=[query], n_results=top_k)
184 entries = []
185 for i, rid in enumerate(results.get("ids", [[]])[0]):
186 entries.append(
187 VectorEntry(
188 id=rid,
189 text=results["documents"][0][i] if results.get("documents") else "",
190 metadata=results["metadatas"][0][i] if results.get("metadatas") else {},
191 score=1.0 - results["distances"][0][i] if results.get("distances") else 0.0,
192 )
193 )
194 return entries
196 def delete(self, ids: list[str]):
197 if self._collection:
198 self._collection.delete(ids=ids)
199 else:
200 for rid in ids:
201 self._fallback_store.pop(rid, None)
203 def count(self) -> int:
204 if self._collection:
205 return self._collection.count()
206 return len(self._fallback_store)
208 @property
209 def _fallback_store(self) -> dict:
210 if not hasattr(self, "_fb"):
211 self._fb = {}
212 return self._fb