Coverage for src / documint_mcp / ai.py: 0%

396 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 22:30 -0400

1""" 

2AI patch drafting pipeline for Documint — 4-step chain. 

3 

4Step 1: Symbol diff characterization (haiku, structured output via instructor) 

5 → Given old/new symbols, produce typed DiffReport 

6 → Fast, cheap, machine-readable change description 

7 

8Step 2: Stale section detection (haiku, parallel per section) 

9 → For each doc section: is it stale given the diff? 

10 → Skips unaffected sections in Step 3 

11 

12Step 3: Patch generation (sonnet, surgical context) 

13 → Only stale sections + only relevant source excerpts 

14 → Produces full updated doc 

15 

16Step 4: Patch verification (haiku) 

17 → Checks: did anything change? any hallucinated symbol names? 

18 → Retries Step 3 up to 3 times if verification says RETRY 

19 

20Falls back to single-step Anthropic call, then OpenRouter, then deterministic if 4-step fails. 

21""" 

22from __future__ import annotations 

23 

24import hashlib 

25import json 

26import re 

27import time 

28from dataclasses import dataclass 

29from textwrap import dedent 

30from typing import Any, Literal 

31 

32import httpx 

33import structlog 

34 

35from .config import settings 

36from .models import ArtifactTrace, ArtifactType, DriftFinding, PatchCitation, Project, ProjectSettings 

37from .rag import get_rag, format_few_shot_prompt 

38 

39logger = structlog.get_logger(__name__) 

40 

41 

42# ── Chain config by artifact type ───────────────────────────────────────── 

43 

44CHAIN_CONFIGS: dict[str, dict] = { 

45 "api_reference": {"steps": 4, "breaking_detector": False}, 

46 "changelog": {"steps": 2, "breaking_detector": False}, 

47 "migration_notes": {"steps": 4, "breaking_detector": True}, 

48 "sdk_guides": {"steps": 3, "breaking_detector": False}, 

49 "mcp_reference": {"steps": 4, "breaking_detector": False}, 

50} 

51DEFAULT_CHAIN_CONFIG = {"steps": 4, "breaking_detector": False} 

52 

53 

54# ── Symbol diff LRU cache ────────────────────────────────────────────────── 

55 

56_DIFF_REPORT_CACHE: dict[str, Any] = {} 

57_DIFF_REPORT_CACHE_MAX = 50 

58 

59 

60def _symbols_cache_key(changed_symbols: list[dict]) -> str: 

61 """Stable cache key for a list of changed symbols.""" 

62 s = json.dumps(sorted(changed_symbols, key=lambda x: x.get("n", "")), sort_keys=True) 

63 return hashlib.sha256(s.encode()).hexdigest() 

64 

65 

66def _cache_get_diff_report(key: str) -> Any | None: 

67 return _DIFF_REPORT_CACHE.get(key) 

68 

69 

70def _cache_set_diff_report(key: str, report: Any) -> None: 

71 if len(_DIFF_REPORT_CACHE) >= _DIFF_REPORT_CACHE_MAX: 

72 # Evict the oldest entry (insertion-order in Python 3.7+) 

73 oldest = next(iter(_DIFF_REPORT_CACHE)) 

74 del _DIFF_REPORT_CACHE[oldest] 

75 _DIFF_REPORT_CACHE[key] = report 

76 

77 

78# ── Section helpers ──────────────────────────────────────────────────────── 

79 

80def _split_into_sections(doc: str) -> dict[str, str]: 

81 """Returns {section_title: section_content} dict. 

82 

83 Keys are the bare heading text (no leading #). The preamble (text before 

84 the first ## heading) is stored under the key '_preamble'. 

85 Only ## headings are treated as section boundaries; deeper headings stay 

86 inside their parent section. 

87 """ 

88 sections: dict[str, str] = {} 

89 current_title = "_preamble" 

90 current_lines: list[str] = [] 

91 for line in doc.splitlines(): 

92 if line.startswith("## "): 

93 if current_lines: 

94 sections[current_title] = "\n".join(current_lines) 

95 current_title = line.lstrip("# ").strip() 

96 current_lines = [line] 

97 else: 

98 current_lines.append(line) 

99 if current_lines: 

100 sections[current_title] = "\n".join(current_lines) 

101 return sections 

102 

103 

104def _apply_section_patches(original_doc: str, stale_titles: list[str], patched_sections: str) -> str: 

105 """Replace only the stale sections in the original doc with the patched versions. 

106 

107 Parses *patched_sections* into individual ## sections and substitutes each 

108 one back into *original_doc*, leaving untouched sections intact. 

109 Returns the complete doc with only the stale sections updated. 

110 """ 

111 original_sections = _split_into_sections(original_doc) 

112 patched_map = _split_into_sections(patched_sections) 

113 

114 result_parts: list[str] = [] 

115 for title, content in original_sections.items(): 

116 if title in stale_titles and title in patched_map: 

117 result_parts.append(patched_map[title]) 

118 else: 

119 result_parts.append(content) 

120 return "\n\n".join(result_parts) 

121 

122 

123def _detect_stale_sections( 

124 doc_content: str, 

125 changed_symbols: list[dict], 

126 use_semantic: bool = False, 

127) -> list[str]: 

128 """Return list of stale section titles. 

129 

130 Checks three signals in order: 

131 1. Section title contains a changed symbol name 

132 2. Section body mentions a changed symbol name 

133 3. Section body contains an old/new symbol signature 

134 

135 *use_semantic* is reserved for future Phase-6 ChromaDB embedding support. 

136 TODO: implement semantic staleness detection via ChromaDB embeddings (Phase 6). 

137 """ 

138 if use_semantic: 

139 # TODO (Phase 6): embed each section and compute cosine similarity against 

140 # the changed-symbol descriptions using a ChromaDB ephemeral collection. 

141 pass 

142 

143 sections = _split_into_sections(doc_content) 

144 stale: list[str] = [] 

145 

146 symbol_names = {s.get("n", "").lower() for s in changed_symbols if s.get("n")} 

147 symbol_signatures = {s.get("s", "").lower() for s in changed_symbols if s.get("s")} 

148 

149 for title, content in sections.items(): 

