Coverage for agentos/rag/store.py: 31%
71 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2向量存储抽象层 — ChromaDB 封装。
4支持创建/加载 collection、添加文档、语义检索。
5"""
7from __future__ import annotations
9import os
10from abc import ABC, abstractmethod
11from dataclasses import dataclass, field
12from pathlib import Path
13from typing import Any
16DEFAULT_PERSIST_DIR = Path.home() / ".agentos" / "chroma"
19@dataclass
20class SearchResult:
21 """检索结果。"""
22 content: str
23 score: float
24 metadata: dict = field(default_factory=dict)
25 source: str = ""
28class VectorStore(ABC):
29 """向量存储抽象基类。"""
31 @abstractmethod
32 def add(self, texts: list[str], metadatas: list[dict] | None = None, ids: list[str] | None = None):
33 ...
35 @abstractmethod
36 def search(self, query: str, top_k: int = 5) -> list[SearchResult]:
37 ...
39 @abstractmethod
40 def count(self) -> int:
41 ...
43 @abstractmethod
44 def clear(self):
45 ...
48class ChromaStore(VectorStore):
49 """ChromaDB 向量存储实现。
51 Args:
52 collection_name: 集合名称
53 persist_dir: 持久化目录,None 则仅内存模式
54 embedding_model: 嵌入模型名称(默认使用 sentence-transformers 轻量模型)
55 """
57 def __init__(
58 self,
59 collection_name: str = "default",
60 persist_dir: str | None = None,
61 embedding_model: str = "all-MiniLM-L6-v2",
62 ):
63 self._collection_name = collection_name
64 self._persist_dir = persist_dir
65 self._embedding_model = embedding_model
66 self._client = None
67 self._collection = None
68 self._initialized = False
70 def _get_embedding_function(self):
71 """获取 embedding 函数,优先 sentence-transformers,fallback 到 ONNX 内置模型。"""
72 from chromadb.utils import embedding_functions
73 try:
74 import sentence_transformers # noqa: F401
75 return embedding_functions.SentenceTransformerEmbeddingFunction(
76 model_name=self._embedding_model,
77 )
78 except ImportError:
79 return embedding_functions.DefaultEmbeddingFunction()
81 def _ensure_init(self):
82 if self._initialized:
83 return
84 try:
85 import chromadb
86 from chromadb.utils import embedding_functions
88 if self._persist_dir:
89 os.makedirs(self._persist_dir, exist_ok=True)
90 self._client = chromadb.PersistentClient(path=self._persist_dir)
91 else:
92 self._client = chromadb.Client()
94 self._ef = self._get_embedding_function()
95 self._collection = self._client.get_or_create_collection(
96 name=self._collection_name,
97 embedding_function=self._ef,
98 )
99 self._initialized = True
100 except ImportError:
101 raise ImportError(
102 "chromadb 未安装。运行: pip install chromadb sentence-transformers"
103 )
105 def add(self, texts: list[str], metadatas: list[dict] | None = None, ids: list[str] | None = None):
106 self._ensure_init()
107 if ids is None:
108 ids = [str(self.count() + i) for i in range(len(texts))]
109 self._collection.add(documents=texts, metadatas=metadatas or None, ids=ids)
111 def search(self, query: str, top_k: int = 5) -> list[SearchResult]:
112 self._ensure_init()
113 results = self._collection.query(query_texts=[query], n_results=top_k)
114 out = []
115 if results["documents"] and results["documents"][0]:
116 for i in range(len(results["documents"][0])):
117 doc = results["documents"][0][i] or ""
118 score = 0.0
119 if results.get("distances") and results["distances"][0]:
120 score = 1.0 / (1.0 + float(results["distances"][0][i]))
121 meta = {}
122 if results.get("metadatas") and results["metadatas"][0]:
123 meta = results["metadatas"][0][i] or {}
124 out.append(SearchResult(content=doc, score=score, metadata=meta))
125 return out
127 def count(self) -> int:
128 self._ensure_init()
129 return self._collection.count()
131 def clear(self):
132 self._ensure_init()
133 self._client.delete_collection(self._collection_name)
134 self._initialized = False