Coverage for src / documint_mcp / tree_sitter_extractor.py: 0%

368 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 22:30 -0400

1""" 

2Tree-sitter based symbol extraction for multi-language support. 

3 

4Uses py-tree-sitter (>=0.23) with per-language grammar packages to parse 

5source code into a concrete syntax tree and extract exported/public symbols. 

6 

7This module is OPTIONAL. If tree-sitter or any language grammar is not 

8installed, all functions return empty lists and log a debug message. 

9The caller should fall back to regex-based extraction. 

10 

11Supported languages: 

12- Python (tree-sitter-python) 

13- TypeScript / JavaScript (tree-sitter-typescript, tree-sitter-javascript) 

14- Rust (tree-sitter-rust) [not yet wired — reserved for future grammar package] 

15- Go (tree-sitter-go) [not yet wired — reserved for future grammar package] 

16""" 

17from __future__ import annotations 

18 

19from typing import Any 

20 

21import structlog 

22 

23logger = structlog.get_logger(__name__) 

24 

25# --------------------------------------------------------------------------- 

26# Lazy availability detection 

27# --------------------------------------------------------------------------- 

28 

29_TS_AVAILABLE: bool | None = None # None = not yet checked 

30 

31 

32def _check_tree_sitter() -> bool: 

33 """Return True if the core tree-sitter package is importable.""" 

34 global _TS_AVAILABLE 

35 if _TS_AVAILABLE is None: 

36 try: 

37 import tree_sitter # noqa: F401 

38 

39 _TS_AVAILABLE = True 

40 except ImportError: 

41 _TS_AVAILABLE = False 

42 logger.debug("tree_sitter_not_installed") 

43 return _TS_AVAILABLE 

44 

45 

46def _get_language(lang: str) -> Any | None: 

47 """ 

48 Return a tree-sitter Language object for the given language name. 

49 

50 Uses the per-language grammar packages (tree-sitter-python, etc.) 

51 which expose a ``language()`` callable returning the Language pointer. 

52 Returns None if the grammar package is not installed. 

53 """ 

54 try: 

55 if lang == "python": 

56 import tree_sitter_python as tspython 

57 

58 return tspython.language() 

59 elif lang in ("typescript", "tsx"): 

60 import tree_sitter_typescript as tstypescript 

61 

62 if lang == "tsx": 

63 return tstypescript.language_tsx() 

64 return tstypescript.language_typescript() 

65 elif lang == "javascript": 

66 import tree_sitter_javascript as tsjavascript 

67 

68 return tsjavascript.language() 

69 elif lang == "rust": 

70 import tree_sitter_rust as tsrust # noqa: F401 

71 

72 return tsrust.language() 

73 elif lang == "go": 

74 import tree_sitter_go as tsgo # noqa: F401 

75 

76 return tsgo.language() 

77 except ImportError: 

78 logger.debug("tree_sitter_grammar_not_installed", language=lang) 

79 except Exception as exc: 

80 logger.debug("tree_sitter_grammar_load_error", language=lang, error=str(exc)) 

81 return None 

82 

83 

84# --------------------------------------------------------------------------- 

85# Node text helper 

86# --------------------------------------------------------------------------- 

87 

88 

89def _node_text(node: Any, source_bytes: bytes) -> str: 

90 """Extract the text content of a tree-sitter node.""" 

91 return source_bytes[node.start_byte : node.end_byte].decode("utf-8", errors="replace") 

92 

93 

94def _node_line(node: Any) -> int: 

95 """Return the 1-based line number for a node.""" 

96 return node.start_point[0] + 1 

97 

98 

99# --------------------------------------------------------------------------- 

100# Language-specific symbol walkers 

101# --------------------------------------------------------------------------- 

102 

103 

104def _extract_python_symbols(root: Any, source_bytes: bytes) -> list[dict[str, Any]]: 

105 """Extract module-level public symbols from a Python CST.""" 

106 symbols: list[dict[str, Any]] = [] 

107 

108 for node in root.children: 

109 # Decorated definitions — unwrap to the inner definition 