150 content_lower = content.lower() 

151 # Check 1: section title contains a changed symbol name 

152 if any(sym in title.lower() for sym in symbol_names): 

153 stale.append(title) 

154 continue 

155 # Check 2: section body mentions a changed symbol name 

156 if any(sym in content_lower for sym in symbol_names): 

157 stale.append(title) 

158 continue 

159 # Check 3: section body contains an old or new signature fragment 

160 if any(sig in content_lower for sig in symbol_signatures): 

161 stale.append(title) 

162 

163 # Fallback: if nothing matched, treat every section as potentially stale 

164 return stale if stale else list(sections.keys()) 

165 

166 

167# ── Structured output models ─────────────────────────────────────────────── 

168 

169try: 

170 from pydantic import BaseModel as PydanticBaseModel, Field as PydanticField 

171 

172 class SymbolChange(PydanticBaseModel): 

173 symbol_name: str = PydanticField( 

174 description="Exact name of the function, class, method, or constant that changed — copy verbatim from the source code." 

175 ) 

176 change_type: Literal["ADDED", "REMOVED", "PARAM_ADDED", "PARAM_REMOVED", "RETURN_CHANGED", "SIGNATURE_CHANGED"] = PydanticField( 

177 description=( 

178 "Category of change: ADDED=new symbol, REMOVED=deleted symbol, " 

179 "PARAM_ADDED=new parameter, PARAM_REMOVED=removed parameter, " 

180 "RETURN_CHANGED=return type changed, SIGNATURE_CHANGED=other signature change." 

181 ) 

182 ) 

183 severity: Literal["BREAKING", "ADDITIVE", "COSMETIC"] = PydanticField( 

184 description=( 

185 "Impact level: BREAKING=callers must update their code (removed param, removed symbol, " 

186 "incompatible return type), ADDITIVE=new optional behaviour callers may opt into, " 

187 "COSMETIC=rename or formatting only." 

188 ) 

189 ) 

190 detail: str = PydanticField( 

191 default="", 

192 description="One-sentence human-readable description of what changed and why it matters for documentation." 

193 ) 

194 before: str | None = PydanticField( 

195 default=None, 

196 description="The old signature or value as it appeared in the previous version, if known." 

197 ) 

198 after: str | None = PydanticField( 

199 default=None, 

200 description="The new signature or value as it appears in the current source code." 

201 ) 

202 

203 class DiffReport(PydanticBaseModel): 

204 changes: list[SymbolChange] = PydanticField( 

205 default_factory=list, 

206 description="List of every symbol that changed. Include only symbols with evidence from the source excerpt." 

207 ) 

208 summary: str = PydanticField( 

209 default="", 

210 description="One or two sentences summarising the overall nature of this drift (e.g. 'Three functions renamed; one parameter removed from add_memory()')." 

211 ) 

212 has_breaking_changes: bool = PydanticField( 

213 default=False, 

214 description="True if ANY change has severity=BREAKING, meaning existing callers will break." 

215 ) 

216 affected_doc_sections: list[str] = PydanticField( 

217 default_factory=list, 

218 description="Markdown section headings (as they appear in the docs) that need to be updated. Use the exact heading text." 

219 ) 

220 

221 class StalenessCheck(PydanticBaseModel): 

222 section_name: str = PydanticField( 

223 description="Markdown heading of the documentation section being evaluated." 

224 ) 

225 is_stale: bool = PydanticField( 

226 description="True if this section references any changed symbol and therefore needs to be rewritten." 

227 ) 

228 reason: str = PydanticField( 

229 default="", 

230 description="Brief explanation of why the section is or is not stale." 

231 ) 

232 

233 class PatchVerification(PydanticBaseModel): 

234 patch_changed_content: bool = PydanticField( 

235 description="True if the proposed patch contains substantive differences from the original documentation (not just whitespace)." 

236 ) 

237 hallucinated_symbols: list[str] = PydanticField( 

238 default_factory=list, 

239 description=( 

240 "List of function/class/method names referenced in the patch that do NOT appear in the " 

241 "provided source symbol list. Empty list if no hallucinations detected." 

242 ) 

243 ) 

244 issues_found: list[str] = PydanticField( 

245 default_factory=list, 

246 description=( 

247 "Specific problems with the patch, each as a short sentence. Examples: " 

248 "'patch is identical to original', 'uses foo() which is not in source', " 

249 "'missing update for removed parameter bar'. Empty list if no issues." 

250 ) 

251 ) 

252 verdict: Literal["PASS", "RETRY", "FAIL"] = PydanticField( 

253 default="PASS", 

254 description=( 

255 "PASS=patch is correct and substantive, ready to present to user. " 

256 "RETRY=patch has fixable issues (same as original, minor hallucinations) — regenerate. " 

257 "FAIL=patch has severe errors (contradicts source, major hallucinations) — escalate." 

258 ) 

259 ) 

260 

261 _HAS_PYDANTIC = True 

262except ImportError: 

263 _HAS_PYDANTIC = False 

264 logger.info("pydantic_not_available", note="structured output models unavailable") 

265 

266_instructor_client = None 

267_HAS_INSTRUCTOR = False 

268 

269 

270def _get_instructor_client(): 

271 """Lazy instructor client initialization. 

272 

273 Creates the instructor-wrapped Anthropic client on first use rather than at 

274 module import time. This avoids conflating "packages not installed" with 

275 "API key not yet configured" — the old eager initialization set 

276 _HAS_INSTRUCTOR=False permanently if the key was missing at import time. 

277 """ 

278 global _instructor_client, _HAS_INSTRUCTOR 

279 if _instructor_client is not None: 

280 return _instructor_client 

281 try: 

282 import instructor 

283 from anthropic import Anthropic 

284 

285 key = settings.anthropic_api_key 

286 if not key: 

287 return None 

288 _instructor_client = instructor.from_anthropic(Anthropic(api_key=key)) 

289 _HAS_INSTRUCTOR = True 

290 return _instructor_client 

291 except (ImportError, Exception) as exc: 

292 logger.info("instructor_init_failed", error=str(exc), note="falling back to single-step chain") 

293 return None 

294 

295 

