Coverage for agentos/rag/citation.py: 0%

118 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1"""Citation tracing for RAG pipeline. 

2 

3Tracks which source documents contributed to generated text, 

4enabling answer provenance and fact-checking. 

5""" 

6 

7from __future__ import annotations 

8 

9from dataclasses import dataclass, field 

10from typing import Any, Dict, List, Optional, Tuple 

11import re 

12import hashlib 

13 

14 

15@dataclass 

16class Citation: 

17 """A single citation linking generated text to a source chunk.""" 

18 chunk_id: str 

19 chunk_text: str 

20 source_doc: str = "" # source document identifier 

21 relevance_score: float = 0.0 

22 context: str = "" # surrounding context window 

23 start_char: int = 0 # position in generated answer 

24 end_char: int = 0 

25 

26 

27@dataclass 

28class CitationReport: 

29 """Complete citation analysis for a generated response.""" 

30 answer: str = "" 

31 citations: List[Citation] = field(default_factory=list) 

32 source_count: int = 0 

33 coverage: float = 0.0 # fraction of answer covered by citations 

34 unused_sources: List[str] = field(default_factory=list) 

35 

36 def to_dict(self) -> Dict[str, Any]: 

37 return { 

38 "answer": self.answer, 

39 "num_citations": len(self.citations), 

40 "source_count": self.source_count, 

41 "coverage": self.coverage, 

42 "citations": [ 

43 { 

44 "chunk_id": c.chunk_id, 

45 "source_doc": c.source_doc, 

46 "relevance": c.relevance_score, 

47 "span": f"{c.start_char}-{c.end_char}", 

48 "text_preview": c.chunk_text[:200], 

49 } 

50 for c in self.citations 

51 ], 

52 } 

53 

54 

55class CitationTracer: 

56 """Track which retrieved chunks contributed to an answer. 

57 

58 Two modes: 

59 - token_overlap: Match answer spans to chunk texts by token overlap. 

60 - explicit: Parse answer for explicit citation markers like [1], [doc1]. 

61 """ 

62 

63 def __init__( 

64 self, 

65 mode: str = "token_overlap", 

66 min_overlap: int = 20, # minimum characters of overlap 

67 overlap_ratio: float = 0.3, # minimum overlap ratio 

68 ): 

69 self.mode = mode 

70 self.min_overlap = min_overlap 

71 self.overlap_ratio = overlap_ratio 

72 

73 def trace( 

74 self, 

75 answer: str, 

76 sources: List[Dict[str, Any]], 

77 ) -> CitationReport: 

78 """Trace answer back to source chunks. 

79 

80 Args: 

81 answer: Generated text response. 

82 sources: Retrieved chunks with 'text', 'score', 'index' keys. 

83 

84 Returns: 

85 CitationReport with matched citations. 

86 """ 

87 if self.mode == "explicit": 

88 citations = self._trace_explicit(answer, sources) 

89 else: 

90 citations = self._trace_overlap(answer, sources) 

91 

92 # Compute coverage 

93 if answer and citations: 

94 covered_chars = self._compute_covered_chars(answer, citations) 

95 coverage = covered_chars / len(answer) 

96 else: 

97 coverage = 0.0 

98 

99 # Find unused sources 

100 used_ids = {c.chunk_id for c in citations} 

101 unused = [ 

102 f"chunk_{s.get('index', i)}" 

103 for i, s in enumerate(sources) 

104 if f"chunk_{s.get('index', i)}" not in used_ids 

105 ] 

106 

107 return CitationReport( 

108 answer=answer, 

109 citations=citations, 

110 source_count=len(sources), 

111 coverage=round(coverage, 3), 

112 unused_sources=unused, 

113 ) 

114 

115 def _trace_overlap( 

116 self, 

117 answer: str, 

118 sources: List[Dict[str, Any]], 

119 ) -> List[Citation]: 

120 """Find answer spans that overlap with source chunks.""" 

121 citations = [] 

122 

123 for i, src in enumerate(sources): 

124 chunk_text = src.get("text", "") 

125 if not chunk_text: 

126 continue 

127 

128 chunk_id = f"chunk_{src.get('index', i)}" 

129 

130 # Find longest common substrings 

131 matches = self._find_substring_matches(answer, chunk_text) 

132 for start, end in matches: 

133 citations.append(Citation( 

134 chunk_id=chunk_id, 

135 chunk_text=chunk_text, 

136 source_doc=src.get("source", src.get("document", "")), 

137 relevance_score=src.get("score", 0.0), 

138 context=self._get_context(chunk_text, start, end), 

139 start_char=start, 

140 end_char=end, 

141 )) 

142 

143 return citations 

144 

145 def _trace_explicit( 

146 self, 

147 answer: str, 

148 sources: List[Dict[str, Any]], 

149 ) -> List[Citation]: 

