Coverage for src / documint_mcp / rag.py: 0%

79 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 22:30 -0400

1"""RAG-based patch style learning using LanceDB Cloud. 

2 

3Uses approved patches as few-shot examples to improve patch quality 

4and teach the AI the team's documentation style over time. 

5 

6LanceDB Cloud provides full-text search (FTS) as the primary retrieval 

7method. When Voyage AI is available (optional), results are reranked 

8for improved relevance. 

9 

10Integrate by calling: 

11 store_approved_patch(patch) — in repository.py's patch-approval method 

12 get_few_shot_examples(...) — in ai.py before the sonnet step-3 prompt 

13""" 

14from __future__ import annotations 

15 

16import hashlib 

17import logging 

18import os 

19import time 

20 

21logger = logging.getLogger(__name__) 

22 

23# Minimum approved patches before RAG activates (cold-start grace period) 

24MIN_EXAMPLES = 5 

25 

26_TABLE_NAME = "documint_approved_patches" 

27 

28 

29class PatchRAG: 

30 """Manages a LanceDB Cloud table of approved patches for style learning.""" 

31 

32 def __init__( 

33 self, 

34 uri: str | None = None, 

35 api_key: str | None = None, 

36 region: str | None = None, 

37 ) -> None: 

38 self.available = False 

39 self._db = None 

40 self._table = None 

41 try: 

42 import lancedb # noqa: PLC0415 

43 

44 _uri = uri or os.getenv("LANCEDB_URI", "") 

45 _key = api_key or os.getenv("LANCEDB_API_KEY", "") 

46 _region = region or os.getenv("LANCEDB_REGION", "us-east-1") 

47 

48 if not _uri or not _key: 

49 logger.debug("LanceDB URI or API key not set — RAG disabled") 

50 return 

51 

52 self._db = lancedb.connect(_uri, api_key=_key, region=_region) 

53 

54 # Get or create the table 

55 existing = self._db.table_names() 

56 if _TABLE_NAME in existing: 

57 self._table = self._db.open_table(_TABLE_NAME) 

58 else: 

59 # Seed row required to create schema — will be filtered from results 

60 import pyarrow as pa # noqa: PLC0415 

61 

62 schema = pa.schema([ 

63 pa.field("id", pa.string()), 

64 pa.field("document", pa.string()), 

65 pa.field("artifact_id", pa.string()), 

66 pa.field("project_id", pa.string()), 

67 pa.field("patch_id", pa.string()), 

68 pa.field("created_at", pa.float64()), 

69 ]) 

70 self._table = self._db.create_table( 

71 _TABLE_NAME, 

72 data=[{ 

73 "id": "seed", 

74 "document": "seed document for schema initialization", 

75 "artifact_id": "", 

76 "project_id": "", 

77 "patch_id": "", 

78 "created_at": 0.0, 

79 }], 

80 schema=schema, 

81 ) 

82 # Build FTS index on the document field 

83 self._table.create_fts_index("document") 

84 

85 self.available = True 

86 logger.info("LanceDB RAG connected: %s", _uri) 

87 except Exception as exc: # noqa: BLE001 

88 logger.debug("LanceDB unavailable for RAG style learning: %s", exc) 

89 

90 def store_approved_patch( 

91 self, 

92 patch_id: str, 

93 artifact_id: str, 

94 project_id: str, 

95 section_titles: list[str], 

96 patch_content: str, 

97 ) -> None: 

98 """Store an approved patch for future few-shot retrieval.""" 

99 if not self.available or not self._table: 

100 return 

101 try: 

102 doc = f"{artifact_id}::{', '.join(section_titles)}::{patch_content[:500]}" 

103 doc_id = hashlib.sha256(doc.encode()).hexdigest()[:16] 

104 self._table.add([{ 

105 "id": doc_id, 

106 "document": doc, 

107 "artifact_id": artifact_id, 

108 "project_id": project_id, 

109 "patch_id": patch_id, 

110 "created_at": time.time(), 

111 }]) 

112 except Exception as exc: # noqa: BLE001 

113 logger.debug("Failed to store patch in LanceDB: %s", exc) 

114 

115 def get_few_shot_examples( 

116 self, 

117 artifact_id: str, 

118 stale_sections: list[str], 

119 n: int = 3, 

120 ) -> list[str]: 

121 """Return up to n similar approved patches as few-shot examples. 

122 

123 Uses LanceDB full-text search (FTS) as the primary retrieval method. 

124 When Voyage AI is available, the FTS results are reranked for better 

125 relevance. Falls back to raw FTS ordering if reranking is unavailable. 

126 

127 Returns empty list if LanceDB unavailable or too few examples exist. 

128 """ 

129 if not self.available or not self._table: 

130 return [] 

131 try: 

132 count = self._table.count_rows() 

133 # Subtract 1 for the seed row 

134 real_count = max(0, count - 1) 

135 if real_count < MIN_EXAMPLES: 

136 return [] 

137 

138 query = f"{artifact_id} {' '.join(stale_sections)}" 

139 

140 # Fetch a larger candidate pool so reranking has room to improve 

141 candidate_limit = min(max(n * 3, 10), real_count) 

142 results = ( 

143 self._table 

144 .search(query, query_type="fts") 

145 .where("id != 'seed'", prefilter=True) 

146 .limit(candidate_limit) 

147 .to_list() 

148 ) 

149 documents = [r["document"] for r in results if r.get("document")] 

150 if not documents: 

151 return [] 

152 

153 # ── Voyage reranking (optional enhancement) ── 

154 from . import embeddings # noqa: PLC0415 

155 

156 reranked = embeddings.rerank(query, documents, top_k=n) 

157 if reranked is not None: 

158 return reranked 

159 

160 # Fall back to raw FTS order, trimmed to n 

161 return documents[:n] 

162 except Exception as exc: # noqa: BLE001 

163 logger.debug("Failed to query LanceDB for few-shot examples: %s", exc) 

164 return [] 

165 

166 

167def format_few_shot_prompt(examples: list[str]) -> str: 

168 """Format retrieved examples as a prompt block for the patch step.""" 

169 if not examples: 

170 return "" 

171 lines = [ 

172 "Here are examples of previously-approved documentation patches for similar changes.", 

173 "Match this style and level of detail:\n", 

174 ] 

175 for i, ex in enumerate(examples, 1): 

176 lines.append(f"---EXAMPLE {i}---") 

177 lines.append(ex[:800]) 

178 lines.append("") 

179 lines.append("---") 

180 return "\n".join(lines) 

181 

182 

183_rag_instance: PatchRAG | None = None 

184 

185 

186def get_rag() -> PatchRAG: 

187 """Return the module-level PatchRAG singleton.""" 

188 global _rag_instance # noqa: PLW0603 

189 if _rag_instance is None: 

190 _rag_instance = PatchRAG() 

191 return _rag_instance