296# ── Token cost helpers ───────────────────────────────────────────────────── 

297 

298# Rates in USD per million tokens 

299_COST_PER_M: dict[str, dict[str, float]] = { 

300 "claude-haiku-4-5-20251001": {"in": 0.25, "out": 1.25}, 

301 # Fallback for any sonnet-4-6 variant string 

302 "claude-sonnet-4-6": {"in": 3.0, "out": 15.0}, 

303} 

304 

305 

306def _model_cost(model: str, input_tokens: int, output_tokens: int) -> float: 

307 """Rough USD cost for a single call.""" 

308 rates = None 

309 for key in _COST_PER_M: 

310 if key in model: 

311 rates = _COST_PER_M[key] 

312 break 

313 if rates is None: 

314 rates = {"in": 3.0, "out": 15.0} # conservative default 

315 return (input_tokens / 1_000_000) * rates["in"] + (output_tokens / 1_000_000) * rates["out"] 

316 

317 

318def _extract_usage(raw_response: Any) -> tuple[int, int]: 

319 """Return (input_tokens, output_tokens) from an instructor response object. 

320 

321 When using instructor with response_model, messages.create() returns the 

322 Pydantic model — not the raw Anthropic response. The Pydantic model has 

323 no ``usage`` attribute, so the old code always returned (0, 0). 

324 

325 Newer versions of the instructor library store the original API response 

326 on ``_raw_response``. We check that first, then fall back to a direct 

327 ``.usage`` attribute (in case the caller passes the raw response directly), 

328 and finally accept (0, 0) with a debug log. 

329 """ 

330 try: 

331 # instructor stores the raw Anthropic response here in newer versions 

332 raw = getattr(raw_response, "_raw_response", None) 

333 if raw is not None: 

334 usage = getattr(raw, "usage", None) 

335 if usage is not None: 

336 return int(getattr(usage, "input_tokens", 0)), int(getattr(usage, "output_tokens", 0)) 

337 

338 # Direct .usage (when caller passes the raw API response itself) 

339 usage = getattr(raw_response, "usage", None) 

340 if usage is not None: 

341 return int(getattr(usage, "input_tokens", 0)), int(getattr(usage, "output_tokens", 0)) 

342 except Exception: 

343 pass 

344 

345 logger.debug("token_usage_unavailable", response_type=type(raw_response).__name__) 

346 return 0, 0 

347 

348 

349# ── Result dataclass ─────────────────────────────────────────────────────── 

350 

351@dataclass(frozen=True) 

352class DraftPatchResult: 

353 summary: str 

354 rationale: str 

355 proposed_sections: list[str] 

356 citations: list[PatchCitation] 

357 preview_markdown: str 

358 ai_provider: str 

359 model_name: str | None 

360 input_summary: str 

361 confidence_score: float = 0.5 

362 chain_steps_used: int = 1 

363 

364 

365# ── Main generator ───────────────────────────────────────────────────────── 

366 

367class PatchGenerator: 

368 """Generate reviewable doc patches from deterministic drift findings.""" 

369 

370 def draft_patch( 

371 self, 

372 *, 

373 project: Project, 

374 trace: ArtifactTrace, 

375 finding: DriftFinding | None, 

376 current_doc: str, 

377 source_content: str = "", 

378 project_settings: ProjectSettings | None, 

379 policy: str, 

380 ) -> DraftPatchResult: 

381 # Attempt 4-step chain first (requires instructor + anthropic SDK) 

382 if _get_instructor_client() is not None and settings.anthropic_api_key: 

383 result = self._draft_4step( 

384 project=project, 

385 trace=trace, 

386 finding=finding, 

387 current_doc=current_doc, 

388 source_content=source_content, 

389 ) 

390 if result is not None: 

391 return result 

392 

393 # Fall back to single-step chain 

394 prompt = self._build_prompt( 

395 project=project, trace=trace, finding=finding, 

396 current_doc=current_doc, source_content=source_content, 

397 project_settings=project_settings, policy=policy, 

398 ) 

399 

400 for provider_fn, provider_name, model_name in [ 

401 (lambda p: self._draft_with_anthropic(p), "anthropic", settings.anthropic_model), 

402 (lambda p: self._draft_with_openrouter(p, settings.openrouter_primary_model), "openrouter", settings.openrouter_primary_model), 

403 (lambda p: self._draft_with_openrouter(p, settings.openrouter_secondary_model), "openrouter", settings.openrouter_secondary_model), 

404 ]: 

405 text = provider_fn(prompt) 

406 if text is not None: 

407 citations = self._build_citations(trace=trace, finding=finding) 

408 sections = self._suggest_sections(trace=trace, finding=finding) 

409 return DraftPatchResult( 

410 summary=f"{provider_name.capitalize()} patch for {trace.title}", 

411 rationale=f"Generated by {provider_name} from drift context.", 

412 proposed_sections=sections, 

413 citations=citations, 

414 preview_markdown=text, 

415 ai_provider=provider_name, 

416 model_name=model_name, 

417 input_summary=prompt[:500], 

418 chain_steps_used=1, 

419 ) 

420 

421 return self._deterministic_draft( 

422 project=project, trace=trace, finding=finding, 

423 current_doc=current_doc, project_settings=project_settings, policy=policy, 

424 ) 

425 

426 # ── 4-step chain ────────────────────────────────────────────────────── 

427 

428 def _draft_4step( 

429 self, 

430 *, 

431 project: Project, 

432 trace: ArtifactTrace, 

433 finding: DriftFinding | None, 

434 current_doc: str, 

435 source_content: str, 

436 ) -> DraftPatchResult | None: 

437 """ 

438 4-step AI chain with structured output and verification. 

439 

440 The number of steps actually executed depends on the artifact type 

441 via CHAIN_CONFIGS (steps=2 skips steps 3+4; steps=3 skips step 4). 

442 

443 Returns None if any critical step fails (caller falls back to single-step). 

444 """ 

445 try: 

446 # Resolve chain config for this artifact type 

447 artifact_type_val = ( 

448 trace.artifact_type.value 

449 if hasattr(trace, "artifact_type") and trace.artifact_type is not None 

450 else "unknown" 

451 ) 

