Coverage for src / documint_mcp / symbol_extractor.py: 0%

220 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 22:30 -0400

1""" 

2Symbol extraction for semantic drift detection. 

3 

4Extracts exported/public symbols from source files using language-specific parsers: 

5- Python: stdlib ast module (more reliable than tree-sitter for Python) 

6- TypeScript/JavaScript: tree-sitter-typescript 

7- Rust: tree-sitter-rust 

8- Go: tree-sitter-go 

9- Fallback: line-based heuristic regex for unsupported languages 

10 

11Only exports/public symbols are extracted (not internal/private) because docs only 

12describe the public surface area. Private symbol changes should NOT trigger drift. 

13""" 

14from __future__ import annotations 

15 

16import ast 

17import re 

18from dataclasses import dataclass, field 

19from pathlib import Path 

20from typing import Any 

21 

22import structlog 

23 

24logger = structlog.get_logger(__name__) 

25 

26 

27@dataclass(frozen=True) 

28class SymbolEntry: 

29 """A single exported/public symbol extracted from source code.""" 

30 

31 name: str 

32 kind: str # "fn", "class", "struct", "enum", "type", "const", "interface", "trait" 

33 params: list[str] = field(default_factory=list) 

34 return_type: str | None = None 

35 line: int = 0 

36 exported: bool = True 

37 

38 def to_lsif_compact(self) -> dict[str, Any]: 

39 """LSIF-compact JSON representation with short keys for dense packing.""" 

40 d: dict[str, Any] = {"n": self.name, "k": self.kind, "e": self.exported, "l": self.line} 

41 if self.params: 

42 d["p"] = self.params 

43 if self.return_type: 

44 d["r"] = self.return_type 

45 return d 

46 

47 def signature(self) -> str: 

48 """Normalized string signature used for hashing. Order-stable.""" 

49 params_str = ",".join(sorted(self.params)) 

50 return f"{self.kind}:{self.name}({params_str})->{self.return_type or 'void'}" 

51 

52 

53class PythonExtractor: 

54 """Extract exported symbols from Python source using stdlib ast.""" 

55 

56 def extract(self, source: str) -> list[SymbolEntry]: 

57 try: 

58 tree = ast.parse(source) 

59 except SyntaxError: 

60 return [] 

61 

62 symbols = [] 

63 for node in ast.iter_child_nodes(tree): 

64 # Module-level functions 

65 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): 

66 if not node.name.startswith("_"): 

67 symbols.append(self._from_funcdef(node)) 

68 # Module-level classes 

69 elif isinstance(node, ast.ClassDef): 

70 if not node.name.startswith("_"): 

71 symbols.append(self._from_classdef(node)) 

72 # Type aliases: MyType = ... at module level (Python 3.12+: type MyType = ...) 

73 elif isinstance(node, ast.Assign): 

74 for target in node.targets: 

75 if isinstance(target, ast.Name) and not target.id.startswith("_"): 

76 if target.id[0].isupper(): # convention: uppercase = exported type alias 

77 symbols.append(SymbolEntry(name=target.id, kind="type", line=node.lineno)) 

78 

79 return symbols 

80 

81 def _from_funcdef(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> SymbolEntry: 

82 params = [] 

83 for arg in node.args.args: 

84 if arg.arg in ("self", "cls"): 

85 continue 

86 if arg.annotation: 

87 try: 

88 params.append(f"{arg.arg}:{ast.unparse(arg.annotation)}") 

89 except Exception: 

90 params.append(arg.arg) 

91 else: 

92 params.append(arg.arg) 

93 

94 return_type = None 

95 if node.returns: 

96 try: 

97 return_type = ast.unparse(node.returns) 

98 except Exception: 

99 return_type = "Any" 

100 

101 return SymbolEntry( 

102 name=node.name, 

103 kind="fn", 

104 params=params, 

105 return_type=return_type, 

106 line=node.lineno, 

107 ) 

108 

109 def _from_classdef(self, node: ast.ClassDef) -> SymbolEntry: 

110 # Collect public methods 

111 methods = [] 

112 for item in ast.iter_child_nodes(node): 

113 if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): 

114 if not item.name.startswith("_") or item.name in ("__init__", "__call__"): 

115 methods.append(item.name) 

116 

117 return SymbolEntry( 

118 name=node.name, 

119 kind="class", 

120 params=methods, # methods list stored in params for class entries 

121 line=node.lineno, 

122 ) 

123 

124 

125class TreeSitterExtractor: 

