Coverage for agentos/tools/validation.py: 39%

165 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 17:43 +0800

1""" 

2v1.15.0 — 工具输出验证层:结构化结果验证 + 错误分类 + 自动修复建议。 

3 

4核心功能: 

51. 验证工具返回结果是否符合预期格式 

62. 自动分类工具执行错误 

73. 提供可操作的修复建议 

84. 集成到 ToolExecutor 中,提升 Agent 鲁棒性 

9""" 

10 

11from __future__ import annotations 

12 

13import json 

14import re 

15from dataclasses import dataclass, field 

16from enum import Enum, auto 

17from typing import Any, Optional, Union, List, Dict 

18 

19from .base import ToolResult 

20from ..errors.handler import ErrorCategory, ErrorFormatter 

21 

22 

23class ValidationSeverity(str, Enum): 

24 """验证结果严重性等级。""" 

25 INFO = "info" # 信息性提示 

26 WARNING = "warning" # 警告,可能有问题但可继续 

27 ERROR = "error" # 错误,需要修复 

28 CRITICAL = "critical" # 严重错误,必须修复 

29 

30 

31class ValidationRule(str, Enum): 

32 """验证规则类型。""" 

33 JSON_FORMAT = "json_format" # JSON 格式验证 

34 REQUIRED_FIELD = "required_field" # 必需字段检查 

35 TYPE_CHECK = "type_check" # 类型检查 

36 RANGE_CHECK = "range_check" # 范围检查 

37 PATTERN_MATCH = "pattern_match" # 正则匹配 

38 LENGTH_CHECK = "length_check" # 长度检查 

39 ENUM_CHECK = "enum_check" # 枚举值检查 

40 STRUCTURE_CHECK = "structure_check" # 结构检查 

41 

42 

43@dataclass 

44class ValidationIssue: 

45 """验证问题。""" 

46 rule: ValidationRule 

47 severity: ValidationSeverity 

48 message: str 

49 field: Optional[str] = None 

50 expected: Optional[Any] = None 

51 actual: Optional[Any] = None 

52 suggestion: str = "" 

53 

54 

55@dataclass 

56class ValidationResult: 

57 """验证结果。""" 

58 is_valid: bool 

59 issues: List[ValidationIssue] = field(default_factory=list) 

60 normalized_output: Optional[Any] = None 

61 

62 @property 

63 def has_errors(self) -> bool: 

64 return any(issue.severity in (ValidationSeverity.ERROR, ValidationSeverity.CRITICAL) 

65 for issue in self.issues) 

66 

67 @property 

68 def has_warnings(self) -> bool: 

69 return any(issue.severity == ValidationSeverity.WARNING for issue in self.issues) 

70 

71 def add_issue(self, issue: ValidationIssue) -> None: 

72 self.issues.append(issue) 

73 if issue.severity in (ValidationSeverity.ERROR, ValidationSeverity.CRITICAL): 

74 self.is_valid = False 

75 

76 

77class ToolOutputValidator: 

78 """工具输出验证器。""" 

79 

80 def __init__(self, tool_name: str): 

81 self.tool_name = tool_name 

82 self._rules: Dict[str, List[ValidationRule]] = {} 

83 self._field_schemas: Dict[str, Dict] = {} 

84 

85 def add_rule(self, field: str, rule: ValidationRule, **kwargs) -> "ToolOutputValidator": 

86 """为指定字段添加验证规则。""" 

87 if field not in self._rules: 

88 self._rules[field] = [] 

89 self._rules[field].append(rule) 

90 

91 if rule == ValidationRule.TYPE_CHECK: 

92 self._field_schemas[field] = {"type": kwargs.get("expected_type")} 

93 elif rule == ValidationRule.RANGE_CHECK: 

94 self._field_schemas[field] = { 

95 "min": kwargs.get("min"), 

96 "max": kwargs.get("max") 

97 } 

98 elif rule == ValidationRule.PATTERN_MATCH: 

99 self._field_schemas[field] = {"pattern": kwargs.get("pattern")} 

100 elif rule == ValidationRule.ENUM_CHECK: 

101 self._field_schemas[field] = {"allowed_values": kwargs.get("allowed_values")} 

102 

103 return self 

104 

105 def validate(self, tool_result: ToolResult) -> ValidationResult: 

106 """验证工具结果。""" 

107 result = ValidationResult(is_valid=True) 

108 

109 # 检查工具执行是否成功 

110 if tool_result.error: 

111 result.add_issue(ValidationIssue( 

112 rule=ValidationRule.REQUIRED_FIELD, 

113 severity=ValidationSeverity.ERROR, 

114 message=f"工具执行失败: {tool_result.error}", 

115 suggestion="请检查工具参数和依赖环境" 

116 )) 

117 return result 

118 

119 if not tool_result.output: 

120 result.add_issue(ValidationIssue( 

121 rule=ValidationRule.REQUIRED_FIELD, 

122 severity=ValidationSeverity.WARNING, 

123 message="工具返回空输出", 

124 suggestion="检查工具是否按预期工作" 

125 )) 

126 return result 

127 

128 # 尝试解析输出 