452 chain_cfg = CHAIN_CONFIGS.get(artifact_type_val, DEFAULT_CHAIN_CONFIG) 

453 max_steps: int = chain_cfg["steps"] 

454 breaking_detector: bool = chain_cfg["breaking_detector"] 

455 

456 # Token accumulators 

457 total_input_tokens = 0 

458 total_output_tokens = 0 

459 

460 # ── Step 1: Characterize the structural diff ────────────────── 

461 # Check the LRU cache first using the changed_symbols list as key. 

462 changed_symbols_raw: list[dict] = [] 

463 if finding and hasattr(finding, "changed_symbols") and finding.changed_symbols: 

464 changed_symbols_raw = finding.changed_symbols 

465 

466 cache_key = _symbols_cache_key(changed_symbols_raw) if changed_symbols_raw else None 

467 cached_report = _cache_get_diff_report(cache_key) if cache_key else None 

468 

469 if cached_report is not None: 

470 diff_report = cached_report 

471 logger.info("step1_cache_hit", cache_key=cache_key[:12] if cache_key else None) 

472 else: 

473 diff_report, step1_in, step1_out = self._step1_diff_report( 

474 finding=finding, source_content=source_content 

475 ) 

476 total_input_tokens += step1_in 

477 total_output_tokens += step1_out 

478 if diff_report is None: 

479 return None 

480 if cache_key: 

481 _cache_set_diff_report(cache_key, diff_report) 

482 

483 # ── Step 2: Identify stale sections ────────────────────────── 

484 stale_sections = self._step2_stale_sections(diff_report=diff_report, current_doc=current_doc) 

485 

486 # For steps=2 we only needed the diff report + stale detection; 

487 # generate a lightweight haiku patch and return early. 

488 if max_steps <= 2: 

489 patch_text = self._step3_generate_patch( 

490 project=project, 

491 trace=trace, 

492 diff_report=diff_report, 

493 stale_sections=stale_sections, 

494 current_doc=current_doc, 

495 source_content=source_content, 

496 use_haiku=True, 

497 ) 

498 if patch_text is None: 

499 return None 

500 if breaking_detector: 

501 patch_text = self._inject_breaking_changes_section(patch_text, diff_report) 

502 citations = self._build_citations(trace=trace, finding=finding) 

503 sections = self._suggest_sections(trace=trace, finding=finding) 

504 confidence = self._compute_confidence(diff_report=diff_report, verification=None) 

505 return DraftPatchResult( 

506 summary=diff_report.summary or f"2-step patch for {trace.title}", 

507 rationale=self._format_rationale(diff_report), 

508 proposed_sections=sections, 

509 citations=citations, 

510 preview_markdown=patch_text, 

511 ai_provider="anthropic-2step", 

512 model_name=settings.anthropic_model, 

513 input_summary=f"2-step: {len(diff_report.changes)} changes, {len(stale_sections)} stale sections", 

514 confidence_score=confidence, 

515 chain_steps_used=2, 

516 ) 

517 

518 # ── Step 3: Generate patch with surgical context ────────────── 

519 patch_text = self._step3_generate_patch( 

520 project=project, 

521 trace=trace, 

522 diff_report=diff_report, 

523 stale_sections=stale_sections, 

524 current_doc=current_doc, 

525 source_content=source_content, 

526 ) 

527 if patch_text is None: 

528 return None 

529 

530 if breaking_detector: 

531 patch_text = self._inject_breaking_changes_section(patch_text, diff_report) 

532 

533 # For steps=3 we skip verification entirely. 

534 if max_steps <= 3: 

535 citations = self._build_citations(trace=trace, finding=finding) 

536 sections = self._suggest_sections(trace=trace, finding=finding) 

537 confidence = self._compute_confidence(diff_report=diff_report, verification=None) 

538 return DraftPatchResult( 

539 summary=diff_report.summary or f"3-step patch for {trace.title}", 

540 rationale=self._format_rationale(diff_report), 

541 proposed_sections=sections, 

542 citations=citations, 

543 preview_markdown=patch_text, 

544 ai_provider="anthropic-3step", 

545 model_name=settings.anthropic_model, 

546 input_summary=f"3-step: {len(diff_report.changes)} changes, {len(stale_sections)} stale sections", 

547 confidence_score=confidence, 

548 chain_steps_used=3, 

549 ) 

550 

551 # ── Step 4: Verify the patch (up to 3 retry attempts) ───────── 

552 verification = None 

553 for attempt in range(3): 

554 verification, v_in, v_out = self._step4_verify( 

555 current_doc=current_doc, 

556 patch_text=patch_text, 

557 source_content=source_content, 

558 diff_report=diff_report, 

559 ) 

560 total_input_tokens += v_in 

561 total_output_tokens += v_out 

562 

563 if verification is None or verification.verdict in ("PASS", "FAIL"): 

564 break 

565 

566 # verdict == "RETRY" 

567 logger.info( 

568 "patch_verification_retry", 

569 artifact=trace.id, 

570 attempt=attempt + 1, 

571 issues=verification.issues_found, 

572 ) 

573 if attempt < 2: 

574 # TODO: Replace with `await asyncio.sleep()` when the chain 

575 # is converted to async. time.sleep() blocks the thread / 

576 # event loop, which is harmful under async servers. 

577 sleep_secs = min(2 ** attempt, 2) # cap at 2s 

578 logger.warning( 

579 "blocking_sleep_in_retry_loop", 

580 seconds=sleep_secs, 

581 attempt=attempt + 1, 

582 note="time.sleep blocks the thread; convert to async", 

583 ) 

584 time.sleep(sleep_secs) 

585 patch_text = self._step3_generate_patch( 

586 project=project, trace=trace, diff_report=diff_report, 

587 stale_sections=stale_sections, current_doc=current_doc, 

588 source_content=source_content, retry=True, 

589 retry_issues=verification.issues_found, 

590 ) or patch_text 

591 if breaking_detector: 

592 patch_text = self._inject_breaking_changes_section(patch_text, diff_report) 

593 

594 citations = self._build_citations(trace=trace, finding=finding) 

595 sections = self._suggest_sections(trace=trace, finding=finding) 