126 """Base class for tree-sitter based extractors.""" 

127 

128 def _try_import_tree_sitter(self, language_name: str) -> Any | None: 

129 try: 

130 import tree_sitter # noqa: F401 

131 

132 return True 

133 except ImportError: 

134 logger.debug("tree_sitter_not_available", language=language_name) 

135 return None 

136 

137 

138class TypeScriptExtractor(TreeSitterExtractor): 

139 """Extract exported symbols from TypeScript/JavaScript using tree-sitter or regex fallback.""" 

140 

141 # Regex fallback patterns for when tree-sitter is not available 

142 _EXPORT_FUNC = re.compile( 

143 r"export\s+(?:async\s+)?function\s+(\w+)\s*\(([^)]*)\)\s*(?::\s*([^{;]+))?", 

144 re.MULTILINE, 

145 ) 

146 _EXPORT_CLASS = re.compile(r"export\s+(?:abstract\s+)?class\s+(\w+)", re.MULTILINE) 

147 _EXPORT_INTERFACE = re.compile(r"export\s+interface\s+(\w+)", re.MULTILINE) 

148 _EXPORT_TYPE = re.compile(r"export\s+type\s+(\w+)", re.MULTILINE) 

149 _EXPORT_CONST_FN = re.compile( 

150 r"export\s+const\s+(\w+)\s*=\s*(?:async\s+)?\(([^)]*)\)\s*(?::\s*([^=>{]+))?=>", 

151 re.MULTILINE, 

152 ) 

153 

154 def extract(self, source: str) -> list[SymbolEntry]: 

155 symbols: list[SymbolEntry] = [] 

156 

157 for m in self._EXPORT_FUNC.finditer(source): 

158 name, raw_params, ret = m.group(1), m.group(2), m.group(3) 

159 params = [p.strip() for p in raw_params.split(",") if p.strip()] 

160 symbols.append( 

161 SymbolEntry( 

162 name=name, 

163 kind="fn", 

164 params=params, 

165 return_type=ret.strip() if ret else None, 

166 line=source[: m.start()].count("\n") + 1, 

167 ) 

168 ) 

169 

170 for m in self._EXPORT_CLASS.finditer(source): 

171 symbols.append( 

172 SymbolEntry(name=m.group(1), kind="class", line=source[: m.start()].count("\n") + 1) 

173 ) 

174 

175 for m in self._EXPORT_INTERFACE.finditer(source): 

176 symbols.append( 

177 SymbolEntry( 

178 name=m.group(1), kind="interface", line=source[: m.start()].count("\n") + 1 

179 ) 

180 ) 

181 

182 for m in self._EXPORT_TYPE.finditer(source): 

183 symbols.append( 

184 SymbolEntry(name=m.group(1), kind="type", line=source[: m.start()].count("\n") + 1) 

185 ) 

186 

187 for m in self._EXPORT_CONST_FN.finditer(source): 

188 name, raw_params, ret = m.group(1), m.group(2), m.group(3) 

189 params = [p.strip() for p in raw_params.split(",") if p.strip()] 

190 symbols.append( 

191 SymbolEntry( 

192 name=name, 

193 kind="fn", 

194 params=params, 

195 return_type=ret.strip() if ret else None, 

196 line=source[: m.start()].count("\n") + 1, 

197 ) 

198 ) 

199 

200 return symbols 

201 

202 

203class RustExtractor(TreeSitterExtractor): 

204 """Extract pub symbols from Rust source using regex fallback.""" 

205 

206 _PUB_FN = re.compile( 

207 r"^\s*pub(?:\s+(?:async|unsafe|extern[^f]*))?\s+fn\s+(\w+)\s*(?:<[^>]*>)?\s*\(([^)]*)\)\s*(?:->\s*([^{;]+))?", 

208 re.MULTILINE, 

209 ) 

210 _PUB_STRUCT = re.compile(r"^\s*pub\s+struct\s+(\w+)", re.MULTILINE) 

211 _PUB_ENUM = re.compile(r"^\s*pub\s+enum\s+(\w+)", re.MULTILINE) 

212 _PUB_TRAIT = re.compile(r"^\s*pub\s+trait\s+(\w+)", re.MULTILINE) 

213 _PUB_TYPE = re.compile(r"^\s*pub\s+type\s+(\w+)", re.MULTILINE) 

214 _PUB_CONST = re.compile(r"^\s*pub\s+const\s+(\w+)", re.MULTILINE) 

215 

