Coverage for agentos/tools/validation.py: 39%
165 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 17:43 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 17:43 +0800
1"""
2v1.15.0 — 工具输出验证层:结构化结果验证 + 错误分类 + 自动修复建议。
4核心功能:
51. 验证工具返回结果是否符合预期格式
62. 自动分类工具执行错误
73. 提供可操作的修复建议
84. 集成到 ToolExecutor 中,提升 Agent 鲁棒性
9"""
11from __future__ import annotations
13import json
14import re
15from dataclasses import dataclass, field
16from enum import Enum, auto
17from typing import Any, Optional, Union, List, Dict
19from .base import ToolResult
20from ..errors.handler import ErrorCategory, ErrorFormatter
23class ValidationSeverity(str, Enum):
24 """验证结果严重性等级。"""
25 INFO = "info" # 信息性提示
26 WARNING = "warning" # 警告,可能有问题但可继续
27 ERROR = "error" # 错误,需要修复
28 CRITICAL = "critical" # 严重错误,必须修复
31class ValidationRule(str, Enum):
32 """验证规则类型。"""
33 JSON_FORMAT = "json_format" # JSON 格式验证
34 REQUIRED_FIELD = "required_field" # 必需字段检查
35 TYPE_CHECK = "type_check" # 类型检查
36 RANGE_CHECK = "range_check" # 范围检查
37 PATTERN_MATCH = "pattern_match" # 正则匹配
38 LENGTH_CHECK = "length_check" # 长度检查
39 ENUM_CHECK = "enum_check" # 枚举值检查
40 STRUCTURE_CHECK = "structure_check" # 结构检查
43@dataclass
44class ValidationIssue:
45 """验证问题。"""
46 rule: ValidationRule
47 severity: ValidationSeverity
48 message: str
49 field: Optional[str] = None
50 expected: Optional[Any] = None
51 actual: Optional[Any] = None
52 suggestion: str = ""
55@dataclass
56class ValidationResult:
57 """验证结果。"""
58 is_valid: bool
59 issues: List[ValidationIssue] = field(default_factory=list)
60 normalized_output: Optional[Any] = None
62 @property
63 def has_errors(self) -> bool:
64 return any(issue.severity in (ValidationSeverity.ERROR, ValidationSeverity.CRITICAL)
65 for issue in self.issues)
67 @property
68 def has_warnings(self) -> bool:
69 return any(issue.severity == ValidationSeverity.WARNING for issue in self.issues)
71 def add_issue(self, issue: ValidationIssue) -> None:
72 self.issues.append(issue)
73 if issue.severity in (ValidationSeverity.ERROR, ValidationSeverity.CRITICAL):
74 self.is_valid = False
77class ToolOutputValidator:
78 """工具输出验证器。"""
80 def __init__(self, tool_name: str):
81 self.tool_name = tool_name
82 self._rules: Dict[str, List[ValidationRule]] = {}
83 self._field_schemas: Dict[str, Dict] = {}
85 def add_rule(self, field: str, rule: ValidationRule, **kwargs) -> "ToolOutputValidator":
86 """为指定字段添加验证规则。"""
87 if field not in self._rules:
88 self._rules[field] = []
89 self._rules[field].append(rule)
91 if rule == ValidationRule.TYPE_CHECK:
92 self._field_schemas[field] = {"type": kwargs.get("expected_type")}
93 elif rule == ValidationRule.RANGE_CHECK:
94 self._field_schemas[field] = {
95 "min": kwargs.get("min"),
96 "max": kwargs.get("max")
97 }
98 elif rule == ValidationRule.PATTERN_MATCH:
99 self._field_schemas[field] = {"pattern": kwargs.get("pattern")}
100 elif rule == ValidationRule.ENUM_CHECK:
101 self._field_schemas[field] = {"allowed_values": kwargs.get("allowed_values")}
103 return self
105 def validate(self, tool_result: ToolResult) -> ValidationResult:
106 """验证工具结果。"""
107 result = ValidationResult(is_valid=True)
109 # 检查工具执行是否成功
110 if tool_result.error:
111 result.add_issue(ValidationIssue(
112 rule=ValidationRule.REQUIRED_FIELD,
113 severity=ValidationSeverity.ERROR,
114 message=f"工具执行失败: {tool_result.error}",
115 suggestion="请检查工具参数和依赖环境"
116 ))
117 return result
119 if not tool_result.output:
120 result.add_issue(ValidationIssue(
121 rule=ValidationRule.REQUIRED_FIELD,
122 severity=ValidationSeverity.WARNING,
123 message="工具返回空输出",
124 suggestion="检查工具是否按预期工作"
125 ))
126 return result
128 # 尝试解析输出
129 parsed_output = self._parse_output(tool_result.output)
130 if isinstance(parsed_output, ValidationIssue):
131 result.add_issue(parsed_output)
132 return result
134 result.normalized_output = parsed_output
136 # 应用验证规则
137 self._apply_rules(result, parsed_output)
139 return result
141 def _parse_output(self, output: str) -> Union[Any, ValidationIssue]:
142 """解析工具输出。"""
143 # 尝试解析为 JSON
144 try:
145 return json.loads(output)
146 except json.JSONDecodeError:
147 pass
149 # 尝试解析为 Python 字典格式(如 "{'key': 'value'}")
150 try:
151 # 安全地使用 eval 但限制为字面量
152 import ast
153 return ast.literal_eval(output)
154 except (SyntaxError, ValueError):
155 pass
157 # 检查是否为纯文本
158 if output.strip():
159 return {"text": output.strip()}
161 return ValidationIssue(
162 rule=ValidationRule.JSON_FORMAT,
163 severity=ValidationSeverity.ERROR,
164 message="无法解析工具输出",
165 actual=output[:100] if output else "空字符串",
166 suggestion="工具应返回 JSON 或结构化文本"
167 )
169 def _apply_rules(self, result: ValidationResult, data: Any) -> None:
170 """应用验证规则到数据。"""
171 if not isinstance(data, dict):
172 return
174 for field, rules in self._rules.items():
175 if field not in data:
176 if ValidationRule.REQUIRED_FIELD in rules:
177 result.add_issue(ValidationIssue(
178 rule=ValidationRule.REQUIRED_FIELD,
179 severity=ValidationSeverity.ERROR,
180 message=f"缺少必需字段: {field}",
181 field=field,
182 suggestion=f"工具应返回字段 '{field}'"
183 ))
184 continue
186 value = data[field]
188 for rule in rules:
189 if rule == ValidationRule.TYPE_CHECK:
190 expected_type = self._field_schemas[field]["type"]
191 if not isinstance(value, expected_type):
192 result.add_issue(ValidationIssue(
193 rule=ValidationRule.TYPE_CHECK,
194 severity=ValidationSeverity.ERROR,
195 message=f"字段类型错误: {field}",
196 field=field,
197 expected=expected_type.__name__,
198 actual=type(value).__name__,
199 suggestion=f"字段 '{field}' 应为 {expected_type.__name__} 类型"
200 ))
202 elif rule == ValidationRule.RANGE_CHECK:
203 schema = self._field_schemas[field]
204 if "min" in schema and value < schema["min"]:
205 result.add_issue(ValidationIssue(
206 rule=ValidationRule.RANGE_CHECK,
207 severity=ValidationSeverity.WARNING,
208 message=f"字段值过小: {field}",
209 field=field,
210 expected=f">= {schema['min']}",
211 actual=value,
212 suggestion=f"字段 '{field}' 应大于等于 {schema['min']}"
213 ))
214 if "max" in schema and value > schema["max"]:
215 result.add_issue(ValidationIssue(
216 rule=ValidationRule.RANGE_CHECK,
217 severity=ValidationSeverity.WARNING,
218 message=f"字段值过大: {field}",
219 field=field,
220 expected=f"<= {schema['max']}",
221 actual=value,
222 suggestion=f"字段 '{field}' 应小于等于 {schema['max']}"
223 ))
225 elif rule == ValidationRule.PATTERN_MATCH:
226 pattern = self._field_schemas[field]["pattern"]
227 if not re.match(pattern, str(value)):
228 result.add_issue(ValidationIssue(
229 rule=ValidationRule.PATTERN_MATCH,
230 severity=ValidationSeverity.ERROR,
231 message=f"字段格式错误: {field}",
232 field=field,
233 expected=f"匹配模式: {pattern}",
234 actual=value,
235 suggestion=f"字段 '{field}' 应符合正则表达式: {pattern}"
236 ))
238 elif rule == ValidationRule.ENUM_CHECK:
239 allowed = self._field_schemas[field]["allowed_values"]
240 if value not in allowed:
241 result.add_issue(ValidationIssue(
242 rule=ValidationRule.ENUM_CHECK,
243 severity=ValidationSeverity.ERROR,
244 message=f"字段值不在允许范围内: {field}",
245 field=field,
246 expected=allowed,
247 actual=value,
248 suggestion=f"字段 '{field}' 应为以下值之一: {allowed}"
249 ))
252class ToolErrorClassifier:
253 """工具错误分类器。"""
255 @staticmethod
256 def classify(tool_result: ToolResult) -> ErrorCategory:
257 """根据工具结果分类错误。"""
258 if tool_result.error:
259 error_msg = tool_result.error.lower()
261 if any(kw in error_msg for kw in ["permission", "access denied", "forbidden"]):
262 return ErrorCategory.AUTH
263 elif any(kw in error_msg for kw in ["timeout", "timed out"]):
264 return ErrorCategory.TIMEOUT
265 elif any(kw in error_msg for kw in ["network", "connection", "dns"]):
266 return ErrorCategory.NETWORK
267 elif any(kw in error_msg for kw in ["not found", "file not found", "no such file"]):
268 return ErrorCategory.VALIDATION
269 elif any(kw in error_msg for kw in ["authentication", "auth", "api key", "invalid key", "unauthorized"]):
270 return ErrorCategory.AUTH
271 elif any(kw in error_msg for kw in ["memory", "disk", "resource"]):
272 return ErrorCategory.RESOURCE
273 elif any(kw in error_msg for kw in ["syntax", "invalid", "malformed"]):
274 return ErrorCategory.VALIDATION
275 elif any(kw in error_msg for kw in ["rate limit", "too many", "quota"]):
276 return ErrorCategory.RATE_LIMIT
278 return ErrorCategory.UNKNOWN
280 @staticmethod
281 def get_recovery_suggestions(category: ErrorCategory, tool_name: str) -> List[str]:
282 """获取针对特定工具的错误恢复建议。"""
283 base_suggestions = {
284 ErrorCategory.AUTH: [
285 f"检查 {tool_name} 工具所需的权限",
286 "确认当前用户有足够的访问权限",
287 "检查 API Key 或认证令牌是否有效"
288 ],
289 ErrorCategory.TIMEOUT: [
290 f"增加 {tool_name} 工具的超时时间",
291 "检查目标服务是否正常运行",
292 "考虑使用更轻量的查询参数"
293 ],
294 ErrorCategory.NETWORK: [
295 "检查网络连接",
296 "确认目标服务地址是否正确",
297 "尝试使用代理或 VPN"
298 ],
299 ErrorCategory.VALIDATION: [
300 f"检查 {tool_name} 工具的输入参数",
301 "确认文件路径或资源是否存在",
302 "验证输入数据的格式和类型"
303 ],
304 ErrorCategory.RESOURCE: [
305 "清理磁盘空间",
306 "增加系统内存",
307 "减少并发请求数量"
308 ],
309 ErrorCategory.RATE_LIMIT: [
310 "降低请求频率",
311 "使用指数退避重试",
312 "检查 API 配额限制"
313 ],
314 ErrorCategory.UNKNOWN: [
315 f"查看 {tool_name} 工具的详细日志",
316 "检查工具依赖是否完整",
317 "尝试重启相关服务"
318 ]
319 }
321 return base_suggestions.get(category, ["请查看详细错误信息"])
324def validate_tool_output(tool_name: str, tool_result: ToolResult,
325 expected_schema: Optional[Dict] = None) -> ValidationResult:
326 """
327 验证工具输出的便捷函数。
329 Args:
330 tool_name: 工具名称
331 tool_result: 工具执行结果
332 expected_schema: 期望的输出模式(可选)
334 Returns:
335 ValidationResult: 验证结果
336 """
337 validator = ToolOutputValidator(tool_name)
339 if expected_schema:
340 for field, schema in expected_schema.items():
341 if "type" in schema:
342 validator.add_rule(field, ValidationRule.TYPE_CHECK,
343 expected_type=schema["type"])
344 if "required" in schema and schema["required"]:
345 validator.add_rule(field, ValidationRule.REQUIRED_FIELD)
346 if "pattern" in schema:
347 validator.add_rule(field, ValidationRule.PATTERN_MATCH,
348 pattern=schema["pattern"])
349 if "enum" in schema:
350 validator.add_rule(field, ValidationRule.ENUM_CHECK,
351 allowed_values=schema["enum"])
352 if "min" in schema or "max" in schema:
353 validator.add_rule(field, ValidationRule.RANGE_CHECK,
354 min=schema.get("min"),
355 max=schema.get("max"))
357 return validator.validate(tool_result)
360def classify_tool_error(tool_result: ToolResult) -> Dict[str, Any]:
361 """
362 分类工具错误并返回结构化信息。
364 Args:
365 tool_result: 工具执行结果
367 Returns:
368 Dict: 包含错误分类和恢复建议的字典
369 """
370 category = ToolErrorClassifier.classify(tool_result)
371 suggestions = ToolErrorClassifier.get_recovery_suggestions(category, "unknown")
373 return {
374 "category": category.name,
375 "error_message": tool_result.error,
376 "suggestions": suggestions,
377 "severity": "error" if category != ErrorCategory.UNKNOWN else "warning"
378 }