Coverage for agentos/rag/store.py: 31%

71 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2向量存储抽象层 — ChromaDB 封装。 

3 

4支持创建/加载 collection、添加文档、语义检索。 

5""" 

6 

7from __future__ import annotations 

8 

9import os 

10from abc import ABC, abstractmethod 

11from dataclasses import dataclass, field 

12from pathlib import Path 

13from typing import Any 

14 

15 

16DEFAULT_PERSIST_DIR = Path.home() / ".agentos" / "chroma" 

17 

18 

19@dataclass 

20class SearchResult: 

21 """检索结果。""" 

22 content: str 

23 score: float 

24 metadata: dict = field(default_factory=dict) 

25 source: str = "" 

26 

27 

28class VectorStore(ABC): 

29 """向量存储抽象基类。""" 

30 

31 @abstractmethod 

32 def add(self, texts: list[str], metadatas: list[dict] | None = None, ids: list[str] | None = None): 

33 ... 

34 

35 @abstractmethod 

36 def search(self, query: str, top_k: int = 5) -> list[SearchResult]: 

37 ... 

38 

39 @abstractmethod 

40 def count(self) -> int: 

41 ... 

42 

43 @abstractmethod 

44 def clear(self): 

45 ... 

46 

47 

48class ChromaStore(VectorStore): 

49 """ChromaDB 向量存储实现。 

50 

51 Args: 

52 collection_name: 集合名称 

53 persist_dir: 持久化目录,None 则仅内存模式 

54 embedding_model: 嵌入模型名称(默认使用 sentence-transformers 轻量模型) 

55 """ 

56 

57 def __init__( 

58 self, 

59 collection_name: str = "default", 

60 persist_dir: str | None = None, 

61 embedding_model: str = "all-MiniLM-L6-v2", 

62 ): 

63 self._collection_name = collection_name 

64 self._persist_dir = persist_dir 

65 self._embedding_model = embedding_model 

66 self._client = None 

67 self._collection = None 

68 self._initialized = False 

69 

70 def _get_embedding_function(self): 

71 """获取 embedding 函数,优先 sentence-transformers,fallback 到 ONNX 内置模型。""" 

72 from chromadb.utils import embedding_functions 

73 try: 

74 import sentence_transformers # noqa: F401 

75 return embedding_functions.SentenceTransformerEmbeddingFunction( 

76 model_name=self._embedding_model, 

77 ) 

78 except ImportError: 

79 return embedding_functions.DefaultEmbeddingFunction() 

80 

81 def _ensure_init(self): 

82 if self._initialized: 

83 return 

84 try: 

85 import chromadb 

86 from chromadb.utils import embedding_functions 

87 

88 if self._persist_dir: 

89 os.makedirs(self._persist_dir, exist_ok=True) 

90 self._client = chromadb.PersistentClient(path=self._persist_dir) 

91 else: 

92 self._client = chromadb.Client() 

93 

94 self._ef = self._get_embedding_function() 

95 self._collection = self._client.get_or_create_collection( 

96 name=self._collection_name, 

97 embedding_function=self._ef, 

98 ) 

99 self._initialized = True 

100 except ImportError: 

101 raise ImportError( 

102 "chromadb 未安装。运行: pip install chromadb sentence-transformers" 

103 ) 

104 

105 def add(self, texts: list[str], metadatas: list[dict] | None = None, ids: list[str] | None = None): 

106 self._ensure_init() 

107 if ids is None: 

108 ids = [str(self.count() + i) for i in range(len(texts))] 

109 self._collection.add(documents=texts, metadatas=metadatas or None, ids=ids) 

110 

111 def search(self, query: str, top_k: int = 5) -> list[SearchResult]: 

112 self._ensure_init() 

113 results = self._collection.query(query_texts=[query], n_results=top_k) 

114 out = [] 

115 if results["documents"] and results["documents"][0]: 

116 for i in range(len(results["documents"][0])): 

117 doc = results["documents"][0][i] or "" 

118 score = 0.0 

119 if results.get("distances") and results["distances"][0]: 

120 score = 1.0 / (1.0 + float(results["distances"][0][i])) 

121 meta = {} 

122 if results.get("metadatas") and results["metadatas"][0]: 

123 meta = results["metadatas"][0][i] or {} 

124 out.append(SearchResult(content=doc, score=score, metadata=meta)) 

125 return out 

126 

127 def count(self) -> int: 

128 self._ensure_init() 

129 return self._collection.count() 

130 

131 def clear(self): 

132 self._ensure_init() 

133 self._client.delete_collection(self._collection_name) 

134 self._initialized = False