Coverage for src / documint_mcp / rag.py: 0%
79 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 22:30 -0400
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 22:30 -0400
1"""RAG-based patch style learning using LanceDB Cloud.
3Uses approved patches as few-shot examples to improve patch quality
4and teach the AI the team's documentation style over time.
6LanceDB Cloud provides full-text search (FTS) as the primary retrieval
7method. When Voyage AI is available (optional), results are reranked
8for improved relevance.
10Integrate by calling:
11 store_approved_patch(patch) — in repository.py's patch-approval method
12 get_few_shot_examples(...) — in ai.py before the sonnet step-3 prompt
13"""
14from __future__ import annotations
16import hashlib
17import logging
18import os
19import time
21logger = logging.getLogger(__name__)
23# Minimum approved patches before RAG activates (cold-start grace period)
24MIN_EXAMPLES = 5
26_TABLE_NAME = "documint_approved_patches"
29class PatchRAG:
30 """Manages a LanceDB Cloud table of approved patches for style learning."""
32 def __init__(
33 self,
34 uri: str | None = None,
35 api_key: str | None = None,
36 region: str | None = None,
37 ) -> None:
38 self.available = False
39 self._db = None
40 self._table = None
41 try:
42 import lancedb # noqa: PLC0415
44 _uri = uri or os.getenv("LANCEDB_URI", "")
45 _key = api_key or os.getenv("LANCEDB_API_KEY", "")
46 _region = region or os.getenv("LANCEDB_REGION", "us-east-1")
48 if not _uri or not _key:
49 logger.debug("LanceDB URI or API key not set — RAG disabled")
50 return
52 self._db = lancedb.connect(_uri, api_key=_key, region=_region)
54 # Get or create the table
55 existing = self._db.table_names()
56 if _TABLE_NAME in existing:
57 self._table = self._db.open_table(_TABLE_NAME)
58 else:
59 # Seed row required to create schema — will be filtered from results
60 import pyarrow as pa # noqa: PLC0415
62 schema = pa.schema([
63 pa.field("id", pa.string()),
64 pa.field("document", pa.string()),
65 pa.field("artifact_id", pa.string()),
66 pa.field("project_id", pa.string()),
67 pa.field("patch_id", pa.string()),
68 pa.field("created_at", pa.float64()),
69 ])
70 self._table = self._db.create_table(
71 _TABLE_NAME,
72 data=[{
73 "id": "seed",
74 "document": "seed document for schema initialization",
75 "artifact_id": "",
76 "project_id": "",
77 "patch_id": "",
78 "created_at": 0.0,
79 }],
80 schema=schema,
81 )
82 # Build FTS index on the document field
83 self._table.create_fts_index("document")
85 self.available = True
86 logger.info("LanceDB RAG connected: %s", _uri)
87 except Exception as exc: # noqa: BLE001
88 logger.debug("LanceDB unavailable for RAG style learning: %s", exc)
90 def store_approved_patch(
91 self,
92 patch_id: str,
93 artifact_id: str,
94 project_id: str,
95 section_titles: list[str],
96 patch_content: str,
97 ) -> None:
98 """Store an approved patch for future few-shot retrieval."""
99 if not self.available or not self._table:
100 return
101 try:
102 doc = f"{artifact_id}::{', '.join(section_titles)}::{patch_content[:500]}"
103 doc_id = hashlib.sha256(doc.encode()).hexdigest()[:16]
104 self._table.add([{
105 "id": doc_id,
106 "document": doc,
107 "artifact_id": artifact_id,
108 "project_id": project_id,
109 "patch_id": patch_id,
110 "created_at": time.time(),
111 }])
112 except Exception as exc: # noqa: BLE001
113 logger.debug("Failed to store patch in LanceDB: %s", exc)
115 def get_few_shot_examples(
116 self,
117 artifact_id: str,
118 stale_sections: list[str],
119 n: int = 3,
120 ) -> list[str]:
121 """Return up to n similar approved patches as few-shot examples.
123 Uses LanceDB full-text search (FTS) as the primary retrieval method.
124 When Voyage AI is available, the FTS results are reranked for better
125 relevance. Falls back to raw FTS ordering if reranking is unavailable.
127 Returns empty list if LanceDB unavailable or too few examples exist.
128 """
129 if not self.available or not self._table:
130 return []
131 try:
132 count = self._table.count_rows()
133 # Subtract 1 for the seed row
134 real_count = max(0, count - 1)
135 if real_count < MIN_EXAMPLES:
136 return []
138 query = f"{artifact_id} {' '.join(stale_sections)}"
140 # Fetch a larger candidate pool so reranking has room to improve
141 candidate_limit = min(max(n * 3, 10), real_count)
142 results = (
143 self._table
144 .search(query, query_type="fts")
145 .where("id != 'seed'", prefilter=True)
146 .limit(candidate_limit)
147 .to_list()
148 )
149 documents = [r["document"] for r in results if r.get("document")]
150 if not documents:
151 return []
153 # ── Voyage reranking (optional enhancement) ──
154 from . import embeddings # noqa: PLC0415
156 reranked = embeddings.rerank(query, documents, top_k=n)
157 if reranked is not None:
158 return reranked
160 # Fall back to raw FTS order, trimmed to n
161 return documents[:n]
162 except Exception as exc: # noqa: BLE001
163 logger.debug("Failed to query LanceDB for few-shot examples: %s", exc)
164 return []
167def format_few_shot_prompt(examples: list[str]) -> str:
168 """Format retrieved examples as a prompt block for the patch step."""
169 if not examples:
170 return ""
171 lines = [
172 "Here are examples of previously-approved documentation patches for similar changes.",
173 "Match this style and level of detail:\n",
174 ]
175 for i, ex in enumerate(examples, 1):
176 lines.append(f"---EXAMPLE {i}---")
177 lines.append(ex[:800])
178 lines.append("")
179 lines.append("---")
180 return "\n".join(lines)
183_rag_instance: PatchRAG | None = None
186def get_rag() -> PatchRAG:
187 """Return the module-level PatchRAG singleton."""
188 global _rag_instance # noqa: PLW0603
189 if _rag_instance is None:
190 _rag_instance = PatchRAG()
191 return _rag_instance