Coverage for agentos/rag/pipeline.py: 27%

56 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2RAG Pipeline — 检索增强生成管道。 

3 

4将文档加载 → 向量化 → 检索 → LLM 生成 串联为一条端到端管道。 

5""" 

6 

7from __future__ import annotations 

8 

9from agentos.rag.store import VectorStore, ChromaStore 

10from agentos.rag.loader import DocumentLoader, Document 

11from agentos.agent.tool_agent import ToolAgent, AgentConfig 

12 

13 

14class RAGPipeline: 

15 """检索增强生成管道。 

16 

17 使用流程:: 

18 

19 rag = RAGPipeline(agent) 

20 rag.ingest_file("docs/report.pdf") # 加载文档到向量库 

21 rag.ingest_directory("docs/project/") # 批量加载 

22 answer = rag.query("Q3 收入是多少?") # 检索 + 生成 

23 

24 Args: 

25 agent: 用于生成答案的 ToolAgent 

26 vector_store: 向量存储(默认 ChromaDB 内存模式) 

27 top_k: 检索返回片段数 

28 """ 

29 

30 def __init__( 

31 self, 

32 agent: ToolAgent, 

33 vector_store: VectorStore | None = None, 

34 top_k: int = 5, 

35 ): 

36 self._agent = agent 

37 self._store = vector_store or ChromaStore() 

38 self._top_k = top_k 

39 self._loader = DocumentLoader() 

40 

41 @property 

42 def store(self) -> VectorStore: 

43 return self._store 

44 

45 @property 

46 def doc_count(self) -> int: 

47 return self._store.count() 

48 

49 def ingest_file(self, path: str) -> int: 

50 """加载单个文件到向量库。返回添加的块数。""" 

51 docs = self._loader.load_file(path) 

52 if not docs: 

53 return 0 

54 texts = [d.content for d in docs] 

55 metadatas = [{"source": d.source, "page": d.page} for d in docs] 

56 self._store.add(texts, metadatas) 

57 return len(texts) 

58 

59 def ingest_directory(self, dir_path: str, recursive: bool = True) -> int: 

60 """加载目录到向量库。返回添加的块数。""" 

61 docs = self._loader.load_directory(dir_path, recursive=recursive) 

62 if not docs: 

63 return 0 

64 texts = [d.content for d in docs] 

65 metadatas = [{"source": d.source} for d in docs] 

66 self._store.add(texts, metadatas) 

67 return len(texts) 

68 

69 def ingest_texts(self, texts: list[str], metadatas: list[dict] | None = None) -> int: 

70 """直接添加文本列表到向量库。""" 

71 self._store.add(texts, metadatas) 

72 return len(texts) 

73 

74 def query(self, question: str, top_k: int | None = None) -> str: 

75 """检索 + 生成答案。 

76 

77 Args: 

78 question: 用户问题 

79 top_k: 检索片段数(覆盖默认值) 

80 

81 Returns: 

82 LLM 生成的答案 

83 """ 

84 k = top_k or self._top_k 

85 

86 if self._store.count() == 0: 

87 return self._agent.run(question).final_answer 

88 

89 # 检索 

90 results = self._store.search(question, top_k=k) 

91 if not results: 

92 return self._agent.run(question).final_answer 

93 

94 # 构建上下文 

95 context_parts = [] 

96 sources = set() 

97 for i, r in enumerate(results, 1): 

98 context_parts.append(f"[{i}] {r.content}") 

99 if r.metadata.get("source"): 

100 sources.add(r.metadata["source"]) 

101 

102 context = "\n\n".join(context_parts) 

103 source_list = "\n".join(f"- {s}" for s in sorted(sources)) 

104 

105 # 构建 RAG prompt 

106 rag_task = f"""请基于以下参考资料回答用户的问题。 

107如果参考资料不足以回答问题,请明确说明,不要编造信息。 

108 

109## 参考资料 

110{context} 

111 

112## 来源文件 

113{source_list} 

114 

115## 用户问题 

116{question} 

117 

118请用中文回答。回答时引用具体的资料编号(如 [1])。""" 

119 

120 return self._agent.run(rag_task).final_answer 

121 

122 def retrieve(self, question: str, top_k: int | None = None) -> list[str]: 

123 """仅检索,不生成。返回匹配的文本片段。""" 

124 k = top_k or self._top_k 

125 results = self._store.search(question, top_k=k) 

126 return [r.content for r in results]