596 confidence = self._compute_confidence(diff_report=diff_report, verification=verification) 

597 

598 # Append token/cost metadata to the input_summary field so callers 

599 # can surface it without a schema change. 

600 approx_cost_usd = _model_cost( 

601 settings.anthropic_model or "claude-sonnet-4-6", 

602 total_input_tokens, 

603 total_output_tokens, 

604 ) 

605 token_meta = ( 

606 f" | tokens in={total_input_tokens} out={total_output_tokens}" 

607 f" cost≈${approx_cost_usd:.5f}" 

608 ) 

609 

610 return DraftPatchResult( 

611 summary=diff_report.summary if diff_report.summary else f"4-step patch for {trace.title}", 

612 rationale=self._format_rationale(diff_report), 

613 proposed_sections=sections, 

614 citations=citations, 

615 preview_markdown=patch_text, 

616 ai_provider="anthropic-4step", 

617 model_name=settings.anthropic_model, 

618 input_summary=( 

619 f"4-step: {len(diff_report.changes)} changes, {len(stale_sections)} stale sections" 

620 + token_meta 

621 ), 

622 confidence_score=confidence, 

623 chain_steps_used=4, 

624 ) 

625 except Exception as exc: 

626 logger.warning("4step_chain_failed", error=str(exc), artifact=trace.id) 

627 return None 

628 

629 # ── Breaking-change injector ─────────────────────────────────────────── 

630 

631 def _inject_breaking_changes_section(self, patch_text: str, diff_report: Any) -> str: 

632 """Append a '## Breaking Changes' section if the diff has BREAKING severity changes.""" 

633 if not diff_report or not diff_report.has_breaking_changes: 

634 return patch_text 

635 breaking = [c for c in diff_report.changes if c.severity == "BREAKING"] 

636 if not breaking: 

637 return patch_text 

638 lines = ["\n\n## Breaking Changes\n"] 

639 for c in breaking: 

640 before_str = f" (was: `{c.before}`)" if c.before else "" 

641 after_str = f" → `{c.after}`" if c.after else "" 

642 lines.append(f"- **`{c.symbol_name}`** ({c.change_type}){before_str}{after_str}: {c.detail}") 

643 return patch_text + "\n".join(lines) 

644 

645 # ── Step implementations ─────────────────────────────────────────────── 

646 

647 def _step1_diff_report( 

648 self, *, finding: DriftFinding | None, source_content: str 

649 ) -> tuple[Any | None, int, int]: 

650 """Step 1: Structured diff report via instructor. 

651 

652 Returns (DiffReport | None, input_tokens, output_tokens). 

653 """ 

654 client = _get_instructor_client() 

655 if client is None: 

656 return None, 0, 0 

657 

658 # Use changed_symbols from the finding if available (from drift_engine) 

659 changed_symbols_context = "" 

660 if finding and hasattr(finding, "changed_symbols") and finding.changed_symbols: 

661 changed_symbols_context = f"\nStructural changes detected:\n" + "\n".join( 

662 f" - {c.get('type','?')}: {c.get('symbol','?')}{c.get('detail','')}" 

663 for c in finding.changed_symbols[:10] 

664 ) 

665 

666 artifact_id = finding.artifact_id if finding else "unknown" 

667 drift_signal = finding.summary if finding else "source changed" 

668 artifact_type = finding.artifact_type.value if finding and hasattr(finding, "artifact_type") else "unknown" 

669 

670 prompt = dedent(f""" 

671 You are a code-change analyst. Produce a structured DiffReport for a documentation drift event. 

672 

673 ARTIFACT: {artifact_id} 

674 ARTIFACT TYPE: {artifact_type} 

675 DRIFT SIGNAL: {drift_signal} 

676 {changed_symbols_context} 

677 

678 SOURCE CODE (ground truth — the current state of the code): 

679 {source_content[:6000]} 

680 

681 INSTRUCTIONS: 

682 - Populate `changes` with every symbol that has evidence of change in the source above. 

683 - Use EXACT symbol names from the source — never invent or paraphrase names. 

684 - Severity guide: BREAKING = callers must change their code; ADDITIVE = new optional feature; 

685 COSMETIC = rename, docstring, or format change only. 

686 - Set `has_breaking_changes` = true if any change is BREAKING. 

687 - For `affected_doc_sections`, list the exact markdown headings from typical API reference docs 

688 that would need updating (e.g. "Parameters", "Return Value", "add_memory"). 

689 - If you cannot determine what changed from the source excerpt, return an empty `changes` list 

690 and set `summary` to "Insufficient source context to characterise changes." 

691 """).strip() 

692 

693 try: 

694 raw = client.messages.create( 

695 model="claude-haiku-4-5-20251001", 

696 max_tokens=2048, 

697 response_model=DiffReport, 

698 messages=[{"role": "user", "content": prompt}], 

699 ) 

700 in_tok, out_tok = _extract_usage(raw) 

701 return raw, in_tok, out_tok 

702 except Exception as exc: 

703 logger.warning("step1_diff_report_failed", error=str(exc)) 

704 return None, 0, 0 

705 

706 def _step2_stale_sections(self, *, diff_report: Any, current_doc: str) -> list[str]: 

707 """Step 2: Identify stale sections. 

708 

709 Uses the improved _detect_stale_sections helper when the finding's 

710 changed_symbols list is available on the diff_report; falls back to the 

711 original symbol-name line-scan for backwards compatibility. 

712 """ 

713 if not diff_report or not diff_report.changes: 

714 # No structured diff — mark all sections as potentially stale 

715 return self._extract_section_names(current_doc) 

716 

717 # Build changed_symbols in the format expected by _detect_stale_sections 

718 changed_symbols_for_detect = [ 

719 {"n": c.symbol_name, "s": (c.before or "") + " " + (c.after or "")} 

720 for c in diff_report.changes 

721 ] 

722 

723 stale = _detect_stale_sections(current_doc, changed_symbols_for_detect) 

724 return stale[:8] # Cap at 8 sections 

725 

