Coverage for tools / registry.py: 87%

86 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-29 02:55 +0800

1import json 

2from typing import Type 

3from pydantic import BaseModel 

4from qrclaw.logger import get_logger 

5 

6logger = get_logger("qrclaw.tools.registry") 

7 

8# 存所有注册的工具 

9# 结构:{ "工具名": {"fn": 函数本身, "model": Pydantic模型, "schema": 给LLM看的描述} } 

10_tools: dict = {} 

11 

12 

13def register(description: str, args_model: Type[BaseModel], confirm: bool = False): 

14 """ 

15 装饰器:把函数注册成一个工具。 

16 

17 confirm=True 表示执行前需要用户确认(高风险工具) 

18 """ 

19 def decorator(fn): 

20 schema = _build_schema(fn.__name__, description, args_model) 

21 _tools[fn.__name__] = { 

22 "fn": fn, 

23 "model": args_model, 

24 "schema": schema, 

25 "confirm": confirm, 

26 } 

27 logger.debug(f"注册工具: {fn.__name__} (需要确认: {confirm})") 

28 return fn 

29 return decorator 

30 

31 

32def need_confirm(name: str) -> bool: 

33 """判断工具是否需要用户确认""" 

34 return _tools.get(name, {}).get("confirm", False) 

35 

36 

37def _resolve_refs(schema: dict, defs: dict) -> dict: 

38 """递归展开 $ref 并清理 title,Gemini 不支持 $ref""" 

39 if "$ref" in schema: 

40 ref_name = schema["$ref"].split("/")[-1] 

41 resolved = _resolve_refs(defs.get(ref_name, {}), defs) 

42 other = {k: v for k, v in schema.items() if k != "$ref"} 

43 return {**resolved, **other} 

44 result = {} 

45 for k, v in schema.items(): 

46 if k in ("$defs", "title"): 

47 continue # 去掉 $defs 和所有层级的 title 

48 elif isinstance(v, dict): 

49 result[k] = _resolve_refs(v, defs) 

50 elif isinstance(v, list): 

51 result[k] = [_resolve_refs(i, defs) if isinstance(i, dict) else i for i in v] 

52 else: 

53 result[k] = v 

54 return result 

55 

56 

57def _build_schema(name: str, description: str, args_model: Type[BaseModel]) -> dict: 

58 """从 Pydantic 模型生成 OpenAI Tool Schema,兼容 Gemini""" 

59 pydantic_schema = args_model.model_json_schema() 

60 defs = pydantic_schema.get("$defs", {}) 

61 

62 # 展开 $ref,内联所有引用,同时递归清理 title 

63 resolved = _resolve_refs(pydantic_schema, defs) 

64 

65 properties = dict(resolved.get("properties", {})) 

66 

67 # 无参数工具:省略 parameters 字段(Gemini 不接受空 properties) 

68 if not properties: 

69 return { 

70 "type": "function", 

71 "function": { 

72 "name": name, 

73 "description": description, 

74 }, 

75 } 

76 

77 required = pydantic_schema.get("required", []) 

78 parameters: dict = { 

79 "type": "object", 

80 "properties": properties, 

81 } 

82 # required 为空时省略,避免某些 API 报错 

83 if required: 

84 parameters["required"] = required 

85 

86 return { 

87 "type": "function", 

88 "function": { 

89 "name": name, 

90 "description": description, 

91 "parameters": parameters, 

92 }, 

93 } 

94 

95 

96def get_schemas() -> list[dict]: 

97 """返回所有工具的 schema 列表,发给 LLM 用""" 

98 schemas = [item["schema"] for item in _tools.values()] 

99 logger.debug(f"获取工具 schemas,共 {len(schemas)} 个工具") 

100 return schemas 

101 

102 

103def execute(name: str, arguments: str) -> str: 

104 """ 

105 执行工具。 

106 用 Pydantic 模型校验参数,不合法直接报错,不会传脏数据给工具函数。 

107 同时执行安全切面检查。 

108  

109 Raises: 

110 PermissionError: 当权限检查失败时抛出,由上层 agent.py 捕获处理 

111 """ 

112 if name not in _tools: 

113 error_msg = f"错误:找不到工具 {name}" 

114 logger.error(error_msg) 

115 return error_msg 

116 

117 try: 

118 raw_args = json.loads(arguments) 

119 logger.debug(f"工具 {name} 原始参数: {raw_args}") 

120 

121 # 用 Pydantic 校验并解析参数 

122 validated = _tools[name]["model"](**raw_args) 

123 validated_args = validated.model_dump() 

124 logger.debug(f"工具 {name} 校验后参数: {validated_args}") 

125 

126 # === AOP 安全拦截 (Security Hook) === 

127 try: 

128 # 局部导入避免循环引用 

129 from qrclaw.agent import get_workspace 

130 from qrclaw.security import security_manager 

131 

132 # 获取当前上下文的工作空间 

133 ws = get_workspace() 

134 

135 # 仅当在 agent 运行上下文中时才检查 

136 # 如果是 CLI 直接调试工具或单元测试,可能没有 workspace,此时视为 Full Access 

137 if ws: 

138 security_manager.check_access( 

139 agent_id=ws.agent_id, 

140 tool_name=name, 

141 args=validated_args, 

142 workspace_root=ws.root 

143 ) 

144 except PermissionError: 

145 # 权限拒绝:透传给上层 agent.py 处理 

146 logger.warning(f"工具 {name} 权限检查失败") 

147 raise # 关键:重新抛出异常,而不是返回字符串 

148 except ImportError: 

149 # 可能是环境问题,忽略 

150 pass 

151 except Exception as e: 

152 # 安全检查本身出错,为了安全起见,选择拦截并报错 (Fail Closed) 

153 error_msg = f"系统错误:执行安全检查时发生异常 ({e})" 

154 logger.error(error_msg, exc_info=True) 

155 return error_msg 

156 # ==================================== 

157 

158 # 执行工具 

159 result = _tools[name]["fn"](**validated_args) 

160 logger.info(f"工具 {name} 执行成功") 

161 

162 return result 

163 except PermissionError: 

164 # 权限拒绝:透传给上层 agent.py 处理 

165 raise 

166 except json.JSONDecodeError as e: 

167 error_msg = f"错误:参数不是合法的 JSON: {e}" 

168 logger.error(error_msg) 

169 return error_msg 

170 except Exception as e: 

171 error_msg = f"错误:参数校验失败 {e}" 

172 logger.error(error_msg, exc_info=True) 

173 return error_msg