Coverage for src / documint_mcp / symbol_extractor.py: 0%
220 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 22:30 -0400
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 22:30 -0400
1"""
2Symbol extraction for semantic drift detection.
4Extracts exported/public symbols from source files using language-specific parsers:
5- Python: stdlib ast module (more reliable than tree-sitter for Python)
6- TypeScript/JavaScript: tree-sitter-typescript
7- Rust: tree-sitter-rust
8- Go: tree-sitter-go
9- Fallback: line-based heuristic regex for unsupported languages
11Only exports/public symbols are extracted (not internal/private) because docs only
12describe the public surface area. Private symbol changes should NOT trigger drift.
13"""
14from __future__ import annotations
16import ast
17import re
18from dataclasses import dataclass, field
19from pathlib import Path
20from typing import Any
22import structlog
24logger = structlog.get_logger(__name__)
27@dataclass(frozen=True)
28class SymbolEntry:
29 """A single exported/public symbol extracted from source code."""
31 name: str
32 kind: str # "fn", "class", "struct", "enum", "type", "const", "interface", "trait"
33 params: list[str] = field(default_factory=list)
34 return_type: str | None = None
35 line: int = 0
36 exported: bool = True
38 def to_lsif_compact(self) -> dict[str, Any]:
39 """LSIF-compact JSON representation with short keys for dense packing."""
40 d: dict[str, Any] = {"n": self.name, "k": self.kind, "e": self.exported, "l": self.line}
41 if self.params:
42 d["p"] = self.params
43 if self.return_type:
44 d["r"] = self.return_type
45 return d
47 def signature(self) -> str:
48 """Normalized string signature used for hashing. Order-stable."""
49 params_str = ",".join(sorted(self.params))
50 return f"{self.kind}:{self.name}({params_str})->{self.return_type or 'void'}"
53class PythonExtractor:
54 """Extract exported symbols from Python source using stdlib ast."""
56 def extract(self, source: str) -> list[SymbolEntry]:
57 try:
58 tree = ast.parse(source)
59 except SyntaxError:
60 return []
62 symbols = []
63 for node in ast.iter_child_nodes(tree):
64 # Module-level functions
65 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
66 if not node.name.startswith("_"):
67 symbols.append(self._from_funcdef(node))
68 # Module-level classes
69 elif isinstance(node, ast.ClassDef):
70 if not node.name.startswith("_"):
71 symbols.append(self._from_classdef(node))
72 # Type aliases: MyType = ... at module level (Python 3.12+: type MyType = ...)
73 elif isinstance(node, ast.Assign):
74 for target in node.targets:
75 if isinstance(target, ast.Name) and not target.id.startswith("_"):
76 if target.id[0].isupper(): # convention: uppercase = exported type alias
77 symbols.append(SymbolEntry(name=target.id, kind="type", line=node.lineno))
79 return symbols
81 def _from_funcdef(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> SymbolEntry:
82 params = []
83 for arg in node.args.args:
84 if arg.arg in ("self", "cls"):
85 continue
86 if arg.annotation:
87 try:
88 params.append(f"{arg.arg}:{ast.unparse(arg.annotation)}")
89 except Exception:
90 params.append(arg.arg)
91 else:
92 params.append(arg.arg)
94 return_type = None
95 if node.returns:
96 try:
97 return_type = ast.unparse(node.returns)
98 except Exception:
99 return_type = "Any"
101 return SymbolEntry(
102 name=node.name,
103 kind="fn",
104 params=params,
105 return_type=return_type,
106 line=node.lineno,
107 )
109 def _from_classdef(self, node: ast.ClassDef) -> SymbolEntry:
110 # Collect public methods
111 methods = []
112 for item in ast.iter_child_nodes(node):
113 if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
114 if not item.name.startswith("_") or item.name in ("__init__", "__call__"):
115 methods.append(item.name)
117 return SymbolEntry(
118 name=node.name,
119 kind="class",
120 params=methods, # methods list stored in params for class entries
121 line=node.lineno,
122 )
125class TreeSitterExtractor:
126 """Base class for tree-sitter based extractors."""
128 def _try_import_tree_sitter(self, language_name: str) -> Any | None:
129 try:
130 import tree_sitter # noqa: F401
132 return True
133 except ImportError:
134 logger.debug("tree_sitter_not_available", language=language_name)
135 return None
138class TypeScriptExtractor(TreeSitterExtractor):
139 """Extract exported symbols from TypeScript/JavaScript using tree-sitter or regex fallback."""
141 # Regex fallback patterns for when tree-sitter is not available
142 _EXPORT_FUNC = re.compile(
143 r"export\s+(?:async\s+)?function\s+(\w+)\s*\(([^)]*)\)\s*(?::\s*([^{;]+))?",
144 re.MULTILINE,
145 )
146 _EXPORT_CLASS = re.compile(r"export\s+(?:abstract\s+)?class\s+(\w+)", re.MULTILINE)
147 _EXPORT_INTERFACE = re.compile(r"export\s+interface\s+(\w+)", re.MULTILINE)
148 _EXPORT_TYPE = re.compile(r"export\s+type\s+(\w+)", re.MULTILINE)
149 _EXPORT_CONST_FN = re.compile(
150 r"export\s+const\s+(\w+)\s*=\s*(?:async\s+)?\(([^)]*)\)\s*(?::\s*([^=>{]+))?=>",
151 re.MULTILINE,
152 )
154 def extract(self, source: str) -> list[SymbolEntry]:
155 symbols: list[SymbolEntry] = []
157 for m in self._EXPORT_FUNC.finditer(source):
158 name, raw_params, ret = m.group(1), m.group(2), m.group(3)
159 params = [p.strip() for p in raw_params.split(",") if p.strip()]
160 symbols.append(
161 SymbolEntry(
162 name=name,
163 kind="fn",
164 params=params,
165 return_type=ret.strip() if ret else None,
166 line=source[: m.start()].count("\n") + 1,
167 )
168 )
170 for m in self._EXPORT_CLASS.finditer(source):
171 symbols.append(
172 SymbolEntry(name=m.group(1), kind="class", line=source[: m.start()].count("\n") + 1)
173 )
175 for m in self._EXPORT_INTERFACE.finditer(source):
176 symbols.append(
177 SymbolEntry(
178 name=m.group(1), kind="interface", line=source[: m.start()].count("\n") + 1
179 )
180 )
182 for m in self._EXPORT_TYPE.finditer(source):
183 symbols.append(
184 SymbolEntry(name=m.group(1), kind="type", line=source[: m.start()].count("\n") + 1)
185 )
187 for m in self._EXPORT_CONST_FN.finditer(source):
188 name, raw_params, ret = m.group(1), m.group(2), m.group(3)
189 params = [p.strip() for p in raw_params.split(",") if p.strip()]
190 symbols.append(
191 SymbolEntry(
192 name=name,
193 kind="fn",
194 params=params,
195 return_type=ret.strip() if ret else None,
196 line=source[: m.start()].count("\n") + 1,
197 )
198 )
200 return symbols
203class RustExtractor(TreeSitterExtractor):
204 """Extract pub symbols from Rust source using regex fallback."""
206 _PUB_FN = re.compile(
207 r"^\s*pub(?:\s+(?:async|unsafe|extern[^f]*))?\s+fn\s+(\w+)\s*(?:<[^>]*>)?\s*\(([^)]*)\)\s*(?:->\s*([^{;]+))?",
208 re.MULTILINE,
209 )
210 _PUB_STRUCT = re.compile(r"^\s*pub\s+struct\s+(\w+)", re.MULTILINE)
211 _PUB_ENUM = re.compile(r"^\s*pub\s+enum\s+(\w+)", re.MULTILINE)
212 _PUB_TRAIT = re.compile(r"^\s*pub\s+trait\s+(\w+)", re.MULTILINE)
213 _PUB_TYPE = re.compile(r"^\s*pub\s+type\s+(\w+)", re.MULTILINE)
214 _PUB_CONST = re.compile(r"^\s*pub\s+const\s+(\w+)", re.MULTILINE)
216 def extract(self, source: str) -> list[SymbolEntry]:
217 symbols: list[SymbolEntry] = []
219 for m in self._PUB_FN.finditer(source):
220 name, raw_params, ret = m.group(1), m.group(2), m.group(3)
221 params = [
222 p.strip()
223 for p in raw_params.split(",")
224 if p.strip() and p.strip() != "&self" and p.strip() != "&mut self"
225 ]
226 symbols.append(
227 SymbolEntry(
228 name=name,
229 kind="fn",
230 params=params[:8], # cap at 8 params for hash stability
231 return_type=ret.strip() if ret else None,
232 line=source[: m.start()].count("\n") + 1,
233 )
234 )
236 for pattern, kind in [
237 (self._PUB_STRUCT, "struct"),
238 (self._PUB_ENUM, "enum"),
239 (self._PUB_TRAIT, "trait"),
240 (self._PUB_TYPE, "type"),
241 (self._PUB_CONST, "const"),
242 ]:
243 for m in pattern.finditer(source):
244 symbols.append(
245 SymbolEntry(name=m.group(1), kind=kind, line=source[: m.start()].count("\n") + 1)
246 )
248 return symbols
251class GoExtractor(TreeSitterExtractor):
252 """Extract exported symbols from Go source (exported = capitalized name)."""
254 _FUNC = re.compile(
255 r"^func\s+(?:\([^)]+\)\s+)?([A-Z]\w*)\s*\(([^)]*)\)\s*(?:\(([^)]*)\)|(\w[\w*\[\]]+))?",
256 re.MULTILINE,
257 )
258 _TYPE = re.compile(r"^type\s+([A-Z]\w*)\s+(?:struct|interface|\w)", re.MULTILINE)
259 _CONST = re.compile(r"^\s+([A-Z]\w*)\s*(?:=|\w)", re.MULTILINE) # exported const in const block
261 def extract(self, source: str) -> list[SymbolEntry]:
262 symbols: list[SymbolEntry] = []
264 for m in self._FUNC.finditer(source):
265 name = m.group(1)
266 raw_params = m.group(2) or ""
267 ret = m.group(3) or m.group(4)
268 params = [p.strip() for p in raw_params.split(",") if p.strip()]
269 symbols.append(
270 SymbolEntry(
271 name=name,
272 kind="fn",
273 params=params,
274 return_type=ret.strip() if ret else None,
275 line=source[: m.start()].count("\n") + 1,
276 )
277 )
279 for m in self._TYPE.finditer(source):
280 symbols.append(
281 SymbolEntry(name=m.group(1), kind="struct", line=source[: m.start()].count("\n") + 1)
282 )
284 return symbols
287class HeuristicExtractor:
288 """Fallback extractor for unsupported languages using line-based heuristics."""
290 _DEF_PATTERNS = [
291 re.compile(
292 r"^(?:public\s+)?(?:static\s+)?(?:\w+\s+)?(\w+)\s*\([^)]*\)", re.MULTILINE
293 ), # generic method
294 re.compile(r"^def\s+([a-z]\w*)\s*\(", re.MULTILINE), # Ruby/Python style
295 re.compile(r"^(?:export\s+)?(?:function|class|interface)\s+(\w+)", re.MULTILINE),
296 ]
298 def extract(self, source: str) -> list[SymbolEntry]:
299 seen: set[str] = set()
300 symbols: list[SymbolEntry] = []
301 for pattern in self._DEF_PATTERNS:
302 for m in pattern.finditer(source):
303 name = m.group(1)
304 if name not in seen and not name.startswith("_"):
305 seen.add(name)
306 symbols.append(
307 SymbolEntry(name=name, kind="fn", line=source[: m.start()].count("\n") + 1)
308 )
309 return symbols
312def _detect_language(path: str) -> str:
313 ext = Path(path).suffix.lower()
314 return {
315 ".py": "python",
316 ".ts": "typescript",
317 ".tsx": "typescript",
318 ".js": "javascript",
319 ".jsx": "javascript",
320 ".rs": "rust",
321 ".go": "go",
322 }.get(ext, "unknown")
325def _try_tree_sitter(source: str, lang: str) -> list[SymbolEntry] | None:
326 """
327 Attempt tree-sitter extraction for *lang*.
329 Returns a list of SymbolEntry on success, or None if tree-sitter is
330 unavailable or the language grammar is missing (so the caller should
331 fall back to regex).
332 """
333 try:
334 from .tree_sitter_extractor import extract_symbols_ts, is_language_supported
335 except ImportError:
336 return None
338 if not is_language_supported(lang):
339 return None
341 raw = extract_symbols_ts(source, lang)
342 if not raw:
343 return None
345 entries: list[SymbolEntry] = []
346 for sym in raw:
347 entries.append(
348 SymbolEntry(
349 name=sym.get("n", ""),
350 kind=sym.get("k", "fn"),
351 params=sym.get("p", []),
352 return_type=sym.get("r"),
353 line=sym.get("l", 0),
354 exported=sym.get("e", True),
355 )
356 )
357 return entries
360def extract_symbols(source: str, path: str = "", language: str = "") -> list[SymbolEntry]:
361 """
362 Extract exported/public symbols from source code.
364 Extraction priority per language:
365 - Python: ast module (stdlib) first, tree-sitter second
366 - TypeScript/JavaScript: tree-sitter first, regex fallback
367 - Rust: tree-sitter first, regex fallback
368 - Go: tree-sitter first, regex fallback
369 - Other: heuristic regex
371 Args:
372 source: Source code content
373 path: File path (used to detect language if not specified)
374 language: Override language detection
376 Returns:
377 List of SymbolEntry objects, deduplicated by name
378 """
379 lang = language or _detect_language(path)
381 symbols: list[SymbolEntry] | None = None
383 if lang == "python":
384 # Python: prefer stdlib ast (more reliable for Python specifically),
385 # fall back to tree-sitter if ast fails.
386 try:
387 symbols = PythonExtractor().extract(source)
388 except Exception:
389 symbols = None
390 if not symbols:
391 symbols = _try_tree_sitter(source, lang)
392 elif lang in ("typescript", "javascript"):
393 # TypeScript/JS: tree-sitter first, regex fallback
394 symbols = _try_tree_sitter(source, lang)
395 if symbols is None:
396 try:
397 symbols = TypeScriptExtractor().extract(source)
398 except Exception as exc:
399 logger.warning("symbol_extraction_failed", path=path, language=lang, error=str(exc))
400 symbols = []
401 elif lang == "rust":
402 symbols = _try_tree_sitter(source, lang)
403 if symbols is None:
404 try:
405 symbols = RustExtractor().extract(source)
406 except Exception as exc:
407 logger.warning("symbol_extraction_failed", path=path, language=lang, error=str(exc))
408 symbols = []
409 elif lang == "go":
410 symbols = _try_tree_sitter(source, lang)
411 if symbols is None:
412 try:
413 symbols = GoExtractor().extract(source)
414 except Exception as exc:
415 logger.warning("symbol_extraction_failed", path=path, language=lang, error=str(exc))
416 symbols = []
417 else:
418 try:
419 symbols = HeuristicExtractor().extract(source)
420 except Exception as exc:
421 logger.warning("symbol_extraction_failed", path=path, language=lang, error=str(exc))
422 symbols = []
424 if symbols is None:
425 symbols = []
427 # Deduplicate by name (keep first occurrence — usually the declaration)
428 seen: set[str] = set()
429 unique = []
430 for s in symbols:
431 if s.name not in seen:
432 seen.add(s.name)
433 unique.append(s)
435 return unique
438def extract_symbols_from_files(file_contents: dict[str, str]) -> list[SymbolEntry]:
439 """Extract symbols from multiple files, merging results."""
440 all_symbols: list[SymbolEntry] = []
441 seen: set[str] = set()
443 for path, content in file_contents.items():
444 for sym in extract_symbols(content, path=path):
445 if sym.name not in seen:
446 seen.add(sym.name)
447 all_symbols.append(sym)
449 return all_symbols