Coverage for agentos/tools/search_tools.py: 23%
105 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 18:40 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 18:40 +0800
1"""搜索工具 — 文件内容搜索、文件名匹配、代码符号搜索。"""
3from __future__ import annotations
5import fnmatch
6import os
7import re
8from typing import Any, Dict
10from agentos.tools.base import BaseTool, ToolResult
13class GrepTool(BaseTool):
14 """文件内容搜索工具 — 在目录中递归搜索匹配文本。"""
16 name = "grep"
17 description = "在目录中递归搜索文件内容,支持正则表达式,返回匹配路径和行号"
19 @property
20 def parameters(self) -> dict:
21 return {
22 "type": "object",
23 "properties": {
24 "pattern": {"type": "string", "description": "搜索的文本或正则表达式"},
25 "directory": {"type": "string", "description": "搜索目录,默认当前目录"},
26 "file_pattern": {"type": "string", "description": "文件名匹配模式,如 *.py"},
27 "max_results": {"type": "integer", "description": "最大结果数,默认 50"},
28 "case_sensitive": {"type": "boolean", "description": "是否区分大小写,默认 true"},
29 },
30 "required": ["pattern"],
31 }
33 async def execute(self, arguments: dict, sandbox=None) -> ToolResult:
34 pattern = arguments.get("pattern", "")
35 directory = arguments.get("directory", ".")
36 file_pattern = arguments.get("file_pattern", "*")
37 max_results = arguments.get("max_results", 50)
38 case_sensitive = arguments.get("case_sensitive", True)
40 flags = 0 if case_sensitive else re.IGNORECASE
41 try:
42 regex = re.compile(pattern, flags)
43 except re.error as e:
44 return ToolResult.fail(call_id="", error=f"Invalid regex: {e}")
46 results = []
47 for root, dirs, files in os.walk(os.path.abspath(directory)):
48 dirs[:] = [d for d in dirs if not d.startswith(".") and d not in ("node_modules", "__pycache__", "dist", "build", ".git")]
49 for filename in files:
50 if not fnmatch.fnmatch(filename, file_pattern):
51 continue
52 filepath = os.path.join(root, filename)
53 try:
54 with open(filepath, "r", encoding="utf-8", errors="ignore") as f:
55 for lineno, line in enumerate(f, 1):
56 if regex.search(line):
57 results.append(f"{filepath}:{lineno}: {line.strip()[:200]}")
58 if len(results) >= max_results:
59 return ToolResult.ok(call_id="", output="\n".join(results))
60 except (PermissionError, IsADirectoryError, UnicodeDecodeError):
61 continue
63 return ToolResult.ok(call_id="", output="\n".join(results) if results else "No matches found")
66class FileSearchTool(BaseTool):
67 """文件搜索工具 — 按文件名模式搜索。"""
69 name = "file_search"
70 description = "按文件名模式搜索文件,支持 glob 通配符,返回匹配的文件路径列表"
72 @property
73 def parameters(self) -> dict:
74 return {
75 "type": "object",
76 "properties": {
77 "pattern": {"type": "string", "description": "文件名匹配模式,如 *.py, report*.pdf"},
78 "directory": {"type": "string", "description": "搜索目录,默认当前目录"},
79 "max_results": {"type": "integer", "description": "最大结果数,默认 100"},
80 },
81 "required": ["pattern"],
82 }
84 async def execute(self, arguments: dict, sandbox=None) -> ToolResult:
85 pattern = arguments.get("pattern", "")
86 directory = arguments.get("directory", ".")
87 max_results = arguments.get("max_results", 100)
89 results = []
90 for root, dirs, files in os.walk(os.path.abspath(directory)):
91 dirs[:] = [d for d in dirs if not d.startswith(".") and d not in ("node_modules", "__pycache__", "dist", "build", ".git")]
92 for filename in files:
93 if fnmatch.fnmatch(filename, pattern):
94 results.append(os.path.join(root, filename))
95 if len(results) >= max_results:
96 return ToolResult.ok(call_id="", output="\n".join(results))
98 return ToolResult.ok(call_id="", output="\n".join(results) if results else "No files found")
101class CodeSearchTool(BaseTool):
102 """代码符号搜索工具 — 搜索函数/类/导入定义(基于 AST)。"""
104 name = "code_search"
105 description = "在 Python 代码中搜索函数定义、类定义、导入等符号,返回符号名和位置"
107 @property
108 def parameters(self) -> dict:
109 return {
110 "type": "object",
111 "properties": {
112 "query": {"type": "string", "description": "搜索的函数名或类名"},
113 "directory": {"type": "string", "description": "代码目录,默认当前目录"},
114 "symbol_type": {"type": "string", "description": "符号类型:function/class/import/all,默认 all"},
115 "max_results": {"type": "integer", "description": "最大结果数,默认 30"},
116 },
117 "required": ["query"],
118 }
120 async def execute(self, arguments: dict, sandbox=None) -> ToolResult:
121 import ast
123 query = arguments.get("query", "")
124 directory = arguments.get("directory", ".")
125 symbol_type = arguments.get("symbol_type", "all")
126 max_results = arguments.get("max_results", 30)
128 results = []
129 for root, dirs, files in os.walk(os.path.abspath(directory)):
130 dirs[:] = [d for d in dirs if not d.startswith(".") and d not in ("node_modules", "__pycache__", "dist", "build", ".git")]
131 for filename in files:
132 if not filename.endswith(".py"):
133 continue
134 filepath = os.path.join(root, filename)
135 try:
136 with open(filepath, "r", encoding="utf-8", errors="ignore") as f:
137 source = f.read()
138 tree = ast.parse(source, filename=filepath)
139 for node in ast.walk(tree):
140 if len(results) >= max_results:
141 break
142 name = None
143 stype = None
144 if isinstance(node, ast.FunctionDef) and symbol_type in ("function", "all"):
145 name, stype = node.name, "function"
146 elif isinstance(node, ast.AsyncFunctionDef) and symbol_type in ("function", "all"):
147 name, stype = node.name, "async_function"
148 elif isinstance(node, ast.ClassDef) and symbol_type in ("class", "all"):
149 name, stype = node.name, "class"
150 elif isinstance(node, ast.Import) and symbol_type in ("import", "all"):
151 for alias in node.names:
152 if query.lower() in alias.name.lower():
153 results.append(f"{filepath}:{node.lineno}: import {alias.name}")
154 elif isinstance(node, ast.ImportFrom) and symbol_type in ("import", "all"):
155 if query.lower() in (node.module or "").lower():
156 results.append(f"{filepath}:{node.lineno}: from {node.module} import ...")
158 if name and stype and query.lower() in name.lower():
159 results.append(f"{filepath}:{node.lineno}: [{stype}] {name}")
160 except (SyntaxError, UnicodeDecodeError, PermissionError):
161 continue
163 return ToolResult.ok(call_id="", output="\n".join(results) if results else "No symbols found")