150 """Parse explicit citation markers like [1], [source1], [doc:1].""" 

151 citations = [] 

152 

153 # Match [N], [docN], [source N] 

154 pattern = r'\[(?:doc|source|ref)?\s*(\d+)\]' 

155 matches = re.finditer(pattern, answer, re.IGNORECASE) 

156 

157 for m in matches: 

158 ref_num = int(m.group(1)) 

159 if 1 <= ref_num <= len(sources): 

160 src = sources[ref_num - 1] 

161 citations.append(Citation( 

162 chunk_id=f"chunk_{src.get('index', ref_num - 1)}", 

163 chunk_text=src.get("text", ""), 

164 source_doc=src.get("source", ""), 

165 relevance_score=src.get("score", 0.0), 

166 context=src.get("text", "")[:500], 

167 start_char=m.start(), 

168 end_char=m.end(), 

169 )) 

170 

171 return citations 

172 

173 def _find_substring_matches( 

174 self, 

175 answer: str, 

176 chunk: str, 

177 ) -> List[Tuple[int, int]]: 

178 """Find spans in answer that match substrings from chunk.""" 

179 matches = [] 

180 min_len = min(self.min_overlap, len(chunk) // 4) 

181 

182 # Use sliding window of sentences/phrases from chunk 

183 sentences = re.split(r'(?<=[.!?。!?])\s+', chunk) 

184 for sent in sentences: 

185 sent = sent.strip() 

186 if len(sent) < min_len: 

187 continue 

188 

189 pos = answer.find(sent) 

190 if pos >= 0: 

191 matches.append((pos, pos + len(sent))) 

192 else: 

193 # Try shorter windows 

194 window = max(min_len, len(sent) // 2) 

195 step = window // 2 

196 for start in range(0, len(sent) - window + 1, step): 

197 sub = sent[start:start + window] 

198 pos = answer.find(sub) 

199 if pos >= 0: 

200 matches.append((pos, pos + len(sub))) 

201 break 

202 

203 return self._merge_overlapping(matches) 

204 

205 def _merge_overlapping( 

206 self, 

207 spans: List[Tuple[int, int]], 

208 ) -> List[Tuple[int, int]]: 

209 """Merge overlapping citation spans.""" 

210 if not spans: 

211 return [] 

212 

213 sorted_spans = sorted(spans) 

214 merged = [sorted_spans[0]] 

215 

216 for span in sorted_spans[1:]: 

217 last = merged[-1] 

218 if span[0] <= last[1]: 

219 merged[-1] = (last[0], max(last[1], span[1])) 

220 else: 

221 merged.append(span) 

222 

223 return merged 

224 

225 def _get_context(self, chunk_text: str, start: int, end: int) -> str: 

226 """Get surrounding context around a match.""" 

227 ctx_start = max(0, start - 100) 

228 ctx_end = min(len(chunk_text), end + 100) 

229 return chunk_text[ctx_start:ctx_end] 

230 

231 def _compute_covered_chars( 

232 self, 

233 answer: str, 

234 citations: List[Citation], 

235 ) -> int: 

236 """Compute total characters covered by citations.""" 

237 if not citations: 

238 return 0 

239 

240 coverage = [False] * len(answer) 

241 for c in citations: 

242 for i in range(max(0, c.start_char), min(c.end_char, len(answer))): 

243 coverage[i] = True 

244 

245 return sum(coverage) 

246 

247 def build_attribution_map( 

248 self, 

249 answer: str, 

250 sources: List[Dict[str, Any]], 

251 ) -> str: 

252 """Build HTML attribution map for the answer. 

253 

254 Wraps cited spans in <cite> tags with source references. 

255 """ 

256 citations = self._trace_overlap(answer, sources) 

257 

258 # Sort citations by start position and apply in reverse to preserve indices 

259 citations.sort(key=lambda c: c.start_char, reverse=True) 

260 

261 result = answer 

262 for c in citations: 

263 prefix = result[:c.start_char] 

264 cited = result[c.start_char:c.end_char] 

265 suffix = result[c.end_char:] 

266 

267 source_id = c.chunk_id.replace("chunk_", "") 

268 cited_wrapped = ( 

269 f'<cite data-source="{c.source_doc}" ' 

270 f'data-chunk="{source_id}" ' 

271 f'data-score="{c.relevance_score:.2f}">' 

272 f'{cited}</cite>' 

273 ) 

274 result = prefix + cited_wrapped + suffix 

275 

276 return result 

277 

278 

279def hash_chunk_id(text: str, index: int) -> str: 

280 """Generate a stable chunk ID from text content.""" 

281 digest = hashlib.sha256(text.encode("utf-8")).hexdigest()[:12] 

282 return f"chunk_{index}_{digest}"