Coverage for agentos/validation/schema_enforcer.py: 28%
156 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""AgentOS v1.3.9 - Schema Enforcer 模块。
3对 Agent 输出执行 Pydantic schema 校验,校验失败时自动修复/重试。
4支持 JSON 修复、字段回退、LLM 辅助修正三种修复策略。
5"""
7from __future__ import annotations
9import asyncio
10import json
11import logging
12from dataclasses import dataclass, field
13from enum import Enum, auto
14from typing import Any, Callable
16logger = logging.getLogger(__name__)
19class FixStrategy(Enum):
20 """修复策略枚举。"""
22 JSON_REPAIR = auto()
23 FIELD_FALLBACK = auto()
24 LLM_ASSISTED = auto()
25 RAISE = auto()
28@dataclass
29class EnforcerResult:
30 """校验执行结果。"""
32 is_valid: bool
33 original_output: Any
34 repaired_output: Any | None = None
35 errors: list[str] = field(default_factory=list)
36 fix_strategy_used: FixStrategy | None = None
37 fix_attempts: int = 0
40@dataclass
41class EnforcerConfig:
42 """Schema Enforcer 配置。"""
44 max_retries: int = 3
45 strategy_order: list[FixStrategy] = field(
46 default_factory=lambda: [FixStrategy.JSON_REPAIR, FixStrategy.FIELD_FALLBACK, FixStrategy.LLM_ASSISTED]
47 )
48 llm_fix_prompt_template: str = ""
49 default_value_fallback: bool = True
50 log_rejections: bool = True
53@dataclass
54class EnforcerStats:
55 """校验统计。"""
57 total_checks: int = 0
58 total_rejections: int = 0
59 total_repairs: int = 0
60 repairs_by_strategy: dict[str, int] = field(default_factory=dict)
63class SchemaEnforcer:
64 """对 Agent 输出执行 Pydantic schema 校验与自动修复。
66 核心流程:
67 1. 尝试直接 model_validate
68 2. 失败时按 strategy_order 依次尝试修复
69 3. 所有策略耗尽仍失败则降级为 FIELD_FALLBACK(最佳努力)
70 """
72 def __init__(self, config: EnforcerConfig | None = None):
73 self.config = config or EnforcerConfig()
74 self.stats = EnforcerStats()
76 async def enforce(
77 self,
78 output: dict | str | Any,
79 schema_model: type,
80 context: dict | None = None,
81 ) -> EnforcerResult:
82 """对单次输出执行 schema 校验。"""
83 self.stats.total_checks += 1
84 errors: list[str] = []
86 try:
87 validated = schema_model.model_validate(output)
88 return EnforcerResult(is_valid=True, original_output=output, repaired_output=validated)
89 except Exception as e:
90 errors.append(str(e))
91 self.stats.total_rejections += 1
93 result = EnforcerResult(is_valid=False, original_output=output, errors=errors)
95 for attempt in range(self.config.max_retries):
96 for strategy in self.config.strategy_order:
97 try:
98 repaired = await self._apply_fix(strategy, output, schema_model, errors, context)
99 if repaired is not None:
100 validated = schema_model.model_validate(repaired)
101 self.stats.total_repairs += 1
102 strat_key = strategy.name
103 self.stats.repairs_by_strategy[strat_key] = (
104 self.stats.repairs_by_strategy.get(strat_key, 0) + 1
105 )
106 result.is_valid = True
107 result.repaired_output = validated
108 result.fix_strategy_used = strategy
109 result.fix_attempts = attempt + 1
110 if self.config.log_rejections:
111 logger.info(
112 "Schema fixed via %s (attempt %d/%d)",
113 strategy.name,
114 attempt + 1,
115 self.config.max_retries,
116 )
117 return result
118 except Exception as fix_error:
119 errors.append(f"[{strategy.name}] {fix_error}")
121 if self.config.default_value_fallback:
122 try:
123 fallback = self._build_fallback(schema_model)
124 self.stats.total_repairs += 1
125 self.stats.repairs_by_strategy["FALLBACK"] = (
126 self.stats.repairs_by_strategy.get("FALLBACK", 0) + 1
127 )
128 result.is_valid = True
129 result.repaired_output = fallback
130 result.fix_strategy_used = FixStrategy.FIELD_FALLBACK
131 result.fix_attempts = self.config.max_retries
132 return result
133 except Exception:
134 pass
136 return result
138 async def _apply_fix(
139 self, strategy: FixStrategy, output: Any, model: type, errors: list[str], context: dict | None
140 ) -> dict | None:
141 if strategy == FixStrategy.JSON_REPAIR:
142 return self._json_repair(output)
143 elif strategy == FixStrategy.FIELD_FALLBACK:
144 return self._field_fallback(output, model, errors)
145 elif strategy == FixStrategy.LLM_ASSISTED:
146 return await self._llm_fix(output, model, errors, context)
147 return None
149 def _json_repair(self, output: Any) -> dict | None:
150 """尝试修复 JSON 格式问题(尾部逗号、单引号、截断等)。"""
151 if isinstance(output, dict):
152 return output
153 if isinstance(output, str):
154 s = output.strip()
155 # 去除 markdown 代码块包裹
156 if s.startswith("```"):
157 lines = s.split("\n")
158 if lines[0].startswith("```"):
159 lines = lines[1:]
160 if lines and lines[-1].strip() == "```":
161 lines = lines[:-1]
162 s = "\n".join(lines)
163 # 修复常见 JSON 问题
164 s = s.replace("'", '"')
165 # 修复尾部多余逗号
166 import re
167 s = re.sub(r",(\s*[}\]])", r"\1", s)
168 try:
169 return json.loads(s)
170 except json.JSONDecodeError:
171 pass
172 return None
174 def _field_fallback(self, output: Any, model: type, errors: list[str]) -> dict | None:
175 """从原始输出中尽力提取有效字段,缺失字段填默认值。"""
176 from pydantic_core import PydanticUndefined
178 try:
179 if not isinstance(output, dict):
180 return None
181 fields_info = model.model_fields
182 clean: dict = {}
183 for key, finfo in fields_info.items():
184 if key in output:
185 clean[key] = output[key]
186 elif finfo.default is not PydanticUndefined:
187 clean[key] = finfo.default
188 elif finfo.default_factory is not None:
189 clean[key] = finfo.default_factory()
190 return clean if clean else None
191 except Exception:
192 return None
194 async def _llm_fix(
195 self, output: Any, model: type, errors: list[str], context: dict | None
196 ) -> dict | None:
197 """通过 LLM 辅助修复(调用方需注入 llm_call 回调)。"""
198 if self.config.llm_fix_prompt_template:
199 logger.warning("LLM-assisted fix requires llm_call callback (not implemented inline).")
200 return None
202 def _build_fallback(self, model: type) -> Any:
203 """使用全默认值构建回退对象。"""
204 from pydantic_core import PydanticUndefined
206 fields_info = model.model_fields
207 kwargs: dict = {}
208 for key, finfo in fields_info.items():
209 if finfo.default is not PydanticUndefined:
210 kwargs[key] = finfo.default
211 elif finfo.default_factory is not None:
212 kwargs[key] = finfo.default_factory()
213 else:
214 annotation = finfo.annotation
215 origin = getattr(annotation, "__origin__", None)
216 if annotation is str:
217 kwargs[key] = ""
218 elif annotation is int:
219 kwargs[key] = 0
220 elif annotation is float:
221 kwargs[key] = 0.0
222 elif annotation is bool:
223 kwargs[key] = False
224 elif annotation is list or origin is list:
225 kwargs[key] = []
226 elif annotation is dict or origin is dict:
227 kwargs[key] = {}
228 return model(**kwargs)
230 async def enforce_batch(
231 self,
232 outputs: list[dict | str],
233 schema_model: type,
234 context: dict | None = None,
235 ) -> list[EnforcerResult]:
236 """批量校验,利用异步并发。"""
237 tasks = [self.enforce(out, schema_model, context) for out in outputs]
238 return await asyncio.gather(*tasks)
241__all__ = [
242 "SchemaEnforcer",
243 "EnforcerConfig",
244 "EnforcerResult",
245 "EnforcerStats",
246 "FixStrategy",
247]