110 actual = node 

111 if node.type == "decorated_definition": 

112 for child in node.children: 

113 if child.type in ("function_definition", "class_definition"): 

114 actual = child 

115 break 

116 

117 if actual.type in ("function_definition",): 

118 name_node = actual.child_by_field_name("name") 

119 if name_node: 

120 name = _node_text(name_node, source_bytes) 

121 if not name.startswith("_"): 

122 params = _extract_python_params(actual, source_bytes) 

123 ret = _extract_python_return(actual, source_bytes) 

124 symbols.append({ 

125 "n": name, 

126 "k": "fn", 

127 "p": params, 

128 "r": ret, 

129 "l": _node_line(actual), 

130 }) 

131 

132 elif actual.type == "class_definition": 

133 name_node = actual.child_by_field_name("name") 

134 if name_node: 

135 name = _node_text(name_node, source_bytes) 

136 if not name.startswith("_"): 

137 methods = _extract_python_class_methods(actual, source_bytes) 

138 symbols.append({ 

139 "n": name, 

140 "k": "class", 

141 "p": methods, 

142 "l": _node_line(actual), 

143 }) 

144 

145 elif actual.type == "expression_statement": 

146 # Type alias: MyType = ... 

147 for child in actual.children: 

148 if child.type == "assignment": 

149 left = child.child_by_field_name("left") 

150 if left and left.type == "identifier": 

151 name = _node_text(left, source_bytes) 

152 if not name.startswith("_") and name[0:1].isupper(): 

153 symbols.append({ 

154 "n": name, 

155 "k": "type", 

156 "l": _node_line(actual), 

157 }) 

158 

159 elif actual.type == "type_alias_statement": 

160 # Python 3.12+: type MyType = ... 

161 name_node = actual.child_by_field_name("name") 

162 if name_node: 

163 name = _node_text(name_node, source_bytes) 

164 if not name.startswith("_"): 

165 symbols.append({ 

166 "n": name, 

167 "k": "type", 

168 "l": _node_line(actual), 

169 }) 

170 

171 return symbols 

172 

173 

174def _extract_python_params(func_node: Any, source_bytes: bytes) -> list[str]: 

175 """Extract parameter names (excluding self/cls) from a Python function node.""" 

176 params: list[str] = [] 

177 parameters = func_node.child_by_field_name("parameters") 

178 if not parameters: 

179 return params 

180 

181 for child in parameters.children: 

182 if child.type in ("identifier",): 

183 name = _node_text(child, source_bytes) 

184 if name not in ("self", "cls"): 

185 params.append(name) 

186 elif child.type in ("typed_parameter", "default_parameter", "typed_default_parameter"): 

187 name_node = child.child_by_field_name("name") 

188 if not name_node: 

189 # For some node types the first identifier child is the name 

190 for sub in child.children: 

191 if sub.type == "identifier": 

192 name_node = sub 

193 break 

194 if name_node: 

195 name = _node_text(name_node, source_bytes) 

196 if name not in ("self", "cls"): 

197 # Include type annotation if present 

198 type_node = child.child_by_field_name("type") 

199 if type_node: 

200 params.append(f"{name}:{_node_text(type_node, source_bytes)}") 

201 else: 

202 params.append(name) 

203 return params 

204 

205 

206def _extract_python_return(func_node: Any, source_bytes: bytes) -> str | None: 

207 """Extract return type annotation from a Python function node.""" 

208 ret_node = func_node.child_by_field_name("return_type") 

209 if ret_node: 

210 return _node_text(ret_node, source_bytes) 

211 return None 

212 

213 

214def _extract_python_class_methods(class_node: Any, source_bytes: bytes) -> list[str]: 

215 """Extract public method names from a Python class body.""" 

216 methods: list[str] = [] 

217 body = class_node.child_by_field_name("body") 

218 if not body: 

219 return methods 

220 

221 for child in body.children: 

222 actual = child 

223 if child.type == "decorated_definition": 

224 for sub in child.children: 

225 if sub.type == "function_definition": 