216 def extract(self, source: str) -> list[SymbolEntry]: 

217 symbols: list[SymbolEntry] = [] 

218 

219 for m in self._PUB_FN.finditer(source): 

220 name, raw_params, ret = m.group(1), m.group(2), m.group(3) 

221 params = [ 

222 p.strip() 

223 for p in raw_params.split(",") 

224 if p.strip() and p.strip() != "&self" and p.strip() != "&mut self" 

225 ] 

226 symbols.append( 

227 SymbolEntry( 

228 name=name, 

229 kind="fn", 

230 params=params[:8], # cap at 8 params for hash stability 

231 return_type=ret.strip() if ret else None, 

232 line=source[: m.start()].count("\n") + 1, 

233 ) 

234 ) 

235 

236 for pattern, kind in [ 

237 (self._PUB_STRUCT, "struct"), 

238 (self._PUB_ENUM, "enum"), 

239 (self._PUB_TRAIT, "trait"), 

240 (self._PUB_TYPE, "type"), 

241 (self._PUB_CONST, "const"), 

242 ]: 

243 for m in pattern.finditer(source): 

244 symbols.append( 

245 SymbolEntry(name=m.group(1), kind=kind, line=source[: m.start()].count("\n") + 1) 

246 ) 

247 

248 return symbols 

249 

250 

251class GoExtractor(TreeSitterExtractor): 

252 """Extract exported symbols from Go source (exported = capitalized name).""" 

253 

254 _FUNC = re.compile( 

255 r"^func\s+(?:\([^)]+\)\s+)?([A-Z]\w*)\s*\(([^)]*)\)\s*(?:\(([^)]*)\)|(\w[\w*\[\]]+))?", 

256 re.MULTILINE, 

257 ) 

258 _TYPE = re.compile(r"^type\s+([A-Z]\w*)\s+(?:struct|interface|\w)", re.MULTILINE) 

259 _CONST = re.compile(r"^\s+([A-Z]\w*)\s*(?:=|\w)", re.MULTILINE) # exported const in const block 

260 

261 def extract(self, source: str) -> list[SymbolEntry]: 

262 symbols: list[SymbolEntry] = [] 

263 

264 for m in self._FUNC.finditer(source): 

265 name = m.group(1) 

266 raw_params = m.group(2) or "" 

267 ret = m.group(3) or m.group(4) 

268 params = [p.strip() for p in raw_params.split(",") if p.strip()] 

269 symbols.append( 

270 SymbolEntry( 

271 name=name, 

272 kind="fn", 

273 params=params, 

274 return_type=ret.strip() if ret else None, 

275 line=source[: m.start()].count("\n") + 1, 

276 ) 

277 ) 

278 

279 for m in self._TYPE.finditer(source): 

280 symbols.append( 

281 SymbolEntry(name=m.group(1), kind="struct", line=source[: m.start()].count("\n") + 1) 

282 ) 

283 

284 return symbols 

285 

286 

287class HeuristicExtractor: 

288 """Fallback extractor for unsupported languages using line-based heuristics.""" 

289 

290 _DEF_PATTERNS = [ 

291 re.compile( 

292 r"^(?:public\s+)?(?:static\s+)?(?:\w+\s+)?(\w+)\s*\([^)]*\)", re.MULTILINE 

293 ), # generic method 

294 re.compile(r"^def\s+([a-z]\w*)\s*\(", re.MULTILINE), # Ruby/Python style 

295 re.compile(r"^(?:export\s+)?(?:function|class|interface)\s+(\w+)", re.MULTILINE), 

296 ] 

297 

298 def extract(self, source: str) -> list[SymbolEntry]: 

299 seen: set[str] = set() 

300 symbols: list[SymbolEntry] = [] 

301 for pattern in self._DEF_PATTERNS: 

302 for m in pattern.finditer(source): 

303 name = m.group(1) 

304 if name not in seen and not name.startswith("_"): 

305 seen.add(name) 

306 symbols.append( 

307 SymbolEntry(name=name, kind="fn", line=source[: m.start()].count("\n") + 1) 

308 ) 

309 return symbols 

310 

311 

312def _detect_language(path: str) -> str: 

313 ext = Path(path).suffix.lower() 

314 return { 

315 ".py": "python", 

316 ".ts": "typescript", 

317 ".tsx": "typescript", 

318 ".js": "javascript", 

319 ".jsx": "javascript", 

320 ".rs": "rust", 

321 ".go": "go", 

322 }.get(ext, "unknown") 

323 

