Coverage for tests / test_registry.py: 91%

115 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-29 02:55 +0800

1""" 

2工具注册模块测试 

3 

4测试覆盖: 

5- 工具注册装饰器 

6- Schema 生成 

7- 参数校验 

8- 工具执行 

9""" 

10import pytest 

11import json 

12from pydantic import BaseModel, Field 

13from qrclaw.tools.registry import ( 

14 register, 

15 execute, 

16 get_schemas, 

17 need_confirm, 

18 _tools, 

19) 

20 

21 

22# 测试用的参数模型 

23class TestArgs(BaseModel): 

24 text: str = Field(description="测试文本") 

25 count: int = Field(default=1, description="重复次数") 

26 

27 

28class TestTool: 

29 """测试工具管理""" 

30 

31 def setup_method(self): 

32 """每个测试前清空工具注册表""" 

33 _tools.clear() 

34 

35 def test_register_tool(self): 

36 """测试注册工具""" 

37 @register(description="测试工具", args_model=TestArgs) 

38 def test_tool(text: str, count: int = 1) -> str: 

39 return text * count 

40 

41 # 验证注册成功 

42 assert "test_tool" in _tools 

43 assert _tools["test_tool"]["fn"] == test_tool 

44 assert _tools["test_tool"]["confirm"] is False 

45 

46 def test_register_tool_with_confirm(self): 

47 """测试注册需要确认的工具""" 

48 @register(description="高风险工具", args_model=TestArgs, confirm=True) 

49 def dangerous_tool(text: str, count: int = 1) -> str: 

50 return text * count 

51 

52 assert need_confirm("dangerous_tool") is True 

53 

54 def test_need_confirm_default(self): 

55 """测试默认不需要确认""" 

56 @register(description="测试工具", args_model=TestArgs) 

57 def normal_tool(text: str, count: int = 1) -> str: 

58 return text * count 

59 

60 assert need_confirm("normal_tool") is False 

61 

62 def test_need_confirm_nonexistent(self): 

63 """测试不存在的工具不需要确认""" 

64 assert need_confirm("nonexistent_tool") is False 

65 

66 

67class TestSchema: 

68 """测试 Schema 生成""" 

69 

70 def setup_method(self): 

71 """每个测试前清空工具注册表""" 

72 _tools.clear() 

73 

74 def test_get_schemas(self): 

75 """测试获取所有工具 schema""" 

76 @register(description="工具1", args_model=TestArgs) 

77 def tool1(text: str, count: int = 1) -> str: 

78 return text 

79 

80 @register(description="工具2", args_model=TestArgs) 

81 def tool2(text: str, count: int = 1) -> str: 

82 return text 

83 

84 schemas = get_schemas() 

85 

86 assert len(schemas) == 2 

87 assert any(s["function"]["name"] == "tool1" for s in schemas) 

88 assert any(s["function"]["name"] == "tool2" for s in schemas) 

89 

90 def test_schema_structure(self): 

91 """测试 schema 结构""" 

92 @register(description="测试工具", args_model=TestArgs) 

93 def test_tool(text: str, count: int = 1) -> str: 

94 return text 

95 

96 schema = _tools["test_tool"]["schema"] 

97 

98 assert schema["type"] == "function" 

99 assert schema["function"]["name"] == "test_tool" 

100 assert schema["function"]["description"] == "测试工具" 

101 assert "parameters" in schema["function"] 

102 assert "properties" in schema["function"]["parameters"] 

103 assert "text" in schema["function"]["parameters"]["properties"] 

104 assert "count" in schema["function"]["parameters"]["properties"] 

105 

106 def test_schema_required_fields(self): 

107 """测试 required 字段""" 

108 @register(description="测试工具", args_model=TestArgs) 

109 def test_tool(text: str, count: int = 1) -> str: 

110 return text 

111 

112 schema = _tools["test_tool"]["schema"] 

113 params = schema["function"]["parameters"] 

114 

115 # text 是必填,count 有默认值 

116 assert "text" in params.get("required", []) 

117 assert "count" not in params.get("required", []) 

118 

119 

120class NoArgs(BaseModel): 

121 """无参数工具的模型""" 

122 pass 

123 

124 

125class TestExecution: 

126 """测试工具执行""" 

127 

128 def setup_method(self): 

129 """每个测试前清空工具注册表""" 

130 _tools.clear() 

131 

132 def test_execute_tool_success(self): 

133 """测试成功执行工具""" 

134 @register(description="测试工具", args_model=TestArgs) 

135 def test_tool(text: str, count: int = 1) -> str: 

136 return text * count 

137 

138 result = execute("test_tool", json.dumps({"text": "hello", "count": 3})) 

139 

140 assert result == "hellohellohello" 

141 

142 def test_execute_tool_default_args(self): 

143 """测试使用默认参数执行""" 

144 @register(description="测试工具", args_model=TestArgs) 

145 def test_tool(text: str, count: int = 1) -> str: 

146 return text * count 

147 

148 result = execute("test_tool", json.dumps({"text": "hi"})) 

149 

150 assert result == "hi" # count 使用默认值 1 

151 

152 def test_execute_tool_invalid_json(self): 

153 """测试无效 JSON 参数""" 

154 @register(description="测试工具", args_model=TestArgs) 

155 def test_tool(text: str, count: int = 1) -> str: 

156 return text 

157 

158 result = execute("test_tool", "not a json") 

159 

160 assert "错误" in result 

161 assert "JSON" in result 

162 

163 def test_execute_tool_missing_required(self): 

164 """测试缺少必填参数""" 

165 @register(description="测试工具", args_model=TestArgs) 

166 def test_tool(text: str, count: int = 1) -> str: 

167 return text 

168 

169 result = execute("test_tool", json.dumps({})) 

170 

171 assert "错误" in result or "校验失败" in result 

172 

173 def test_execute_tool_wrong_type(self): 

174 """测试参数类型错误""" 

175 @register(description="测试工具", args_model=TestArgs) 

176 def test_tool(text: str, count: int = 1) -> str: 

177 return text * count 

178 

179 result = execute("test_tool", json.dumps({"text": "hi", "count": "not a number"})) 

180 

181 assert "错误" in result or "校验失败" in result 

182 

183 def test_execute_nonexistent_tool(self): 

184 """测试执行不存在的工具""" 

185 result = execute("nonexistent", json.dumps({})) 

186 

187 assert "错误" in result 

188 assert "找不到工具" in result 

189 

190 def test_execute_no_args_tool(self): 

191 """测试无参数工具""" 

192 @register(description="无参数工具", args_model=NoArgs) 

193 def no_args_tool() -> str: 

194 return "success" 

195 

196 result = execute("no_args_tool", json.dumps({})) 

197 

198 assert result == "success" 

199 

200 def test_execute_tool_exception(self): 

201 """测试工具执行异常""" 

202 @register(description="会报错的工具", args_model=TestArgs) 

203 def error_tool(text: str, count: int = 1) -> str: 

204 raise ValueError("故意报错") 

205 

206 result = execute("error_tool", json.dumps({"text": "test"})) 

207 

208 # 应该捕获异常并返回错误信息 

209 assert "错误" in result or "校验失败" in result