226 actual = sub 

227 break 

228 if actual.type == "function_definition": 

229 name_node = actual.child_by_field_name("name") 

230 if name_node: 

231 name = _node_text(name_node, source_bytes) 

232 if not name.startswith("_") or name in ("__init__", "__call__"): 

233 methods.append(name) 

234 return methods 

235 

236 

237# --------------------------------------------------------------------------- 

238# TypeScript / JavaScript 

239# --------------------------------------------------------------------------- 

240 

241 

242def _extract_typescript_symbols(root: Any, source_bytes: bytes) -> list[dict[str, Any]]: 

243 """Extract exported symbols from TypeScript/JavaScript CST.""" 

244 symbols: list[dict[str, Any]] = [] 

245 

246 for node in root.children: 

247 if node.type == "export_statement": 

248 _walk_ts_export(node, source_bytes, symbols) 

249 # Also catch: export default function/class 

250 elif node.type in ( 

251 "function_declaration", 

252 "class_declaration", 

253 "interface_declaration", 

254 "type_alias_declaration", 

255 ): 

256 # Top-level declarations without export keyword are not exported 

257 pass 

258 

259 return symbols 

260 

261 

262def _walk_ts_export(export_node: Any, source_bytes: bytes, out: list[dict[str, Any]]) -> None: 

263 """Walk children of an export_statement to find the declaration.""" 

264 for child in export_node.children: 

265 if child.type in ("function_declaration", "function_signature"): 

266 name_node = child.child_by_field_name("name") 

267 if name_node: 

268 params = _extract_ts_params(child, source_bytes) 

269 ret = _extract_ts_return(child, source_bytes) 

270 out.append({ 

271 "n": _node_text(name_node, source_bytes), 

272 "k": "fn", 

273 "p": params, 

274 "r": ret, 

275 "l": _node_line(child), 

276 }) 

277 

278 elif child.type in ("class_declaration", "abstract_class_declaration"): 

279 name_node = child.child_by_field_name("name") 

280 if name_node: 

281 out.append({ 

282 "n": _node_text(name_node, source_bytes), 

283 "k": "class", 

284 "l": _node_line(child), 

285 }) 

286 

287 elif child.type == "interface_declaration": 

288 name_node = child.child_by_field_name("name") 

289 if name_node: 

290 out.append({ 

291 "n": _node_text(name_node, source_bytes), 

292 "k": "interface", 

293 "l": _node_line(child), 

294 }) 

295 

296 elif child.type == "type_alias_declaration": 

297 name_node = child.child_by_field_name("name") 

298 if name_node: 

299 out.append({ 

300 "n": _node_text(name_node, source_bytes), 

301 "k": "type", 

302 "l": _node_line(child), 

303 }) 

304 

305 elif child.type == "lexical_declaration": 

306 # export const myFn = (...) => { ... } 

307 for decl in child.children: 

308 if decl.type == "variable_declarator": 

309 name_node = decl.child_by_field_name("name") 

310 value_node = decl.child_by_field_name("value") 

311 if name_node and value_node: 

312 name = _node_text(name_node, source_bytes) 

313 if value_node.type in ("arrow_function", "function_expression"): 

314 params = _extract_ts_params(value_node, source_bytes) 

315 ret = _extract_ts_return(value_node, source_bytes) 

316 out.append({ 

317 "n": name, 

318 "k": "fn", 

319 "p": params, 

320 "r": ret, 

321 "l": _node_line(decl), 

322 }) 

323 else: 

324 out.append({ 

325 "n": name, 

326 "k": "const", 

327 "l": _node_line(decl), 

328 }) 

329 

330 elif child.type == "enum_declaration": 

331 name_node = child.child_by_field_name("name") 

332 if name_node: 

333 out.append({ 

334 "n": _node_text(name_node, source_bytes), 

335 "k": "enum", 

336 "l": _node_line(child), 

337 }) 

338 

339 

340def _extract_ts_params(node: Any, source_bytes: bytes) -> list[str]: 

341 """Extract parameter strings from a TS/JS function-like node.""" 

