Coverage for tests / test_registry.py: 91%
115 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-29 02:55 +0800
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-29 02:55 +0800
1"""
2工具注册模块测试
4测试覆盖:
5- 工具注册装饰器
6- Schema 生成
7- 参数校验
8- 工具执行
9"""
10import pytest
11import json
12from pydantic import BaseModel, Field
13from qrclaw.tools.registry import (
14 register,
15 execute,
16 get_schemas,
17 need_confirm,
18 _tools,
19)
22# 测试用的参数模型
23class TestArgs(BaseModel):
24 text: str = Field(description="测试文本")
25 count: int = Field(default=1, description="重复次数")
28class TestTool:
29 """测试工具管理"""
31 def setup_method(self):
32 """每个测试前清空工具注册表"""
33 _tools.clear()
35 def test_register_tool(self):
36 """测试注册工具"""
37 @register(description="测试工具", args_model=TestArgs)
38 def test_tool(text: str, count: int = 1) -> str:
39 return text * count
41 # 验证注册成功
42 assert "test_tool" in _tools
43 assert _tools["test_tool"]["fn"] == test_tool
44 assert _tools["test_tool"]["confirm"] is False
46 def test_register_tool_with_confirm(self):
47 """测试注册需要确认的工具"""
48 @register(description="高风险工具", args_model=TestArgs, confirm=True)
49 def dangerous_tool(text: str, count: int = 1) -> str:
50 return text * count
52 assert need_confirm("dangerous_tool") is True
54 def test_need_confirm_default(self):
55 """测试默认不需要确认"""
56 @register(description="测试工具", args_model=TestArgs)
57 def normal_tool(text: str, count: int = 1) -> str:
58 return text * count
60 assert need_confirm("normal_tool") is False
62 def test_need_confirm_nonexistent(self):
63 """测试不存在的工具不需要确认"""
64 assert need_confirm("nonexistent_tool") is False
67class TestSchema:
68 """测试 Schema 生成"""
70 def setup_method(self):
71 """每个测试前清空工具注册表"""
72 _tools.clear()
74 def test_get_schemas(self):
75 """测试获取所有工具 schema"""
76 @register(description="工具1", args_model=TestArgs)
77 def tool1(text: str, count: int = 1) -> str:
78 return text
80 @register(description="工具2", args_model=TestArgs)
81 def tool2(text: str, count: int = 1) -> str:
82 return text
84 schemas = get_schemas()
86 assert len(schemas) == 2
87 assert any(s["function"]["name"] == "tool1" for s in schemas)
88 assert any(s["function"]["name"] == "tool2" for s in schemas)
90 def test_schema_structure(self):
91 """测试 schema 结构"""
92 @register(description="测试工具", args_model=TestArgs)
93 def test_tool(text: str, count: int = 1) -> str:
94 return text
96 schema = _tools["test_tool"]["schema"]
98 assert schema["type"] == "function"
99 assert schema["function"]["name"] == "test_tool"
100 assert schema["function"]["description"] == "测试工具"
101 assert "parameters" in schema["function"]
102 assert "properties" in schema["function"]["parameters"]
103 assert "text" in schema["function"]["parameters"]["properties"]
104 assert "count" in schema["function"]["parameters"]["properties"]
106 def test_schema_required_fields(self):
107 """测试 required 字段"""
108 @register(description="测试工具", args_model=TestArgs)
109 def test_tool(text: str, count: int = 1) -> str:
110 return text
112 schema = _tools["test_tool"]["schema"]
113 params = schema["function"]["parameters"]
115 # text 是必填,count 有默认值
116 assert "text" in params.get("required", [])
117 assert "count" not in params.get("required", [])
120class NoArgs(BaseModel):
121 """无参数工具的模型"""
122 pass
125class TestExecution:
126 """测试工具执行"""
128 def setup_method(self):
129 """每个测试前清空工具注册表"""
130 _tools.clear()
132 def test_execute_tool_success(self):
133 """测试成功执行工具"""
134 @register(description="测试工具", args_model=TestArgs)
135 def test_tool(text: str, count: int = 1) -> str:
136 return text * count
138 result = execute("test_tool", json.dumps({"text": "hello", "count": 3}))
140 assert result == "hellohellohello"
142 def test_execute_tool_default_args(self):
143 """测试使用默认参数执行"""
144 @register(description="测试工具", args_model=TestArgs)
145 def test_tool(text: str, count: int = 1) -> str:
146 return text * count
148 result = execute("test_tool", json.dumps({"text": "hi"}))
150 assert result == "hi" # count 使用默认值 1
152 def test_execute_tool_invalid_json(self):
153 """测试无效 JSON 参数"""
154 @register(description="测试工具", args_model=TestArgs)
155 def test_tool(text: str, count: int = 1) -> str:
156 return text
158 result = execute("test_tool", "not a json")
160 assert "错误" in result
161 assert "JSON" in result
163 def test_execute_tool_missing_required(self):
164 """测试缺少必填参数"""
165 @register(description="测试工具", args_model=TestArgs)
166 def test_tool(text: str, count: int = 1) -> str:
167 return text
169 result = execute("test_tool", json.dumps({}))
171 assert "错误" in result or "校验失败" in result
173 def test_execute_tool_wrong_type(self):
174 """测试参数类型错误"""
175 @register(description="测试工具", args_model=TestArgs)
176 def test_tool(text: str, count: int = 1) -> str:
177 return text * count
179 result = execute("test_tool", json.dumps({"text": "hi", "count": "not a number"}))
181 assert "错误" in result or "校验失败" in result
183 def test_execute_nonexistent_tool(self):
184 """测试执行不存在的工具"""
185 result = execute("nonexistent", json.dumps({}))
187 assert "错误" in result
188 assert "找不到工具" in result
190 def test_execute_no_args_tool(self):
191 """测试无参数工具"""
192 @register(description="无参数工具", args_model=NoArgs)
193 def no_args_tool() -> str:
194 return "success"
196 result = execute("no_args_tool", json.dumps({}))
198 assert result == "success"
200 def test_execute_tool_exception(self):
201 """测试工具执行异常"""
202 @register(description="会报错的工具", args_model=TestArgs)
203 def error_tool(text: str, count: int = 1) -> str:
204 raise ValueError("故意报错")
206 result = execute("error_tool", json.dumps({"text": "test"}))
208 # 应该捕获异常并返回错误信息
209 assert "错误" in result or "校验失败" in result