Coverage for agentos/tools/search_tools.py: 23%

105 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 18:40 +0800

1"""搜索工具 — 文件内容搜索、文件名匹配、代码符号搜索。""" 

2 

3from __future__ import annotations 

4 

5import fnmatch 

6import os 

7import re 

8from typing import Any, Dict 

9 

10from agentos.tools.base import BaseTool, ToolResult 

11 

12 

13class GrepTool(BaseTool): 

14 """文件内容搜索工具 — 在目录中递归搜索匹配文本。""" 

15 

16 name = "grep" 

17 description = "在目录中递归搜索文件内容,支持正则表达式,返回匹配路径和行号" 

18 

19 @property 

20 def parameters(self) -> dict: 

21 return { 

22 "type": "object", 

23 "properties": { 

24 "pattern": {"type": "string", "description": "搜索的文本或正则表达式"}, 

25 "directory": {"type": "string", "description": "搜索目录,默认当前目录"}, 

26 "file_pattern": {"type": "string", "description": "文件名匹配模式,如 *.py"}, 

27 "max_results": {"type": "integer", "description": "最大结果数,默认 50"}, 

28 "case_sensitive": {"type": "boolean", "description": "是否区分大小写,默认 true"}, 

29 }, 

30 "required": ["pattern"], 

31 } 

32 

33 async def execute(self, arguments: dict, sandbox=None) -> ToolResult: 

34 pattern = arguments.get("pattern", "") 

35 directory = arguments.get("directory", ".") 

36 file_pattern = arguments.get("file_pattern", "*") 

37 max_results = arguments.get("max_results", 50) 

38 case_sensitive = arguments.get("case_sensitive", True) 

39 

40 flags = 0 if case_sensitive else re.IGNORECASE 

41 try: 

42 regex = re.compile(pattern, flags) 

43 except re.error as e: 

44 return ToolResult.fail(call_id="", error=f"Invalid regex: {e}") 

45 

46 results = [] 

47 for root, dirs, files in os.walk(os.path.abspath(directory)): 

48 dirs[:] = [d for d in dirs if not d.startswith(".") and d not in ("node_modules", "__pycache__", "dist", "build", ".git")] 

49 for filename in files: 

50 if not fnmatch.fnmatch(filename, file_pattern): 

51 continue 

52 filepath = os.path.join(root, filename) 

53 try: 

54 with open(filepath, "r", encoding="utf-8", errors="ignore") as f: 

55 for lineno, line in enumerate(f, 1): 

56 if regex.search(line): 

57 results.append(f"{filepath}:{lineno}: {line.strip()[:200]}") 

58 if len(results) >= max_results: 

59 return ToolResult.ok(call_id="", output="\n".join(results)) 

60 except (PermissionError, IsADirectoryError, UnicodeDecodeError): 

61 continue 

62 

63 return ToolResult.ok(call_id="", output="\n".join(results) if results else "No matches found") 

64 

65 

66class FileSearchTool(BaseTool): 

67 """文件搜索工具 — 按文件名模式搜索。""" 

68 

69 name = "file_search" 

70 description = "按文件名模式搜索文件,支持 glob 通配符,返回匹配的文件路径列表" 

71 

72 @property 

73 def parameters(self) -> dict: 

74 return { 

75 "type": "object", 

76 "properties": { 

77 "pattern": {"type": "string", "description": "文件名匹配模式,如 *.py, report*.pdf"}, 

78 "directory": {"type": "string", "description": "搜索目录,默认当前目录"}, 

79 "max_results": {"type": "integer", "description": "最大结果数,默认 100"}, 

80 }, 

81 "required": ["pattern"], 

82 } 

83 

84 async def execute(self, arguments: dict, sandbox=None) -> ToolResult: 

85 pattern = arguments.get("pattern", "") 

86 directory = arguments.get("directory", ".") 

87 max_results = arguments.get("max_results", 100) 

88 

89 results = [] 

90 for root, dirs, files in os.walk(os.path.abspath(directory)): 

91 dirs[:] = [d for d in dirs if not d.startswith(".") and d not in ("node_modules", "__pycache__", "dist", "build", ".git")] 

92 for filename in files: 

93 if fnmatch.fnmatch(filename, pattern): 

94 results.append(os.path.join(root, filename)) 

95 if len(results) >= max_results: 

96 return ToolResult.ok(call_id="", output="\n".join(results)) 

97 

98 return ToolResult.ok(call_id="", output="\n".join(results) if results else "No files found") 

99 

100 

101class CodeSearchTool(BaseTool): 

102 """代码符号搜索工具 — 搜索函数/类/导入定义(基于 AST)。""" 

103 

104 name = "code_search" 

105 description = "在 Python 代码中搜索函数定义、类定义、导入等符号,返回符号名和位置" 

106 

107 @property 

108 def parameters(self) -> dict: 

109 return { 

110 "type": "object", 

111 "properties": { 

112 "query": {"type": "string", "description": "搜索的函数名或类名"}, 

113 "directory": {"type": "string", "description": "代码目录,默认当前目录"}, 

114 "symbol_type": {"type": "string", "description": "符号类型:function/class/import/all,默认 all"}, 

115 "max_results": {"type": "integer", "description": "最大结果数,默认 30"}, 

116 }, 

117 "required": ["query"], 

118 } 

119 

120 async def execute(self, arguments: dict, sandbox=None) -> ToolResult: 

121 import ast 

122 

123 query = arguments.get("query", "") 

124 directory = arguments.get("directory", ".") 

125 symbol_type = arguments.get("symbol_type", "all") 

126 max_results = arguments.get("max_results", 30) 

127 

128 results = [] 

129 for root, dirs, files in os.walk(os.path.abspath(directory)): 

130 dirs[:] = [d for d in dirs if not d.startswith(".") and d not in ("node_modules", "__pycache__", "dist", "build", ".git")] 

131 for filename in files: 

132 if not filename.endswith(".py"): 

133 continue 

134 filepath = os.path.join(root, filename) 

135 try: 

136 with open(filepath, "r", encoding="utf-8", errors="ignore") as f: 

137 source = f.read() 

138 tree = ast.parse(source, filename=filepath) 

139 for node in ast.walk(tree): 

140 if len(results) >= max_results: 

141 break 

142 name = None 

143 stype = None 

144 if isinstance(node, ast.FunctionDef) and symbol_type in ("function", "all"): 

145 name, stype = node.name, "function" 

146 elif isinstance(node, ast.AsyncFunctionDef) and symbol_type in ("function", "all"): 

147 name, stype = node.name, "async_function" 

148 elif isinstance(node, ast.ClassDef) and symbol_type in ("class", "all"): 

149 name, stype = node.name, "class" 

150 elif isinstance(node, ast.Import) and symbol_type in ("import", "all"): 

151 for alias in node.names: 

152 if query.lower() in alias.name.lower(): 

153 results.append(f"{filepath}:{node.lineno}: import {alias.name}") 

154 elif isinstance(node, ast.ImportFrom) and symbol_type in ("import", "all"): 

155 if query.lower() in (node.module or "").lower(): 

156 results.append(f"{filepath}:{node.lineno}: from {node.module} import ...") 

157 

158 if name and stype and query.lower() in name.lower(): 

159 results.append(f"{filepath}:{node.lineno}: [{stype}] {name}") 

160 except (SyntaxError, UnicodeDecodeError, PermissionError): 

161 continue 

162 

163 return ToolResult.ok(call_id="", output="\n".join(results) if results else "No symbols found")