342 params: list[str] = [] 

343 parameters = node.child_by_field_name("parameters") 

344 if not parameters: 

345 # Arrow functions: look for formal_parameters child 

346 for child in node.children: 

347 if child.type == "formal_parameters": 

348 parameters = child 

349 break 

350 if not parameters: 

351 return params 

352 

353 for child in parameters.children: 

354 if child.type in ( 

355 "required_parameter", 

356 "optional_parameter", 

357 "rest_parameter", 

358 ): 

359 # The pattern/name is typically the first identifier-like child 

360 pattern = child.child_by_field_name("pattern") 

361 if pattern: 

362 name = _node_text(pattern, source_bytes) 

363 else: 

364 # Fallback: grab first identifier 

365 for sub in child.children: 

366 if sub.type == "identifier": 

367 name = _node_text(sub, source_bytes) 

368 break 

369 else: 

370 continue 

371 type_ann = child.child_by_field_name("type") 

372 if type_ann: 

373 params.append(f"{name}:{_node_text(type_ann, source_bytes).lstrip(': ')}") 

374 else: 

375 params.append(name) 

376 elif child.type == "identifier": 

377 params.append(_node_text(child, source_bytes)) 

378 return params 

379 

380 

381def _extract_ts_return(node: Any, source_bytes: bytes) -> str | None: 

382 """Extract return type annotation from a TS function node.""" 

383 ret_node = node.child_by_field_name("return_type") 

384 if ret_node: 

385 text = _node_text(ret_node, source_bytes).lstrip(": ") 

386 return text if text else None 

387 return None 

388 

389 

390# --------------------------------------------------------------------------- 

391# Rust 

392# --------------------------------------------------------------------------- 

393 

394 

395def _extract_rust_symbols(root: Any, source_bytes: bytes) -> list[dict[str, Any]]: 

396 """Extract pub symbols from Rust CST.""" 

397 symbols: list[dict[str, Any]] = [] 

398 _walk_rust_items(root, source_bytes, symbols) 

399 return symbols 

400 

401 

402def _walk_rust_items(node: Any, source_bytes: bytes, out: list[dict[str, Any]]) -> None: 

403 """Recursively walk Rust items looking for pub declarations.""" 

404 for child in node.children: 

405 if not _is_rust_pub(child): 

406 # Also descend into impl blocks to find pub methods 

407 if child.type in ("impl_item",): 

408 _walk_rust_items(child, source_bytes, out) 

409 # And into declaration_list (body of impl) 

410 if child.type == "declaration_list": 

411 _walk_rust_items(child, source_bytes, out) 

412 continue 

413 

414 if child.type == "function_item": 

415 name_node = child.child_by_field_name("name") 

416 if name_node: 

417 params = _extract_rust_params(child, source_bytes) 

418 ret = _extract_rust_return(child, source_bytes) 

419 out.append({ 

420 "n": _node_text(name_node, source_bytes), 

421 "k": "fn", 

422 "p": params[:8], 

423 "r": ret, 

424 "l": _node_line(child), 

425 }) 

426 

427 elif child.type == "struct_item": 

428 name_node = child.child_by_field_name("name") 

429 if name_node: 

430 out.append({ 

431 "n": _node_text(name_node, source_bytes), 

432 "k": "struct", 

433 "l": _node_line(child), 

434 }) 

435 

436 elif child.type == "enum_item": 

437 name_node = child.child_by_field_name("name") 

438 if name_node: 

439 out.append({ 

440 "n": _node_text(name_node, source_bytes), 

441 "k": "enum", 

442 "l": _node_line(child), 

443 }) 

444 

445 elif child.type == "trait_item": 

446 name_node = child.child_by_field_name("name") 

447 if name_node: 

448 out.append({ 

449 "n": _node_text(name_node, source_bytes), 

450 "k": "trait", 

451 "l": _node_line(child), 

452 }) 

453 

454 elif child.type == "type_item": 

455 name_node = child.child_by_field_name("name") 

456 if name_node: 

