Coverage for agentos/validation/schema_enforcer.py: 28%

156 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1"""AgentOS v1.3.9 - Schema Enforcer 模块。 

2 

3对 Agent 输出执行 Pydantic schema 校验,校验失败时自动修复/重试。 

4支持 JSON 修复、字段回退、LLM 辅助修正三种修复策略。 

5""" 

6 

7from __future__ import annotations 

8 

9import asyncio 

10import json 

11import logging 

12from dataclasses import dataclass, field 

13from enum import Enum, auto 

14from typing import Any, Callable 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19class FixStrategy(Enum): 

20 """修复策略枚举。""" 

21 

22 JSON_REPAIR = auto() 

23 FIELD_FALLBACK = auto() 

24 LLM_ASSISTED = auto() 

25 RAISE = auto() 

26 

27 

28@dataclass 

29class EnforcerResult: 

30 """校验执行结果。""" 

31 

32 is_valid: bool 

33 original_output: Any 

34 repaired_output: Any | None = None 

35 errors: list[str] = field(default_factory=list) 

36 fix_strategy_used: FixStrategy | None = None 

37 fix_attempts: int = 0 

38 

39 

40@dataclass 

41class EnforcerConfig: 

42 """Schema Enforcer 配置。""" 

43 

44 max_retries: int = 3 

45 strategy_order: list[FixStrategy] = field( 

46 default_factory=lambda: [FixStrategy.JSON_REPAIR, FixStrategy.FIELD_FALLBACK, FixStrategy.LLM_ASSISTED] 

47 ) 

48 llm_fix_prompt_template: str = "" 

49 default_value_fallback: bool = True 

50 log_rejections: bool = True 

51 

52 

53@dataclass 

54class EnforcerStats: 

55 """校验统计。""" 

56 

57 total_checks: int = 0 

58 total_rejections: int = 0 

59 total_repairs: int = 0 

60 repairs_by_strategy: dict[str, int] = field(default_factory=dict) 

61 

62 

63class SchemaEnforcer: 

64 """对 Agent 输出执行 Pydantic schema 校验与自动修复。 

65 

66 核心流程: 

67 1. 尝试直接 model_validate 

68 2. 失败时按 strategy_order 依次尝试修复 

69 3. 所有策略耗尽仍失败则降级为 FIELD_FALLBACK(最佳努力) 

70 """ 

71 

72 def __init__(self, config: EnforcerConfig | None = None): 

73 self.config = config or EnforcerConfig() 

74 self.stats = EnforcerStats() 

75 

76 async def enforce( 

77 self, 

78 output: dict | str | Any, 

79 schema_model: type, 

80 context: dict | None = None, 

81 ) -> EnforcerResult: 

82 """对单次输出执行 schema 校验。""" 

83 self.stats.total_checks += 1 

84 errors: list[str] = [] 

85 

86 try: 

87 validated = schema_model.model_validate(output) 

88 return EnforcerResult(is_valid=True, original_output=output, repaired_output=validated) 

89 except Exception as e: 

90 errors.append(str(e)) 

91 self.stats.total_rejections += 1 

92 

93 result = EnforcerResult(is_valid=False, original_output=output, errors=errors) 

94 

95 for attempt in range(self.config.max_retries): 

96 for strategy in self.config.strategy_order: 

97 try: 

98 repaired = await self._apply_fix(strategy, output, schema_model, errors, context) 

99 if repaired is not None: 

100 validated = schema_model.model_validate(repaired) 

101 self.stats.total_repairs += 1 

102 strat_key = strategy.name 

103 self.stats.repairs_by_strategy[strat_key] = ( 

104 self.stats.repairs_by_strategy.get(strat_key, 0) + 1 

105 ) 

106 result.is_valid = True 

107 result.repaired_output = validated 

108 result.fix_strategy_used = strategy 

109 result.fix_attempts = attempt + 1 

110 if self.config.log_rejections: 

111 logger.info( 

112 "Schema fixed via %s (attempt %d/%d)", 

113 strategy.name, 

114 attempt + 1, 

115 self.config.max_retries, 

116 ) 

117 return result 

118 except Exception as fix_error: 

119 errors.append(f"[{strategy.name}] {fix_error}") 

120 

121 if self.config.default_value_fallback: 

122 try: 

123 fallback = self._build_fallback(schema_model) 

124 self.stats.total_repairs += 1 

125 self.stats.repairs_by_strategy["FALLBACK"] = ( 

126 self.stats.repairs_by_strategy.get("FALLBACK", 0) + 1 

127 ) 

