Coverage for agentos/tests/test_schema_enforcer.py: 0%

83 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1"""Tests for agentos.validation.schema_enforcer.""" 

2 

3from __future__ import annotations 

4 

5import pytest 

6from pydantic import BaseModel, Field 

7from agentos.validation.schema_enforcer import ( 

8 SchemaEnforcer, 

9 EnforcerConfig, 

10 EnforcerResult, 

11 FixStrategy, 

12) 

13 

14 

15class SimpleOutput(BaseModel): 

16 """测试用简单输出 schema。""" 

17 

18 name: str 

19 score: float 

20 category: str = "general" 

21 

22 

23class NestedOutput(BaseModel): 

24 """测试用嵌套输出 schema。""" 

25 

26 title: str 

27 items: list[dict] = Field(default_factory=list) 

28 meta: dict = Field(default_factory=dict) 

29 

30 

31@pytest.fixture 

32def enforcer(): 

33 return SchemaEnforcer() 

34 

35 

36@pytest.mark.asyncio 

37async def test_valid_output_passes(enforcer): 

38 """合法输出直接通过。""" 

39 output = {"name": "task1", "score": 0.95, "category": "code"} 

40 result = await enforcer.enforce(output, SimpleOutput) 

41 assert result.is_valid 

42 assert result.fix_attempts == 0 

43 assert result.repaired_output.name == "task1" 

44 

45 

46@pytest.mark.asyncio 

47async def test_missing_field_fallback(enforcer): 

48 """缺失字段使用默认值回退。""" 

49 output = {"name": "task2", "score": 0.88} 

50 result = await enforcer.enforce(output, SimpleOutput) 

51 assert result.is_valid 

52 assert result.repaired_output.category == "general" 

53 

54 

55@pytest.mark.asyncio 

56async def test_json_string_repair(enforcer): 

57 """JSON 字符串格式自动修复。""" 

58 output = '{"name": "task3", "score": 0.75,}' 

59 result = await enforcer.enforce(output, SimpleOutput) 

60 assert result.is_valid 

61 assert result.repaired_output.name == "task3" 

62 

63 

64@pytest.mark.asyncio 

65async def test_json_markdown_codeblock_repair(enforcer): 

66 """Markdown 代码块包裹的 JSON 自动修复。""" 

67 output = '```json\n{"name": "task4", "score": 0.65}\n```' 

68 result = await enforcer.enforce(output, SimpleOutput) 

69 assert result.is_valid 

70 assert result.repaired_output.name == "task4" 

71 

72 

73@pytest.mark.asyncio 

74async def test_single_quote_json_repair(enforcer): 

75 """单引号 JSON 自动修复。""" 

76 output = "{'name': 'task5', 'score': 0.55}" 

77 result = await enforcer.enforce(output, SimpleOutput) 

78 assert result.is_valid 

79 assert result.repaired_output.name == "task5" 

80 

81 

82@pytest.mark.asyncio 

83async def test_extra_field_ok(enforcer): 

84 """多余字段不影响校验。""" 

85 output = {"name": "task6", "score": 0.45, "extra_field": "ignored"} 

86 result = await enforcer.enforce(output, SimpleOutput) 

87 assert result.is_valid 

88 

89 

90@pytest.mark.asyncio 

91async def test_completely_invalid_full_fallback(enforcer): 

92 """完全无效时全默认值回退。""" 

93 output = {"wrong": "oops"} 

94 result = await enforcer.enforce(output, SimpleOutput) 

95 assert result.is_valid 

96 assert result.repaired_output.name == "" 

97 

98 

99@pytest.mark.asyncio 

100async def test_nested_output(enforcer): 

101 """嵌套 schema 校验。""" 

102 output = {"title": "report", "items": [{"a": 1}], "meta": {"page": 1}} 

103 result = await enforcer.enforce(output, NestedOutput) 

104 assert result.is_valid 

105 assert result.repaired_output.items == [{"a": 1}] 

106 

107 

108@pytest.mark.asyncio 

109async def test_stats_tracking(enforcer): 

110 """校验统计正确累加。""" 

111 await enforcer.enforce({"name": "x", "score": 1.0}, SimpleOutput) 

112 await enforcer.enforce({"bad": True}, SimpleOutput) 

113 assert enforcer.stats.total_checks == 2 

114 assert enforcer.stats.total_rejections == 1 

115 assert enforcer.stats.total_repairs >= 1 

116 

117 

118@pytest.mark.asyncio 

119async def test_enforce_batch(enforcer): 

120 """批量校验。""" 

121 outputs = [ 

122 {"name": "b1", "score": 0.9}, 

123 {"name": "b2", "score": 0.8}, 

124 {"name": "b3", "score": 0.7}, 

125 ] 

126 results = await enforcer.enforce_batch(outputs, SimpleOutput) 

127 assert len(results) == 3 

128 assert all(r.is_valid for r in results) 

129 

130 

131@pytest.mark.asyncio 

132async def test_fix_strategy_order_respected(): 

133 """自定义策略顺序生效。""" 

134 config = EnforcerConfig( 

135 strategy_order=[FixStrategy.FIELD_FALLBACK, FixStrategy.JSON_REPAIR], 

136 max_retries=1, 

137 ) 

138 enf = SchemaEnforcer(config) 

139 output = "{'name': 's', 'score': 0.3,}" 

140 result = await enf.enforce(output, SimpleOutput) 

141 assert result.is_valid