726 def _step3_generate_patch( 

727 self, 

728 *, 

729 project: Project, 

730 trace: ArtifactTrace, 

731 diff_report: Any, 

732 stale_sections: list[str], 

733 current_doc: str, 

734 source_content: str, 

735 retry: bool = False, 

736 retry_issues: list[str] | None = None, 

737 use_haiku: bool = False, 

738 ) -> str | None: 

739 """Step 3: Generate the actual patch with surgical context. 

740 

741 When *use_haiku* is True (steps=2 path) the cheaper haiku model is used. 

742 Uses section-level patching: only the stale sections are sent to the 

743 model, and the result is stitched back into the original document. 

744 """ 

745 changes_text = "" 

746 # Extract valid symbol names from diff report for explicit hallucination guard 

747 valid_symbol_names: list[str] = [] 

748 if diff_report and diff_report.changes: 

749 valid_symbol_names = [c.symbol_name for c in diff_report.changes[:15]] 

750 changes_text = "VERIFIED SYMBOL CHANGES (copy these names exactly — do not paraphrase):\n" + "\n".join( 

751 f" [{c.severity}] {c.change_type}: `{c.symbol_name}`" 

752 + (f"\n Before: {c.before}" if c.before else "") 

753 + (f"\n After: {c.after}" if c.after else "") 

754 + (f"\n Note: {c.detail}" if c.detail else "") 

755 for c in diff_report.changes[:15] 

756 ) 

757 

758 # Also mine symbol names directly from source as a second guard 

759 source_symbols = re.findall(r'\b(?:def |fn |func |class |pub fn |pub struct )\s*(\w+)', source_content) 

760 all_valid_symbols = list(dict.fromkeys(valid_symbol_names + source_symbols[:30])) # deduplicated, order preserved 

761 

762 stale_text = ( 

763 "SECTIONS TO UPDATE (update these and only these; leave other sections verbatim):\n" 

764 + "\n".join(f" - {s}" for s in stale_sections) 

765 ) if stale_sections else "UPDATE ALL SECTIONS (no specific stale sections identified)." 

766 

767 retry_instruction = "" 

768 if retry: 

769 issues_detail = "\n".join(f" - {issue}" for issue in (retry_issues or [])) or " - patch was identical to original or contained hallucinations" 

770 retry_instruction = dedent(f""" 

771 

772 RETRY NOTICE: The previous draft was rejected for these reasons: 

773 {issues_detail} 

774 You MUST fix all listed issues before returning the updated document. 

775 """) 

776 

777 valid_symbols_guard = ( 

778 "\nVALID SYMBOL NAMES (ONLY use names from this list when referencing code):\n" 

779 + ", ".join(f"`{s}`" for s in all_valid_symbols) 

780 ) if all_valid_symbols else "" 

781 

782 # ── Section-level patching: only send stale sections to the model ── 

783 sections_map = _split_into_sections(current_doc) 

784 stale_section_content = "\n\n".join( 

785 sections_map[title] for title in stale_sections if title in sections_map 

786 ) 

787 # If we couldn't extract any stale content, fall back to the full doc 

788 doc_context = stale_section_content if stale_section_content.strip() else current_doc[:2500] 

789 section_instruction = ( 

790 "You are rewriting ONLY the stale sections shown below. " 

791 "Return ONLY those sections as a complete markdown fragment (keep their ## headings). " 

792 "Do not reproduce unrelated sections." 

793 if stale_section_content.strip() 

794 else "Return the COMPLETE updated documentation as a single markdown document." 

795 ) 

796 

797 # ── RAG: retrieve few-shot examples from approved patches ── 

798 few_shot_block = "" 

799 try: 

800 rag = get_rag() 

801 if rag.available: 

802 few_shot_examples = rag.get_few_shot_examples( 

803 artifact_id=trace.id, 

804 stale_sections=stale_sections, 

805 n=3, 

806 ) 

807 few_shot_block = format_few_shot_prompt(few_shot_examples) 

808 except Exception: # noqa: BLE001 

809 pass # RAG is optional — never block the chain 

810 

811 prompt = dedent(f""" 

812 You are a senior technical writer producing a precise documentation patch. 

813 

814 PROJECT: {project.name} 

815 ARTIFACT: {trace.title} (type: {trace.artifact_type.value}) 

816 

817 {changes_text} 

818 

819 {stale_text} 

820 {valid_symbols_guard} 

821 

822 SOURCE CODE (ground truth — the canonical reference for all names and signatures): 

823 {source_content[:6000]} 

824 

825 STALE DOCUMENTATION SECTIONS (the only content you should rewrite): 

826 {doc_context} 

827 

828 {few_shot_block} 

829 

830 RULES — follow every rule without exception: 

831 1. {section_instruction} 

832 2. Use ONLY symbol names that appear in the VALID SYMBOL NAMES list or verbatim in SOURCE CODE above. 

833 If you are unsure of a name, omit the code reference rather than guessing. 

834 3. Update ONLY the sections listed in SECTIONS TO UPDATE. Copy all other sections byte-for-byte. 

835 4. If a section has no changes, reproduce it exactly as-is. 

836 5. Append a "## Changes Made" section at the end that lists each change and the source evidence for it. 

837 6. Do NOT invent parameters, return types, or functions that are not in the source.{retry_instruction} 

838 

839 Return ONLY the updated markdown content. No preamble, no commentary outside the document. 

840 """).strip() 

841 

842 if use_haiku: 

843 # Lightweight path: call haiku directly via raw httpx (same as _draft_with_anthropic 

844 # but forces the haiku model regardless of settings.anthropic_model). 

845 if not settings.anthropic_api_key: 

846 return None 

847 try: 

848 with httpx.Client(timeout=60.0) as client: 

849 response = client.post( 

850 "https://api.anthropic.com/v1/messages", 

851 headers={ 

852 "x-api-key": settings.anthropic_api_key, 

853 "anthropic-version": "2023-06-01", 

854 "content-type": "application/json", 

855 }, 

856 json={ 

857 "model": "claude-haiku-4-5-20251001", 

858 "max_tokens": 2048, 

859 "system": "You are a senior technical writer. Draft precise, factual documentation patches based on code drift signals. Return only the updated markdown document.", 

860 "messages": [{"role": "user", "content": prompt}], 

861 }, 

862 ) 

