Coverage for src / documint_mcp / tree_sitter_extractor.py: 0%
368 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 22:30 -0400
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 22:30 -0400
1"""
2Tree-sitter based symbol extraction for multi-language support.
4Uses py-tree-sitter (>=0.23) with per-language grammar packages to parse
5source code into a concrete syntax tree and extract exported/public symbols.
7This module is OPTIONAL. If tree-sitter or any language grammar is not
8installed, all functions return empty lists and log a debug message.
9The caller should fall back to regex-based extraction.
11Supported languages:
12- Python (tree-sitter-python)
13- TypeScript / JavaScript (tree-sitter-typescript, tree-sitter-javascript)
14- Rust (tree-sitter-rust) [not yet wired — reserved for future grammar package]
15- Go (tree-sitter-go) [not yet wired — reserved for future grammar package]
16"""
17from __future__ import annotations
19from typing import Any
21import structlog
23logger = structlog.get_logger(__name__)
25# ---------------------------------------------------------------------------
26# Lazy availability detection
27# ---------------------------------------------------------------------------
29_TS_AVAILABLE: bool | None = None # None = not yet checked
32def _check_tree_sitter() -> bool:
33 """Return True if the core tree-sitter package is importable."""
34 global _TS_AVAILABLE
35 if _TS_AVAILABLE is None:
36 try:
37 import tree_sitter # noqa: F401
39 _TS_AVAILABLE = True
40 except ImportError:
41 _TS_AVAILABLE = False
42 logger.debug("tree_sitter_not_installed")
43 return _TS_AVAILABLE
46def _get_language(lang: str) -> Any | None:
47 """
48 Return a tree-sitter Language object for the given language name.
50 Uses the per-language grammar packages (tree-sitter-python, etc.)
51 which expose a ``language()`` callable returning the Language pointer.
52 Returns None if the grammar package is not installed.
53 """
54 try:
55 if lang == "python":
56 import tree_sitter_python as tspython
58 return tspython.language()
59 elif lang in ("typescript", "tsx"):
60 import tree_sitter_typescript as tstypescript
62 if lang == "tsx":
63 return tstypescript.language_tsx()
64 return tstypescript.language_typescript()
65 elif lang == "javascript":
66 import tree_sitter_javascript as tsjavascript
68 return tsjavascript.language()
69 elif lang == "rust":
70 import tree_sitter_rust as tsrust # noqa: F401
72 return tsrust.language()
73 elif lang == "go":
74 import tree_sitter_go as tsgo # noqa: F401
76 return tsgo.language()
77 except ImportError:
78 logger.debug("tree_sitter_grammar_not_installed", language=lang)
79 except Exception as exc:
80 logger.debug("tree_sitter_grammar_load_error", language=lang, error=str(exc))
81 return None
84# ---------------------------------------------------------------------------
85# Node text helper
86# ---------------------------------------------------------------------------
89def _node_text(node: Any, source_bytes: bytes) -> str:
90 """Extract the text content of a tree-sitter node."""
91 return source_bytes[node.start_byte : node.end_byte].decode("utf-8", errors="replace")
94def _node_line(node: Any) -> int:
95 """Return the 1-based line number for a node."""
96 return node.start_point[0] + 1
99# ---------------------------------------------------------------------------
100# Language-specific symbol walkers
101# ---------------------------------------------------------------------------
104def _extract_python_symbols(root: Any, source_bytes: bytes) -> list[dict[str, Any]]:
105 """Extract module-level public symbols from a Python CST."""
106 symbols: list[dict[str, Any]] = []
108 for node in root.children:
109 # Decorated definitions — unwrap to the inner definition
110 actual = node
111 if node.type == "decorated_definition":
112 for child in node.children:
113 if child.type in ("function_definition", "class_definition"):
114 actual = child
115 break
117 if actual.type in ("function_definition",):
118 name_node = actual.child_by_field_name("name")
119 if name_node:
120 name = _node_text(name_node, source_bytes)
121 if not name.startswith("_"):
122 params = _extract_python_params(actual, source_bytes)
123 ret = _extract_python_return(actual, source_bytes)
124 symbols.append({
125 "n": name,
126 "k": "fn",
127 "p": params,
128 "r": ret,
129 "l": _node_line(actual),
130 })
132 elif actual.type == "class_definition":
133 name_node = actual.child_by_field_name("name")
134 if name_node:
135 name = _node_text(name_node, source_bytes)
136 if not name.startswith("_"):
137 methods = _extract_python_class_methods(actual, source_bytes)
138 symbols.append({
139 "n": name,
140 "k": "class",
141 "p": methods,
142 "l": _node_line(actual),
143 })
145 elif actual.type == "expression_statement":
146 # Type alias: MyType = ...
147 for child in actual.children:
148 if child.type == "assignment":
149 left = child.child_by_field_name("left")
150 if left and left.type == "identifier":
151 name = _node_text(left, source_bytes)
152 if not name.startswith("_") and name[0:1].isupper():
153 symbols.append({
154 "n": name,
155 "k": "type",
156 "l": _node_line(actual),
157 })
159 elif actual.type == "type_alias_statement":
160 # Python 3.12+: type MyType = ...
161 name_node = actual.child_by_field_name("name")
162 if name_node:
163 name = _node_text(name_node, source_bytes)
164 if not name.startswith("_"):
165 symbols.append({
166 "n": name,
167 "k": "type",
168 "l": _node_line(actual),
169 })
171 return symbols
174def _extract_python_params(func_node: Any, source_bytes: bytes) -> list[str]:
175 """Extract parameter names (excluding self/cls) from a Python function node."""
176 params: list[str] = []
177 parameters = func_node.child_by_field_name("parameters")
178 if not parameters:
179 return params
181 for child in parameters.children:
182 if child.type in ("identifier",):
183 name = _node_text(child, source_bytes)
184 if name not in ("self", "cls"):
185 params.append(name)
186 elif child.type in ("typed_parameter", "default_parameter", "typed_default_parameter"):
187 name_node = child.child_by_field_name("name")
188 if not name_node:
189 # For some node types the first identifier child is the name
190 for sub in child.children:
191 if sub.type == "identifier":
192 name_node = sub
193 break
194 if name_node:
195 name = _node_text(name_node, source_bytes)
196 if name not in ("self", "cls"):
197 # Include type annotation if present
198 type_node = child.child_by_field_name("type")
199 if type_node:
200 params.append(f"{name}:{_node_text(type_node, source_bytes)}")
201 else:
202 params.append(name)
203 return params
206def _extract_python_return(func_node: Any, source_bytes: bytes) -> str | None:
207 """Extract return type annotation from a Python function node."""
208 ret_node = func_node.child_by_field_name("return_type")
209 if ret_node:
210 return _node_text(ret_node, source_bytes)
211 return None
214def _extract_python_class_methods(class_node: Any, source_bytes: bytes) -> list[str]:
215 """Extract public method names from a Python class body."""
216 methods: list[str] = []
217 body = class_node.child_by_field_name("body")
218 if not body:
219 return methods
221 for child in body.children:
222 actual = child
223 if child.type == "decorated_definition":
224 for sub in child.children:
225 if sub.type == "function_definition":
226 actual = sub
227 break
228 if actual.type == "function_definition":
229 name_node = actual.child_by_field_name("name")
230 if name_node:
231 name = _node_text(name_node, source_bytes)
232 if not name.startswith("_") or name in ("__init__", "__call__"):
233 methods.append(name)
234 return methods
237# ---------------------------------------------------------------------------
238# TypeScript / JavaScript
239# ---------------------------------------------------------------------------
242def _extract_typescript_symbols(root: Any, source_bytes: bytes) -> list[dict[str, Any]]:
243 """Extract exported symbols from TypeScript/JavaScript CST."""
244 symbols: list[dict[str, Any]] = []
246 for node in root.children:
247 if node.type == "export_statement":
248 _walk_ts_export(node, source_bytes, symbols)
249 # Also catch: export default function/class
250 elif node.type in (
251 "function_declaration",
252 "class_declaration",
253 "interface_declaration",
254 "type_alias_declaration",
255 ):
256 # Top-level declarations without export keyword are not exported
257 pass
259 return symbols
262def _walk_ts_export(export_node: Any, source_bytes: bytes, out: list[dict[str, Any]]) -> None:
263 """Walk children of an export_statement to find the declaration."""
264 for child in export_node.children:
265 if child.type in ("function_declaration", "function_signature"):
266 name_node = child.child_by_field_name("name")
267 if name_node:
268 params = _extract_ts_params(child, source_bytes)
269 ret = _extract_ts_return(child, source_bytes)
270 out.append({
271 "n": _node_text(name_node, source_bytes),
272 "k": "fn",
273 "p": params,
274 "r": ret,
275 "l": _node_line(child),
276 })
278 elif child.type in ("class_declaration", "abstract_class_declaration"):
279 name_node = child.child_by_field_name("name")
280 if name_node:
281 out.append({
282 "n": _node_text(name_node, source_bytes),
283 "k": "class",
284 "l": _node_line(child),
285 })
287 elif child.type == "interface_declaration":
288 name_node = child.child_by_field_name("name")
289 if name_node:
290 out.append({
291 "n": _node_text(name_node, source_bytes),
292 "k": "interface",
293 "l": _node_line(child),
294 })
296 elif child.type == "type_alias_declaration":
297 name_node = child.child_by_field_name("name")
298 if name_node:
299 out.append({
300 "n": _node_text(name_node, source_bytes),
301 "k": "type",
302 "l": _node_line(child),
303 })
305 elif child.type == "lexical_declaration":
306 # export const myFn = (...) => { ... }
307 for decl in child.children:
308 if decl.type == "variable_declarator":
309 name_node = decl.child_by_field_name("name")
310 value_node = decl.child_by_field_name("value")
311 if name_node and value_node:
312 name = _node_text(name_node, source_bytes)
313 if value_node.type in ("arrow_function", "function_expression"):
314 params = _extract_ts_params(value_node, source_bytes)
315 ret = _extract_ts_return(value_node, source_bytes)
316 out.append({
317 "n": name,
318 "k": "fn",
319 "p": params,
320 "r": ret,
321 "l": _node_line(decl),
322 })
323 else:
324 out.append({
325 "n": name,
326 "k": "const",
327 "l": _node_line(decl),
328 })
330 elif child.type == "enum_declaration":
331 name_node = child.child_by_field_name("name")
332 if name_node:
333 out.append({
334 "n": _node_text(name_node, source_bytes),
335 "k": "enum",
336 "l": _node_line(child),
337 })
340def _extract_ts_params(node: Any, source_bytes: bytes) -> list[str]:
341 """Extract parameter strings from a TS/JS function-like node."""
342 params: list[str] = []
343 parameters = node.child_by_field_name("parameters")
344 if not parameters:
345 # Arrow functions: look for formal_parameters child
346 for child in node.children:
347 if child.type == "formal_parameters":
348 parameters = child
349 break
350 if not parameters:
351 return params
353 for child in parameters.children:
354 if child.type in (
355 "required_parameter",
356 "optional_parameter",
357 "rest_parameter",
358 ):
359 # The pattern/name is typically the first identifier-like child
360 pattern = child.child_by_field_name("pattern")
361 if pattern:
362 name = _node_text(pattern, source_bytes)
363 else:
364 # Fallback: grab first identifier
365 for sub in child.children:
366 if sub.type == "identifier":
367 name = _node_text(sub, source_bytes)
368 break
369 else:
370 continue
371 type_ann = child.child_by_field_name("type")
372 if type_ann:
373 params.append(f"{name}:{_node_text(type_ann, source_bytes).lstrip(': ')}")
374 else:
375 params.append(name)
376 elif child.type == "identifier":
377 params.append(_node_text(child, source_bytes))
378 return params
381def _extract_ts_return(node: Any, source_bytes: bytes) -> str | None:
382 """Extract return type annotation from a TS function node."""
383 ret_node = node.child_by_field_name("return_type")
384 if ret_node:
385 text = _node_text(ret_node, source_bytes).lstrip(": ")
386 return text if text else None
387 return None
390# ---------------------------------------------------------------------------
391# Rust
392# ---------------------------------------------------------------------------
395def _extract_rust_symbols(root: Any, source_bytes: bytes) -> list[dict[str, Any]]:
396 """Extract pub symbols from Rust CST."""
397 symbols: list[dict[str, Any]] = []
398 _walk_rust_items(root, source_bytes, symbols)
399 return symbols
402def _walk_rust_items(node: Any, source_bytes: bytes, out: list[dict[str, Any]]) -> None:
403 """Recursively walk Rust items looking for pub declarations."""
404 for child in node.children:
405 if not _is_rust_pub(child):
406 # Also descend into impl blocks to find pub methods
407 if child.type in ("impl_item",):
408 _walk_rust_items(child, source_bytes, out)
409 # And into declaration_list (body of impl)
410 if child.type == "declaration_list":
411 _walk_rust_items(child, source_bytes, out)
412 continue
414 if child.type == "function_item":
415 name_node = child.child_by_field_name("name")
416 if name_node:
417 params = _extract_rust_params(child, source_bytes)
418 ret = _extract_rust_return(child, source_bytes)
419 out.append({
420 "n": _node_text(name_node, source_bytes),
421 "k": "fn",
422 "p": params[:8],
423 "r": ret,
424 "l": _node_line(child),
425 })
427 elif child.type == "struct_item":
428 name_node = child.child_by_field_name("name")
429 if name_node:
430 out.append({
431 "n": _node_text(name_node, source_bytes),
432 "k": "struct",
433 "l": _node_line(child),
434 })
436 elif child.type == "enum_item":
437 name_node = child.child_by_field_name("name")
438 if name_node:
439 out.append({
440 "n": _node_text(name_node, source_bytes),
441 "k": "enum",
442 "l": _node_line(child),
443 })
445 elif child.type == "trait_item":
446 name_node = child.child_by_field_name("name")
447 if name_node:
448 out.append({
449 "n": _node_text(name_node, source_bytes),
450 "k": "trait",
451 "l": _node_line(child),
452 })
454 elif child.type == "type_item":
455 name_node = child.child_by_field_name("name")
456 if name_node:
457 out.append({
458 "n": _node_text(name_node, source_bytes),
459 "k": "type",
460 "l": _node_line(child),
461 })
463 elif child.type == "const_item":
464 name_node = child.child_by_field_name("name")
465 if name_node:
466 out.append({
467 "n": _node_text(name_node, source_bytes),
468 "k": "const",
469 "l": _node_line(child),
470 })
473def _is_rust_pub(node: Any) -> bool:
474 """Check if a Rust item has a visibility_modifier child indicating pub."""
475 for child in node.children:
476 if child.type == "visibility_modifier":
477 return True
478 return False
481def _extract_rust_params(func_node: Any, source_bytes: bytes) -> list[str]:
482 """Extract parameter strings from a Rust function_item."""
483 params: list[str] = []
484 parameters = func_node.child_by_field_name("parameters")
485 if not parameters:
486 return params
488 for child in parameters.children:
489 if child.type == "parameter":
490 pattern = child.child_by_field_name("pattern")
491 if pattern:
492 name = _node_text(pattern, source_bytes)
493 if name in ("self", "&self", "&mut self", "mut self"):
494 continue
495 type_node = child.child_by_field_name("type")
496 if type_node:
497 params.append(f"{name}:{_node_text(type_node, source_bytes)}")
498 else:
499 params.append(name)
500 elif child.type == "self_parameter":
501 continue
502 return params
505def _extract_rust_return(func_node: Any, source_bytes: bytes) -> str | None:
506 """Extract return type from a Rust function_item."""
507 ret_node = func_node.child_by_field_name("return_type")
508 if ret_node:
509 text = _node_text(ret_node, source_bytes).lstrip("-> ").strip()
510 return text if text else None
511 return None
514# ---------------------------------------------------------------------------
515# Go
516# ---------------------------------------------------------------------------
519def _extract_go_symbols(root: Any, source_bytes: bytes) -> list[dict[str, Any]]:
520 """Extract exported (capitalized) symbols from Go CST."""
521 symbols: list[dict[str, Any]] = []
523 for node in root.children:
524 if node.type == "function_declaration":
525 name_node = node.child_by_field_name("name")
526 if name_node:
527 name = _node_text(name_node, source_bytes)
528 if name[0:1].isupper():
529 params = _extract_go_params(node, source_bytes)
530 ret = _extract_go_return(node, source_bytes)
531 symbols.append({
532 "n": name,
533 "k": "fn",
534 "p": params,
535 "r": ret,
536 "l": _node_line(node),
537 })
539 elif node.type == "method_declaration":
540 name_node = node.child_by_field_name("name")
541 if name_node:
542 name = _node_text(name_node, source_bytes)
543 if name[0:1].isupper():
544 params = _extract_go_params(node, source_bytes)
545 ret = _extract_go_return(node, source_bytes)
546 symbols.append({
547 "n": name,
548 "k": "fn",
549 "p": params,
550 "r": ret,
551 "l": _node_line(node),
552 })
554 elif node.type == "type_declaration":
555 for child in node.children:
556 if child.type == "type_spec":
557 name_node = child.child_by_field_name("name")
558 if name_node:
559 name = _node_text(name_node, source_bytes)
560 if name[0:1].isupper():
561 # Determine kind from the type body
562 type_node = child.child_by_field_name("type")
563 kind = "struct"
564 if type_node and type_node.type == "interface_type":
565 kind = "interface"
566 symbols.append({
567 "n": name,
568 "k": kind,
569 "l": _node_line(child),
570 })
572 elif node.type == "const_declaration":
573 for child in node.children:
574 if child.type == "const_spec":
575 name_node = child.child_by_field_name("name")
576 if name_node:
577 name = _node_text(name_node, source_bytes)
578 if name[0:1].isupper():
579 symbols.append({
580 "n": name,
581 "k": "const",
582 "l": _node_line(child),
583 })
585 elif node.type == "var_declaration":
586 for child in node.children:
587 if child.type == "var_spec":
588 name_node = child.child_by_field_name("name")
589 if name_node:
590 name = _node_text(name_node, source_bytes)
591 if name[0:1].isupper():
592 symbols.append({
593 "n": name,
594 "k": "const",
595 "l": _node_line(child),
596 })
598 return symbols
601def _extract_go_params(func_node: Any, source_bytes: bytes) -> list[str]:
602 """Extract parameter strings from a Go function/method node."""
603 params: list[str] = []
604 parameters = func_node.child_by_field_name("parameters")
605 if not parameters:
606 return params
608 for child in parameters.children:
609 if child.type == "parameter_declaration":
610 param_text = _node_text(child, source_bytes).strip()
611 if param_text:
612 params.append(param_text)
613 return params
616def _extract_go_return(func_node: Any, source_bytes: bytes) -> str | None:
617 """Extract return type from a Go function/method node."""
618 result = func_node.child_by_field_name("result")
619 if result:
620 text = _node_text(result, source_bytes).strip()
621 return text if text else None
622 return None
625# ---------------------------------------------------------------------------
626# Public API
627# ---------------------------------------------------------------------------
629# Map from language names to (grammar_key, walker_fn) pairs
630_LANG_MAP: dict[str, tuple[str, Any]] = {
631 "python": ("python", _extract_python_symbols),
632 "typescript": ("typescript", _extract_typescript_symbols),
633 "tsx": ("tsx", _extract_typescript_symbols),
634 "javascript": ("javascript", _extract_typescript_symbols),
635 "rust": ("rust", _extract_rust_symbols),
636 "go": ("go", _extract_go_symbols),
637}
640def extract_symbols_ts(source_code: str, language: str) -> list[dict[str, Any]]:
641 """
642 Parse *source_code* with tree-sitter and extract exported/public symbols.
644 Args:
645 source_code: The source file contents as a string.
646 language: One of "python", "typescript", "javascript", "tsx", "rust", "go".
648 Returns:
649 A list of symbol dicts with keys matching SymbolEntry.to_lsif_compact():
650 ``{n: name, k: kind, p: [params], r: return_type, l: line}``.
651 Returns an empty list if tree-sitter or the language grammar is unavailable.
652 """
653 if not _check_tree_sitter():
654 return []
656 entry = _LANG_MAP.get(language)
657 if entry is None:
658 logger.debug("tree_sitter_unsupported_language", language=language)
659 return []
661 grammar_key, walker_fn = entry
663 ts_language = _get_language(grammar_key)
664 if ts_language is None:
665 return []
667 try:
668 from tree_sitter import Language, Parser
670 parser = Parser(Language(ts_language))
671 source_bytes = source_code.encode("utf-8")
672 tree = parser.parse(source_bytes)
673 return walker_fn(tree.root_node, source_bytes)
674 except Exception as exc:
675 logger.warning(
676 "tree_sitter_parse_error",
677 language=language,
678 error=str(exc),
679 )
680 return []
683def is_language_supported(language: str) -> bool:
684 """Check whether tree-sitter extraction is available for *language*."""
685 if not _check_tree_sitter():
686 return False
687 if language not in _LANG_MAP:
688 return False
689 grammar_key, _ = _LANG_MAP[language]
690 return _get_language(grammar_key) is not None