457 out.append({ 

458 "n": _node_text(name_node, source_bytes), 

459 "k": "type", 

460 "l": _node_line(child), 

461 }) 

462 

463 elif child.type == "const_item": 

464 name_node = child.child_by_field_name("name") 

465 if name_node: 

466 out.append({ 

467 "n": _node_text(name_node, source_bytes), 

468 "k": "const", 

469 "l": _node_line(child), 

470 }) 

471 

472 

473def _is_rust_pub(node: Any) -> bool: 

474 """Check if a Rust item has a visibility_modifier child indicating pub.""" 

475 for child in node.children: 

476 if child.type == "visibility_modifier": 

477 return True 

478 return False 

479 

480 

481def _extract_rust_params(func_node: Any, source_bytes: bytes) -> list[str]: 

482 """Extract parameter strings from a Rust function_item.""" 

483 params: list[str] = [] 

484 parameters = func_node.child_by_field_name("parameters") 

485 if not parameters: 

486 return params 

487 

488 for child in parameters.children: 

489 if child.type == "parameter": 

490 pattern = child.child_by_field_name("pattern") 

491 if pattern: 

492 name = _node_text(pattern, source_bytes) 

493 if name in ("self", "&self", "&mut self", "mut self"): 

494 continue 

495 type_node = child.child_by_field_name("type") 

496 if type_node: 

497 params.append(f"{name}:{_node_text(type_node, source_bytes)}") 

498 else: 

499 params.append(name) 

500 elif child.type == "self_parameter": 

501 continue 

502 return params 

503 

504 

505def _extract_rust_return(func_node: Any, source_bytes: bytes) -> str | None: 

506 """Extract return type from a Rust function_item.""" 

507 ret_node = func_node.child_by_field_name("return_type") 

508 if ret_node: 

509 text = _node_text(ret_node, source_bytes).lstrip("-> ").strip() 

510 return text if text else None 

511 return None 

512 

513 

514# --------------------------------------------------------------------------- 

515# Go 

516# --------------------------------------------------------------------------- 

517 

518 

519def _extract_go_symbols(root: Any, source_bytes: bytes) -> list[dict[str, Any]]: 

520 """Extract exported (capitalized) symbols from Go CST.""" 

521 symbols: list[dict[str, Any]] = [] 

522 

523 for node in root.children: 

524 if node.type == "function_declaration": 

525 name_node = node.child_by_field_name("name") 

526 if name_node: 

527 name = _node_text(name_node, source_bytes) 

528 if name[0:1].isupper(): 

529 params = _extract_go_params(node, source_bytes) 

530 ret = _extract_go_return(node, source_bytes) 

531 symbols.append({ 

532 "n": name, 

533 "k": "fn", 

534 "p": params, 

535 "r": ret, 

536 "l": _node_line(node), 

537 }) 

538 

539 elif node.type == "method_declaration": 

540 name_node = node.child_by_field_name("name") 

541 if name_node: 

542 name = _node_text(name_node, source_bytes) 

543 if name[0:1].isupper(): 

544 params = _extract_go_params(node, source_bytes) 

545 ret = _extract_go_return(node, source_bytes) 

546 symbols.append({ 

547 "n": name, 

548 "k": "fn", 

549 "p": params, 

550 "r": ret, 

551 "l": _node_line(node), 

552 }) 

553 

554 elif node.type == "type_declaration": 

555 for child in node.children: 

556 if child.type == "type_spec": 

557 name_node = child.child_by_field_name("name") 

558 if name_node: 

559 name = _node_text(name_node, source_bytes) 

560 if name[0:1].isupper(): 

561 # Determine kind from the type body 

562 type_node = child.child_by_field_name("type") 

563 kind = "struct" 

564 if type_node and type_node.type == "interface_type": 

565 kind = "interface" 

566 symbols.append({ 

567 "n": name, 

568 "k": kind, 

569 "l": _node_line(child), 

570 }) 

571 

572 elif node.type == "const_declaration": 

573 for child in node.children: 

574 if child.type == "const_spec": 

575 name_node = child.child_by_field_name("name") 