324 

325def _try_tree_sitter(source: str, lang: str) -> list[SymbolEntry] | None: 

326 """ 

327 Attempt tree-sitter extraction for *lang*. 

328 

329 Returns a list of SymbolEntry on success, or None if tree-sitter is 

330 unavailable or the language grammar is missing (so the caller should 

331 fall back to regex). 

332 """ 

333 try: 

334 from .tree_sitter_extractor import extract_symbols_ts, is_language_supported 

335 except ImportError: 

336 return None 

337 

338 if not is_language_supported(lang): 

339 return None 

340 

341 raw = extract_symbols_ts(source, lang) 

342 if not raw: 

343 return None 

344 

345 entries: list[SymbolEntry] = [] 

346 for sym in raw: 

347 entries.append( 

348 SymbolEntry( 

349 name=sym.get("n", ""), 

350 kind=sym.get("k", "fn"), 

351 params=sym.get("p", []), 

352 return_type=sym.get("r"), 

353 line=sym.get("l", 0), 

354 exported=sym.get("e", True), 

355 ) 

356 ) 

357 return entries 

358 

359 

360def extract_symbols(source: str, path: str = "", language: str = "") -> list[SymbolEntry]: 

361 """ 

362 Extract exported/public symbols from source code. 

363 

364 Extraction priority per language: 

365 - Python: ast module (stdlib) first, tree-sitter second 

366 - TypeScript/JavaScript: tree-sitter first, regex fallback 

367 - Rust: tree-sitter first, regex fallback 

368 - Go: tree-sitter first, regex fallback 

369 - Other: heuristic regex 

370 

371 Args: 

372 source: Source code content 

373 path: File path (used to detect language if not specified) 

374 language: Override language detection 

375 

376 Returns: 

377 List of SymbolEntry objects, deduplicated by name 

378 """ 

379 lang = language or _detect_language(path) 

380 

381 symbols: list[SymbolEntry] | None = None 

382 

383 if lang == "python": 

384 # Python: prefer stdlib ast (more reliable for Python specifically), 

385 # fall back to tree-sitter if ast fails. 

386 try: 

387 symbols = PythonExtractor().extract(source) 

388 except Exception: 

389 symbols = None 

390 if not symbols: 

391 symbols = _try_tree_sitter(source, lang) 

392 elif lang in ("typescript", "javascript"): 

393 # TypeScript/JS: tree-sitter first, regex fallback 

394 symbols = _try_tree_sitter(source, lang) 

395 if symbols is None: 

396 try: 

397 symbols = TypeScriptExtractor().extract(source) 

398 except Exception as exc: 

399 logger.warning("symbol_extraction_failed", path=path, language=lang, error=str(exc)) 

400 symbols = [] 

401 elif lang == "rust": 

402 symbols = _try_tree_sitter(source, lang) 

403 if symbols is None: 

404 try: 

405 symbols = RustExtractor().extract(source) 

406 except Exception as exc: 

407 logger.warning("symbol_extraction_failed", path=path, language=lang, error=str(exc)) 

408 symbols = [] 

409 elif lang == "go": 

410 symbols = _try_tree_sitter(source, lang) 

411 if symbols is None: 

412 try: 

413 symbols = GoExtractor().extract(source) 

414 except Exception as exc: 

415 logger.warning("symbol_extraction_failed", path=path, language=lang, error=str(exc)) 

416 symbols = [] 

417 else: 

418 try: 

419 symbols = HeuristicExtractor().extract(source) 

420 except Exception as exc: 

421 logger.warning("symbol_extraction_failed", path=path, language=lang, error=str(exc)) 

422 symbols = [] 

423 

424 if symbols is None: 

425 symbols = [] 

426 

427 # Deduplicate by name (keep first occurrence — usually the declaration) 

428 seen: set[str] = set() 

429 unique = [] 

430 for s in symbols: 

431 if s.name not in seen: 

432 seen.add(s.name) 

433 unique.append(s) 

434 

435 return unique 

436 

437 

438def extract_symbols_from_files(file_contents: dict[str, str]) -> list[SymbolEntry]: 

439 """Extract symbols from multiple files, merging results.""" 

440 all_symbols: list[SymbolEntry] = [] 

441 seen: set[str] = set() 

442 

443 for path, content in file_contents.items(): 

444 for sym in extract_symbols(content, path=path): 

445 if sym.name not in seen: 

446 seen.add(sym.name) 

447 all_symbols.append(sym) 

448 

449 return all_symbols