128 result.is_valid = True 

129 result.repaired_output = fallback 

130 result.fix_strategy_used = FixStrategy.FIELD_FALLBACK 

131 result.fix_attempts = self.config.max_retries 

132 return result 

133 except Exception: 

134 pass 

135 

136 return result 

137 

138 async def _apply_fix( 

139 self, strategy: FixStrategy, output: Any, model: type, errors: list[str], context: dict | None 

140 ) -> dict | None: 

141 if strategy == FixStrategy.JSON_REPAIR: 

142 return self._json_repair(output) 

143 elif strategy == FixStrategy.FIELD_FALLBACK: 

144 return self._field_fallback(output, model, errors) 

145 elif strategy == FixStrategy.LLM_ASSISTED: 

146 return await self._llm_fix(output, model, errors, context) 

147 return None 

148 

149 def _json_repair(self, output: Any) -> dict | None: 

150 """尝试修复 JSON 格式问题(尾部逗号、单引号、截断等)。""" 

151 if isinstance(output, dict): 

152 return output 

153 if isinstance(output, str): 

154 s = output.strip() 

155 # 去除 markdown 代码块包裹 

156 if s.startswith("```"): 

157 lines = s.split("\n") 

158 if lines[0].startswith("```"): 

159 lines = lines[1:] 

160 if lines and lines[-1].strip() == "```": 

161 lines = lines[:-1] 

162 s = "\n".join(lines) 

163 # 修复常见 JSON 问题 

164 s = s.replace("'", '"') 

165 # 修复尾部多余逗号 

166 import re 

167 s = re.sub(r",(\s*[}\]])", r"\1", s) 

168 try: 

169 return json.loads(s) 

170 except json.JSONDecodeError: 

171 pass 

172 return None 

173 

174 def _field_fallback(self, output: Any, model: type, errors: list[str]) -> dict | None: 

175 """从原始输出中尽力提取有效字段,缺失字段填默认值。""" 

176 from pydantic_core import PydanticUndefined 

177 

178 try: 

179 if not isinstance(output, dict): 

180 return None 

181 fields_info = model.model_fields 

182 clean: dict = {} 

183 for key, finfo in fields_info.items(): 

184 if key in output: 

185 clean[key] = output[key] 

186 elif finfo.default is not PydanticUndefined: 

187 clean[key] = finfo.default 

188 elif finfo.default_factory is not None: 

189 clean[key] = finfo.default_factory() 

190 return clean if clean else None 

191 except Exception: 

192 return None 

193 

194 async def _llm_fix( 

195 self, output: Any, model: type, errors: list[str], context: dict | None 

196 ) -> dict | None: 

197 """通过 LLM 辅助修复(调用方需注入 llm_call 回调)。""" 

198 if self.config.llm_fix_prompt_template: 

199 logger.warning("LLM-assisted fix requires llm_call callback (not implemented inline).") 

200 return None 

201 

202 def _build_fallback(self, model: type) -> Any: 

203 """使用全默认值构建回退对象。""" 

204 from pydantic_core import PydanticUndefined 

205 

206 fields_info = model.model_fields 

207 kwargs: dict = {} 

208 for key, finfo in fields_info.items(): 

209 if finfo.default is not PydanticUndefined: 

210 kwargs[key] = finfo.default 

211 elif finfo.default_factory is not None: 

212 kwargs[key] = finfo.default_factory() 

213 else: 

214 annotation = finfo.annotation 

215 origin = getattr(annotation, "__origin__", None) 

216 if annotation is str: 

217 kwargs[key] = "" 

218 elif annotation is int: 

219 kwargs[key] = 0 

220 elif annotation is float: 

221 kwargs[key] = 0.0 

222 elif annotation is bool: 

223 kwargs[key] = False 

224 elif annotation is list or origin is list: 

225 kwargs[key] = [] 

226 elif annotation is dict or origin is dict: 

227 kwargs[key] = {} 

228 return model(**kwargs) 

229 

230 async def enforce_batch( 

231 self, 

232 outputs: list[dict | str], 

233 schema_model: type, 

234 context: dict | None = None, 

235 ) -> list[EnforcerResult]: 

236 """批量校验,利用异步并发。""" 

237 tasks = [self.enforce(out, schema_model, context) for out in outputs] 

238 return await asyncio.gather(*tasks) 

239 

240 

241__all__ = [ 

242 "SchemaEnforcer", 

243 "EnforcerConfig", 

244 "EnforcerResult", 

245 "EnforcerStats", 

246 "FixStrategy", 

247]