576 if name_node: 

577 name = _node_text(name_node, source_bytes) 

578 if name[0:1].isupper(): 

579 symbols.append({ 

580 "n": name, 

581 "k": "const", 

582 "l": _node_line(child), 

583 }) 

584 

585 elif node.type == "var_declaration": 

586 for child in node.children: 

587 if child.type == "var_spec": 

588 name_node = child.child_by_field_name("name") 

589 if name_node: 

590 name = _node_text(name_node, source_bytes) 

591 if name[0:1].isupper(): 

592 symbols.append({ 

593 "n": name, 

594 "k": "const", 

595 "l": _node_line(child), 

596 }) 

597 

598 return symbols 

599 

600 

601def _extract_go_params(func_node: Any, source_bytes: bytes) -> list[str]: 

602 """Extract parameter strings from a Go function/method node.""" 

603 params: list[str] = [] 

604 parameters = func_node.child_by_field_name("parameters") 

605 if not parameters: 

606 return params 

607 

608 for child in parameters.children: 

609 if child.type == "parameter_declaration": 

610 param_text = _node_text(child, source_bytes).strip() 

611 if param_text: 

612 params.append(param_text) 

613 return params 

614 

615 

616def _extract_go_return(func_node: Any, source_bytes: bytes) -> str | None: 

617 """Extract return type from a Go function/method node.""" 

618 result = func_node.child_by_field_name("result") 

619 if result: 

620 text = _node_text(result, source_bytes).strip() 

621 return text if text else None 

622 return None 

623 

624 

625# --------------------------------------------------------------------------- 

626# Public API 

627# --------------------------------------------------------------------------- 

628 

629# Map from language names to (grammar_key, walker_fn) pairs 

630_LANG_MAP: dict[str, tuple[str, Any]] = { 

631 "python": ("python", _extract_python_symbols), 

632 "typescript": ("typescript", _extract_typescript_symbols), 

633 "tsx": ("tsx", _extract_typescript_symbols), 

634 "javascript": ("javascript", _extract_typescript_symbols), 

635 "rust": ("rust", _extract_rust_symbols), 

636 "go": ("go", _extract_go_symbols), 

637} 

638 

639 

640def extract_symbols_ts(source_code: str, language: str) -> list[dict[str, Any]]: 

641 """ 

642 Parse *source_code* with tree-sitter and extract exported/public symbols. 

643 

644 Args: 

645 source_code: The source file contents as a string. 

646 language: One of "python", "typescript", "javascript", "tsx", "rust", "go". 

647 

648 Returns: 

649 A list of symbol dicts with keys matching SymbolEntry.to_lsif_compact(): 

650 ``{n: name, k: kind, p: [params], r: return_type, l: line}``. 

651 Returns an empty list if tree-sitter or the language grammar is unavailable. 

652 """ 

653 if not _check_tree_sitter(): 

654 return [] 

655 

656 entry = _LANG_MAP.get(language) 

657 if entry is None: 

658 logger.debug("tree_sitter_unsupported_language", language=language) 

659 return [] 

660 

661 grammar_key, walker_fn = entry 

662 

663 ts_language = _get_language(grammar_key) 

664 if ts_language is None: 

665 return [] 

666 

667 try: 

668 from tree_sitter import Language, Parser 

669 

670 parser = Parser(Language(ts_language)) 

671 source_bytes = source_code.encode("utf-8") 

672 tree = parser.parse(source_bytes) 

673 return walker_fn(tree.root_node, source_bytes) 

674 except Exception as exc: 

675 logger.warning( 

676 "tree_sitter_parse_error", 

677 language=language, 

678 error=str(exc), 

679 ) 

680 return [] 

681 

682 

683def is_language_supported(language: str) -> bool: 

684 """Check whether tree-sitter extraction is available for *language*.""" 

685 if not _check_tree_sitter(): 

686 return False 

687 if language not in _LANG_MAP: 

688 return False 

689 grammar_key, _ = _LANG_MAP[language] 

690 return _get_language(grammar_key) is not None