Coverage for agentos/tests/test_schema_enforcer.py: 0%
83 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""Tests for agentos.validation.schema_enforcer."""
3from __future__ import annotations
5import pytest
6from pydantic import BaseModel, Field
7from agentos.validation.schema_enforcer import (
8 SchemaEnforcer,
9 EnforcerConfig,
10 EnforcerResult,
11 FixStrategy,
12)
15class SimpleOutput(BaseModel):
16 """测试用简单输出 schema。"""
18 name: str
19 score: float
20 category: str = "general"
23class NestedOutput(BaseModel):
24 """测试用嵌套输出 schema。"""
26 title: str
27 items: list[dict] = Field(default_factory=list)
28 meta: dict = Field(default_factory=dict)
31@pytest.fixture
32def enforcer():
33 return SchemaEnforcer()
36@pytest.mark.asyncio
37async def test_valid_output_passes(enforcer):
38 """合法输出直接通过。"""
39 output = {"name": "task1", "score": 0.95, "category": "code"}
40 result = await enforcer.enforce(output, SimpleOutput)
41 assert result.is_valid
42 assert result.fix_attempts == 0
43 assert result.repaired_output.name == "task1"
46@pytest.mark.asyncio
47async def test_missing_field_fallback(enforcer):
48 """缺失字段使用默认值回退。"""
49 output = {"name": "task2", "score": 0.88}
50 result = await enforcer.enforce(output, SimpleOutput)
51 assert result.is_valid
52 assert result.repaired_output.category == "general"
55@pytest.mark.asyncio
56async def test_json_string_repair(enforcer):
57 """JSON 字符串格式自动修复。"""
58 output = '{"name": "task3", "score": 0.75,}'
59 result = await enforcer.enforce(output, SimpleOutput)
60 assert result.is_valid
61 assert result.repaired_output.name == "task3"
64@pytest.mark.asyncio
65async def test_json_markdown_codeblock_repair(enforcer):
66 """Markdown 代码块包裹的 JSON 自动修复。"""
67 output = '```json\n{"name": "task4", "score": 0.65}\n```'
68 result = await enforcer.enforce(output, SimpleOutput)
69 assert result.is_valid
70 assert result.repaired_output.name == "task4"
73@pytest.mark.asyncio
74async def test_single_quote_json_repair(enforcer):
75 """单引号 JSON 自动修复。"""
76 output = "{'name': 'task5', 'score': 0.55}"
77 result = await enforcer.enforce(output, SimpleOutput)
78 assert result.is_valid
79 assert result.repaired_output.name == "task5"
82@pytest.mark.asyncio
83async def test_extra_field_ok(enforcer):
84 """多余字段不影响校验。"""
85 output = {"name": "task6", "score": 0.45, "extra_field": "ignored"}
86 result = await enforcer.enforce(output, SimpleOutput)
87 assert result.is_valid
90@pytest.mark.asyncio
91async def test_completely_invalid_full_fallback(enforcer):
92 """完全无效时全默认值回退。"""
93 output = {"wrong": "oops"}
94 result = await enforcer.enforce(output, SimpleOutput)
95 assert result.is_valid
96 assert result.repaired_output.name == ""
99@pytest.mark.asyncio
100async def test_nested_output(enforcer):
101 """嵌套 schema 校验。"""
102 output = {"title": "report", "items": [{"a": 1}], "meta": {"page": 1}}
103 result = await enforcer.enforce(output, NestedOutput)
104 assert result.is_valid
105 assert result.repaired_output.items == [{"a": 1}]
108@pytest.mark.asyncio
109async def test_stats_tracking(enforcer):
110 """校验统计正确累加。"""
111 await enforcer.enforce({"name": "x", "score": 1.0}, SimpleOutput)
112 await enforcer.enforce({"bad": True}, SimpleOutput)
113 assert enforcer.stats.total_checks == 2
114 assert enforcer.stats.total_rejections == 1
115 assert enforcer.stats.total_repairs >= 1
118@pytest.mark.asyncio
119async def test_enforce_batch(enforcer):
120 """批量校验。"""
121 outputs = [
122 {"name": "b1", "score": 0.9},
123 {"name": "b2", "score": 0.8},
124 {"name": "b3", "score": 0.7},
125 ]
126 results = await enforcer.enforce_batch(outputs, SimpleOutput)
127 assert len(results) == 3
128 assert all(r.is_valid for r in results)
131@pytest.mark.asyncio
132async def test_fix_strategy_order_respected():
133 """自定义策略顺序生效。"""
134 config = EnforcerConfig(
135 strategy_order=[FixStrategy.FIELD_FALLBACK, FixStrategy.JSON_REPAIR],
136 max_retries=1,
137 )
138 enf = SchemaEnforcer(config)
139 output = "{'name': 's', 'score': 0.3,}"
140 result = await enf.enforce(output, SimpleOutput)
141 assert result.is_valid