Coverage for agentos/rag/pipeline.py: 27%
56 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2RAG Pipeline — 检索增强生成管道。
4将文档加载 → 向量化 → 检索 → LLM 生成 串联为一条端到端管道。
5"""
7from __future__ import annotations
9from agentos.rag.store import VectorStore, ChromaStore
10from agentos.rag.loader import DocumentLoader, Document
11from agentos.agent.tool_agent import ToolAgent, AgentConfig
14class RAGPipeline:
15 """检索增强生成管道。
17 使用流程::
19 rag = RAGPipeline(agent)
20 rag.ingest_file("docs/report.pdf") # 加载文档到向量库
21 rag.ingest_directory("docs/project/") # 批量加载
22 answer = rag.query("Q3 收入是多少?") # 检索 + 生成
24 Args:
25 agent: 用于生成答案的 ToolAgent
26 vector_store: 向量存储(默认 ChromaDB 内存模式)
27 top_k: 检索返回片段数
28 """
30 def __init__(
31 self,
32 agent: ToolAgent,
33 vector_store: VectorStore | None = None,
34 top_k: int = 5,
35 ):
36 self._agent = agent
37 self._store = vector_store or ChromaStore()
38 self._top_k = top_k
39 self._loader = DocumentLoader()
41 @property
42 def store(self) -> VectorStore:
43 return self._store
45 @property
46 def doc_count(self) -> int:
47 return self._store.count()
49 def ingest_file(self, path: str) -> int:
50 """加载单个文件到向量库。返回添加的块数。"""
51 docs = self._loader.load_file(path)
52 if not docs:
53 return 0
54 texts = [d.content for d in docs]
55 metadatas = [{"source": d.source, "page": d.page} for d in docs]
56 self._store.add(texts, metadatas)
57 return len(texts)
59 def ingest_directory(self, dir_path: str, recursive: bool = True) -> int:
60 """加载目录到向量库。返回添加的块数。"""
61 docs = self._loader.load_directory(dir_path, recursive=recursive)
62 if not docs:
63 return 0
64 texts = [d.content for d in docs]
65 metadatas = [{"source": d.source} for d in docs]
66 self._store.add(texts, metadatas)
67 return len(texts)
69 def ingest_texts(self, texts: list[str], metadatas: list[dict] | None = None) -> int:
70 """直接添加文本列表到向量库。"""
71 self._store.add(texts, metadatas)
72 return len(texts)
74 def query(self, question: str, top_k: int | None = None) -> str:
75 """检索 + 生成答案。
77 Args:
78 question: 用户问题
79 top_k: 检索片段数(覆盖默认值)
81 Returns:
82 LLM 生成的答案
83 """
84 k = top_k or self._top_k
86 if self._store.count() == 0:
87 return self._agent.run(question).final_answer
89 # 检索
90 results = self._store.search(question, top_k=k)
91 if not results:
92 return self._agent.run(question).final_answer
94 # 构建上下文
95 context_parts = []
96 sources = set()
97 for i, r in enumerate(results, 1):
98 context_parts.append(f"[{i}] {r.content}")
99 if r.metadata.get("source"):
100 sources.add(r.metadata["source"])
102 context = "\n\n".join(context_parts)
103 source_list = "\n".join(f"- {s}" for s in sorted(sources))
105 # 构建 RAG prompt
106 rag_task = f"""请基于以下参考资料回答用户的问题。
107如果参考资料不足以回答问题,请明确说明,不要编造信息。
109## 参考资料
110{context}
112## 来源文件
113{source_list}
115## 用户问题
116{question}
118请用中文回答。回答时引用具体的资料编号(如 [1])。"""
120 return self._agent.run(rag_task).final_answer
122 def retrieve(self, question: str, top_k: int | None = None) -> list[str]:
123 """仅检索,不生成。返回匹配的文本片段。"""
124 k = top_k or self._top_k
125 results = self._store.search(question, top_k=k)
126 return [r.content for r in results]