129 parsed_output = self._parse_output(tool_result.output) 

130 if isinstance(parsed_output, ValidationIssue): 

131 result.add_issue(parsed_output) 

132 return result 

133 

134 result.normalized_output = parsed_output 

135 

136 # 应用验证规则 

137 self._apply_rules(result, parsed_output) 

138 

139 return result 

140 

141 def _parse_output(self, output: str) -> Union[Any, ValidationIssue]: 

142 """解析工具输出。""" 

143 # 尝试解析为 JSON 

144 try: 

145 return json.loads(output) 

146 except json.JSONDecodeError: 

147 pass 

148 

149 # 尝试解析为 Python 字典格式(如 "{'key': 'value'}") 

150 try: 

151 # 安全地使用 eval 但限制为字面量 

152 import ast 

153 return ast.literal_eval(output) 

154 except (SyntaxError, ValueError): 

155 pass 

156 

157 # 检查是否为纯文本 

158 if output.strip(): 

159 return {"text": output.strip()} 

160 

161 return ValidationIssue( 

162 rule=ValidationRule.JSON_FORMAT, 

163 severity=ValidationSeverity.ERROR, 

164 message="无法解析工具输出", 

165 actual=output[:100] if output else "空字符串", 

166 suggestion="工具应返回 JSON 或结构化文本" 

167 ) 

168 

169 def _apply_rules(self, result: ValidationResult, data: Any) -> None: 

170 """应用验证规则到数据。""" 

171 if not isinstance(data, dict): 

172 return 

173 

174 for field, rules in self._rules.items(): 

175 if field not in data: 

176 if ValidationRule.REQUIRED_FIELD in rules: 

177 result.add_issue(ValidationIssue( 

178 rule=ValidationRule.REQUIRED_FIELD, 

179 severity=ValidationSeverity.ERROR, 

180 message=f"缺少必需字段: {field}", 

181 field=field, 

182 suggestion=f"工具应返回字段 '{field}'" 

183 )) 

184 continue 

185 

186 value = data[field] 

187 

188 for rule in rules: 

189 if rule == ValidationRule.TYPE_CHECK: 

190 expected_type = self._field_schemas[field]["type"] 

191 if not isinstance(value, expected_type): 

192 result.add_issue(ValidationIssue( 

193 rule=ValidationRule.TYPE_CHECK, 

194 severity=ValidationSeverity.ERROR, 

195 message=f"字段类型错误: {field}", 

196 field=field, 

197 expected=expected_type.__name__, 

198 actual=type(value).__name__, 

199 suggestion=f"字段 '{field}' 应为 {expected_type.__name__} 类型" 

200 )) 

201 

202 elif rule == ValidationRule.RANGE_CHECK: 

203 schema = self._field_schemas[field] 

204 if "min" in schema and value < schema["min"]: 

205 result.add_issue(ValidationIssue( 

206 rule=ValidationRule.RANGE_CHECK, 

207 severity=ValidationSeverity.WARNING, 

208 message=f"字段值过小: {field}", 

209 field=field, 

210 expected=f">= {schema['min']}", 

211 actual=value, 

212 suggestion=f"字段 '{field}' 应大于等于 {schema['min']}" 

213 )) 

214 if "max" in schema and value > schema["max"]: 

215 result.add_issue(ValidationIssue( 

216 rule=ValidationRule.RANGE_CHECK, 

217 severity=ValidationSeverity.WARNING, 

218 message=f"字段值过大: {field}", 

219 field=field, 

220 expected=f"<= {schema['max']}", 

221 actual=value, 

222 suggestion=f"字段 '{field}' 应小于等于 {schema['max']}" 

223 )) 

224 

225 elif rule == ValidationRule.PATTERN_MATCH: 

226 pattern = self._field_schemas[field]["pattern"] 

227 if not re.match(pattern, str(value)): 

228 result.add_issue(ValidationIssue( 

229 rule=ValidationRule.PATTERN_MATCH, 

230 severity=ValidationSeverity.ERROR, 

231 message=f"字段格式错误: {field}", 

232 field=field, 

233 expected=f"匹配模式: {pattern}", 

234 actual=value, 

235 suggestion=f"字段 '{field}' 应符合正则表达式: {pattern}" 

236 )) 

237 

238 elif rule == ValidationRule.ENUM_CHECK: 

239 allowed = self._field_schemas[field]["allowed_values"] 

240 if value not in allowed: 

241 result.add_issue(ValidationIssue( 

242 rule=ValidationRule.ENUM_CHECK, 

243 severity=ValidationSeverity.ERROR, 

244 message=f"字段值不在允许范围内: {field}", 

245 field=field, 

246 expected=allowed, 

247 actual=value, 

248 suggestion=f"字段 '{field}' 应为以下值之一: {allowed}" 

249 )) 

250 

251 

252class ToolErrorClassifier: 

253 """工具错误分类器。""" 

254 

255 @staticmethod 

256 def classify(tool_result: ToolResult) -> ErrorCategory: 

257 """根据工具结果分类错误。""" 

258 if tool_result.error: 