863 response.raise_for_status() 

864 patched_fragment = response.json()["content"][0]["text"] 

865 except Exception as exc: 

866 logger.warning("step3_haiku_failed", error=str(exc)) 

867 return None 

868 else: 

869 patched_fragment = self._draft_with_anthropic(prompt) 

870 

871 if patched_fragment is None: 

872 return None 

873 

874 # Stitch the patched sections back into the full document 

875 if stale_section_content.strip() and stale_sections: 

876 return _apply_section_patches(current_doc, stale_sections, patched_fragment) 

877 return patched_fragment 

878 

879 def _step4_verify( 

880 self, 

881 *, 

882 current_doc: str, 

883 patch_text: str, 

884 source_content: str, 

885 diff_report: Any, 

886 ) -> tuple[Any | None, int, int]: 

887 """Step 4: Verify the patch is correct, substantive, and free of hallucinations. 

888 

889 Returns (PatchVerification | None, input_tokens, output_tokens). 

890 """ 

891 client = _get_instructor_client() 

892 if client is None: 

893 return None, 0, 0 

894 

895 # Extract all identifiers from source — this is the hallucination allowlist 

896 source_names = sorted(set(re.findall(r'\b(?:def |fn |func |class |pub fn |pub struct )\s*(\w+)', source_content))) 

897 # Also include names from the diff report (already validated in step 1) 

898 if diff_report and diff_report.changes: 

899 source_names = sorted(set(source_names + [c.symbol_name for c in diff_report.changes])) 

900 

901 # Use generous windows: enough to catch hallucinated names that appear mid-document 

902 original_excerpt = current_doc[:1500] 

903 patch_excerpt = patch_text[:2000] 

904 

905 prompt = dedent(f""" 

906 You are a documentation QA reviewer. Evaluate whether this documentation patch is correct. 

907 

908 VALID SOURCE SYMBOLS (the only code names that should appear in the patch): 

909 {', '.join(source_names[:40]) or '(none extracted — skip hallucination check)'} 

910 

911 ORIGINAL DOCUMENTATION (first 1500 chars): 

912 {original_excerpt} 

913 

914 PROPOSED PATCH (first 2000 chars): 

915 {patch_excerpt} 

916 

917 EVALUATION CHECKLIST: 

918 1. `patch_changed_content`: Is the patch substantively different from the original? 

919 (Ignore whitespace-only differences. True = real content changed.) 

920 2. `hallucinated_symbols`: List every function/class/method name in the patch that does NOT 

921 appear in VALID SOURCE SYMBOLS. Only list code identifiers, not prose words. 

922 Empty list = no hallucinations. 

923 3. `issues_found`: List each specific problem as a short sentence. Common issues: 

924 - "patch is nearly identical to original" 

925 - "uses `foo()` which is not in source symbols" 

926 - "missing '## Changes Made' section" 

927 - "updated wrong section" 

928 Empty list = no issues. 

929 4. `verdict`: 

930 - PASS: patch is substantive, no hallucinations, no major issues 

931 - RETRY: patch has fixable problems (same as original, minor hallucinations, missing section) 

932 - FAIL: patch is severely wrong (contradicts source, multiple major hallucinations) 

933 

934 Be strict: any hallucinated symbol name that would mislead a developer = at minimum RETRY. 

935 """).strip() 

936 

937 try: 

938 raw = client.messages.create( 

939 model="claude-haiku-4-5-20251001", 

940 max_tokens=512, # raised from 256 — issues_found list needs room 

941 response_model=PatchVerification, 

942 messages=[{"role": "user", "content": prompt}], 

943 ) 

944 in_tok, out_tok = _extract_usage(raw) 

945 return raw, in_tok, out_tok 

946 except Exception as exc: 

947 logger.warning("step4_verification_failed", error=str(exc)) 

948 return None, 0, 0 

949 

950 def _compute_confidence(self, *, diff_report: Any, verification: Any) -> float: 

951 """ 

952 Compute a meaningful confidence score (0.0–1.0) from chain signals. 

953 

954 Scoring model: 

955 Base: 0.70 (4-step chain completed) 

956 +0.15 diff_report has changes and they are characterised 

957 -0.15 has_breaking_changes (higher uncertainty) 

958 +0.10 verification PASS 

959 -0.20 verification RETRY or hallucinations found 

960 -0.40 verification FAIL 

961 Score is clamped to [0.10, 0.95]. 

962 """ 

963 score = 0.70 

964 

965 if diff_report: 

966 if diff_report.changes: 

967 score += 0.15 

968 if diff_report.has_breaking_changes: 

969 score -= 0.15 

970 

971 if verification: 

972 if verification.verdict == "PASS" and not verification.hallucinated_symbols: 

973 score += 0.10 

974 elif verification.verdict == "RETRY" or verification.hallucinated_symbols: 

975 score -= 0.20 

976 elif verification.verdict == "FAIL": 

977 score -= 0.40 

978 

979 return round(max(0.10, min(0.95, score)), 2) 

980 

981 def _format_rationale(self, diff_report: Any) -> str: 

982 if not diff_report or not diff_report.changes: 

983 return "Generated by 4-step AI chain from deterministic drift signal." 

984 changes_str = "; ".join( 

985 f"{c.change_type} {c.symbol_name}" for c in diff_report.changes[:5] 

986 ) 

987 suffix = f"... (+{len(diff_report.changes)-5} more)" if len(diff_report.changes) > 5 else "" 

988 return f"Structural changes: {changes_str}{suffix}. Confidence: {'HIGH' if not diff_report.has_breaking_changes else 'MEDIUM'}." 

989 

990 def _extract_section_names(self, doc: str) -> list[str]: 

991 """Extract section heading names, matching only ## headings. 

992 

993 Must stay consistent with _split_into_sections which only splits on 

994 '## ' boundaries. Previously this matched all heading levels (#, ##, 

995 ###, ...) causing subsection names to appear in the stale-section list 

996 even though _split_into_sections never creates keys for them — patches 

997 for those subsections silently failed to apply. 

998 """ 

999 return [ 

1000 line[3:].strip() 

1001 for line in doc.splitlines() 

1002 if line.startswith("## ") and not line.startswith("### ") 

1003 ][:8] 