259 error_msg = tool_result.error.lower() 

260 

261 if any(kw in error_msg for kw in ["permission", "access denied", "forbidden"]): 

262 return ErrorCategory.AUTH 

263 elif any(kw in error_msg for kw in ["timeout", "timed out"]): 

264 return ErrorCategory.TIMEOUT 

265 elif any(kw in error_msg for kw in ["network", "connection", "dns"]): 

266 return ErrorCategory.NETWORK 

267 elif any(kw in error_msg for kw in ["not found", "file not found", "no such file"]): 

268 return ErrorCategory.VALIDATION 

269 elif any(kw in error_msg for kw in ["authentication", "auth", "api key", "invalid key", "unauthorized"]): 

270 return ErrorCategory.AUTH 

271 elif any(kw in error_msg for kw in ["memory", "disk", "resource"]): 

272 return ErrorCategory.RESOURCE 

273 elif any(kw in error_msg for kw in ["syntax", "invalid", "malformed"]): 

274 return ErrorCategory.VALIDATION 

275 elif any(kw in error_msg for kw in ["rate limit", "too many", "quota"]): 

276 return ErrorCategory.RATE_LIMIT 

277 

278 return ErrorCategory.UNKNOWN 

279 

280 @staticmethod 

281 def get_recovery_suggestions(category: ErrorCategory, tool_name: str) -> List[str]: 

282 """获取针对特定工具的错误恢复建议。""" 

283 base_suggestions = { 

284 ErrorCategory.AUTH: [ 

285 f"检查 {tool_name} 工具所需的权限", 

286 "确认当前用户有足够的访问权限", 

287 "检查 API Key 或认证令牌是否有效" 

288 ], 

289 ErrorCategory.TIMEOUT: [ 

290 f"增加 {tool_name} 工具的超时时间", 

291 "检查目标服务是否正常运行", 

292 "考虑使用更轻量的查询参数" 

293 ], 

294 ErrorCategory.NETWORK: [ 

295 "检查网络连接", 

296 "确认目标服务地址是否正确", 

297 "尝试使用代理或 VPN" 

298 ], 

299 ErrorCategory.VALIDATION: [ 

300 f"检查 {tool_name} 工具的输入参数", 

301 "确认文件路径或资源是否存在", 

302 "验证输入数据的格式和类型" 

303 ], 

304 ErrorCategory.RESOURCE: [ 

305 "清理磁盘空间", 

306 "增加系统内存", 

307 "减少并发请求数量" 

308 ], 

309 ErrorCategory.RATE_LIMIT: [ 

310 "降低请求频率", 

311 "使用指数退避重试", 

312 "检查 API 配额限制" 

313 ], 

314 ErrorCategory.UNKNOWN: [ 

315 f"查看 {tool_name} 工具的详细日志", 

316 "检查工具依赖是否完整", 

317 "尝试重启相关服务" 

318 ] 

319 } 

320 

321 return base_suggestions.get(category, ["请查看详细错误信息"]) 

322 

323 

324def validate_tool_output(tool_name: str, tool_result: ToolResult, 

325 expected_schema: Optional[Dict] = None) -> ValidationResult: 

326 """ 

327 验证工具输出的便捷函数。 

328  

329 Args: 

330 tool_name: 工具名称 

331 tool_result: 工具执行结果 

332 expected_schema: 期望的输出模式(可选) 

333  

334 Returns: 

335 ValidationResult: 验证结果 

336 """ 

337 validator = ToolOutputValidator(tool_name) 

338 

339 if expected_schema: 

340 for field, schema in expected_schema.items(): 

341 if "type" in schema: 

342 validator.add_rule(field, ValidationRule.TYPE_CHECK, 

343 expected_type=schema["type"]) 

344 if "required" in schema and schema["required"]: 

345 validator.add_rule(field, ValidationRule.REQUIRED_FIELD) 

346 if "pattern" in schema: 

347 validator.add_rule(field, ValidationRule.PATTERN_MATCH, 

348 pattern=schema["pattern"]) 

349 if "enum" in schema: 

350 validator.add_rule(field, ValidationRule.ENUM_CHECK, 

351 allowed_values=schema["enum"]) 

352 if "min" in schema or "max" in schema: 

353 validator.add_rule(field, ValidationRule.RANGE_CHECK, 

354 min=schema.get("min"), 

355 max=schema.get("max")) 

356 

357 return validator.validate(tool_result) 

358 

359 

360def classify_tool_error(tool_result: ToolResult) -> Dict[str, Any]: 

361 """ 

362 分类工具错误并返回结构化信息。 

363  

364 Args: 

365 tool_result: 工具执行结果 

366  

367 Returns: 

368 Dict: 包含错误分类和恢复建议的字典 

369 """ 

370 category = ToolErrorClassifier.classify(tool_result) 

371 suggestions = ToolErrorClassifier.get_recovery_suggestions(category, "unknown") 

372 

373 return { 

374 "category": category.name, 

375 "error_message": tool_result.error, 

376 "suggestions": suggestions, 

377 "severity": "error" if category != ErrorCategory.UNKNOWN else "warning" 

378 }