1004 

1005 # ── Single-step providers (fallback) ────────────────────────────────── 

1006 

1007 def _draft_with_anthropic(self, prompt: str) -> str | None: 

1008 if not settings.anthropic_api_key: 

1009 return None 

1010 try: 

1011 with httpx.Client(timeout=60.0) as client: 

1012 response = client.post( 

1013 "https://api.anthropic.com/v1/messages", 

1014 headers={ 

1015 "x-api-key": settings.anthropic_api_key, 

1016 "anthropic-version": "2023-06-01", 

1017 "content-type": "application/json", 

1018 }, 

1019 json={ 

1020 "model": settings.anthropic_model, 

1021 "max_tokens": 2048, 

1022 "system": "You are a senior technical writer. Draft precise, factual documentation patches based on code drift signals. Return only the updated markdown document.", 

1023 "messages": [{"role": "user", "content": prompt}], 

1024 }, 

1025 ) 

1026 response.raise_for_status() 

1027 return response.json()["content"][0]["text"] 

1028 except Exception as e: 

1029 logger.warning("anthropic_draft_failed", error=str(e)) 

1030 return None 

1031 

1032 def _draft_with_openrouter(self, prompt: str, model: str) -> str | None: 

1033 if not settings.openrouter_api_key: 

1034 return None 

1035 try: 

1036 with httpx.Client(timeout=60.0) as client: 

1037 response = client.post( 

1038 "https://openrouter.ai/api/v1/chat/completions", 

1039 headers={ 

1040 "Authorization": f"Bearer {settings.openrouter_api_key}", 

1041 "HTTP-Referer": "https://documint.xyz", 

1042 "X-Title": "Documint", 

1043 "content-type": "application/json", 

1044 }, 

1045 json={ 

1046 "model": model, 

1047 "max_tokens": 2048, 

1048 "messages": [ 

1049 {"role": "system", "content": "You are a senior technical writer. Return only updated markdown."}, 

1050 {"role": "user", "content": prompt}, 

1051 ], 

1052 }, 

1053 ) 

1054 response.raise_for_status() 

1055 return response.json()["choices"][0]["message"]["content"] 

1056 except Exception as e: 

1057 logger.warning("openrouter_draft_failed", model=model, error=str(e)) 

1058 return None 

1059 

1060 def _deterministic_draft(self, *, project, trace, finding, current_doc, project_settings, policy) -> DraftPatchResult: 

1061 del project_settings, policy 

1062 citations = self._build_citations(trace=trace, finding=finding) 

1063 sections = self._suggest_sections(trace=trace, finding=finding) 

1064 summary = finding.summary if finding is not None else f"Refresh {trace.title}" 

1065 rationale = finding.rationale if finding is not None else "Documentation should be refreshed against latest source files." 

1066 preview_markdown = "\n".join([ 

1067 f"# Proposed update for {trace.title}", "", 

1068 f"- Project: `{project.slug}`", 

1069 f"- Artifact: `{trace.artifact_type.value}`", 

1070 f"- Existing doc: `{len(current_doc.splitlines())}` lines", "", 

1071 "## Why this patch exists", rationale, "", 

1072 "## Suggested sections", *[f"- {s}" for s in sections], "", 

1073 "## Source citations", 

1074 *[f"- `{c.path}` at `{c.ref}`: {c.note}" for c in citations], 

1075 ]) 

1076 return DraftPatchResult( 

1077 summary=summary, rationale=rationale, proposed_sections=sections, 

1078 citations=citations, preview_markdown=preview_markdown, 

1079 ai_provider="deterministic", model_name=None, 

1080 input_summary=f"{trace.id}:{len(trace.source_paths)} sources", 

1081 chain_steps_used=0, 

1082 ) 

1083 

1084 def _build_prompt(self, *, project, trace, finding, current_doc, source_content="", project_settings, policy) -> str: 

1085 finding_summary = finding.summary if finding else "manual refresh" 

1086 source_paths = ", ".join(trace.source_paths[:10]) or "(none)" 

1087 source_block = ( 

1088 f"\nSOURCE FILES (ground truth):\n{source_content[:4000]}\n" 

1089 if source_content.strip() else "" 

1090 ) 

1091 return dedent(f""" 

1092 You are a senior technical writer drafting a precise documentation patch. 

1093 

1094 CONTEXT: 

1095 - Project: {project.name} 

1096 - Artifact: {trace.title} (type: {trace.artifact_type.value}) 

1097 - Drift finding: {finding_summary} 

1098 - Changed source files: {source_paths} 

1099 {source_block} 

1100 CURRENT DOCUMENTATION: 

1101 {current_doc[:3000]} 

1102 

1103 TASK: 

1104 1. Write a complete updated version of the above documentation 

1105 2. Incorporate ALL changes reflected in the source files above 

1106 3. Preserve sections not affected by the change 

1107 4. Use exact function names, types, and signatures from the source code 

1108 5. End with a "## Changes Made" section listing what you updated and why 

1109 

1110 Return ONLY the updated markdown document. 

1111 """).strip() 

1112 

1113 def _suggest_sections(self, *, trace, finding) -> list[str]: 

1114 if finding is not None and finding.suggested_actions: 

1115 return finding.suggested_actions 

1116 return [ 

1117 "Refresh the artifact overview against the latest source files.", 

1118 "Update examples, endpoint names, and linked flows that appear in the trace.", 

1119 "Record the latest verification and publish metadata before merge.", 

1120 ] 

1121 

1122 def _build_citations(self, *, trace, finding) -> list[PatchCitation]: 

1123 reference = ( 

1124 finding.source_revision.ref 

1125 if finding is not None and finding.source_revision is not None 

1126 else (trace.latest_source_revision.ref if trace.latest_source_revision is not None else "WORKTREE") 

1127 ) 

1128 return [ 

1129 PatchCitation(path=path, ref=reference, note="Mapped source input for this artifact.") 

1130 for path in trace.source_paths[:5] 

1131 ] 

1132 

1133 

1134_generator = PatchGenerator() 

1135 

1136 

1137def get_patch_generator() -> PatchGenerator: 

1138 return _generator