Coverage for agentos/tools/function_calling.py: 42%

99 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Function Calling Pipeline — Schema-validated tool invocation. 

3 

4Provides a complete function calling lifecycle: schema registration, LLM 

5tool_choice dispatch, argument validation, execution, and result formatting. 

6""" 

7 

8from __future__ import annotations 

9 

10import json 

11from dataclasses import dataclass, field 

12from typing import Any, Callable, Optional 

13 

14import jsonschema 

15 

16 

17@dataclass 

18class ToolSchema: 

19 """OpenAI-compatible tool/function schema.""" 

20 

21 name: str 

22 description: str 

23 parameters: dict[str, Any] 

24 """JSON Schema for parameters.""" 

25 

26 required: list[str] = field(default_factory=list) 

27 """Required parameter names.""" 

28 

29 def to_openai(self) -> dict[str, Any]: 

30 """Convert to OpenAI function definition format.""" 

31 schema = { 

32 "type": self.parameters.get("type", "object"), 

33 "properties": self.parameters.get("properties", {}), 

34 } 

35 if self.required: 

36 schema["required"] = self.required 

37 return { 

38 "type": "function", 

39 "function": { 

40 "name": self.name, 

41 "description": self.description, 

42 "parameters": schema, 

43 }, 

44 } 

45 

46 def to_anthropic(self) -> dict[str, Any]: 

47 """Convert to Anthropic tool format.""" 

48 return { 

49 "name": self.name, 

50 "description": self.description, 

51 "input_schema": { 

52 "type": "object", 

53 "properties": self.parameters.get("properties", {}), 

54 "required": self.required, 

55 }, 

56 } 

57 

58 

59@dataclass 

60class ToolCall: 

61 """A parsed tool call from an LLM response.""" 

62 

63 id: str 

64 name: str 

65 arguments: dict[str, Any] 

66 

67 

68@dataclass 

69class ToolResult: 

70 """Result of executing a tool call.""" 

71 

72 call_id: str 

73 name: str 

74 success: bool 

75 output: Any = None 

76 error: Optional[str] = None 

77 latency_ms: float = 0.0 

78 

79 

80class ToolRegistry: 

81 """ 

82 Registry of callable tools with schema validation. 

83 

84 Example:: 

85 

86 registry = ToolRegistry() 

87 registry.register( 

88 ToolSchema(name="get_weather", description="Get weather", parameters={ 

89 "type": "object", 

90 "properties": {"city": {"type": "string"}} 

91 }, required=["city"]), 

92 handler=lambda city: f"Weather in {city}: sunny" 

93 ) 

94 """ 

95 

96 def __init__(self): 

97 self._tools: dict[str, ToolSchema] = {} 

98 self._handlers: dict[str, Callable[..., Any]] = {} 

99 

100 def register( 

101 self, 

102 schema: ToolSchema, 

103 handler: Callable[..., Any], 

104 ) -> None: 

105 """Register a tool with its schema and handler function.""" 

106 name = schema.name 

107 if name in self._tools: 

108 raise ValueError(f"Tool '{name}' already registered") 

109 self._tools[name] = schema 

110 self._handlers[name] = handler 

111 

112 def unregister(self, name: str) -> None: 

113 """Remove a tool from the registry.""" 

114 self._tools.pop(name, None) 

115 self._handlers.pop(name, None) 

116 

117 def get_schema(self, name: str) -> Optional[ToolSchema]: 

118 return self._tools.get(name) 

119 

120 def list_schemas(self) -> list[ToolSchema]: 

121 return list(self._tools.values()) 

122 

123 def to_openai_tools(self) -> list[dict[str, Any]]: 

124 """Export all tools as OpenAI function definitions.""" 

125 return [t.to_openai() for t in self._tools.values()] 

126 

127 def to_anthropic_tools(self) -> list[dict[str, Any]]: 

128 """Export all tools as Anthropic tool definitions.""" 

129 return [t.to_anthropic() for t in self._tools.values()] 

130 

131 def validate_arguments(self, name: str, arguments: dict) -> list[str]: 

132 """Validate arguments against tool schema. Returns list of errors.""" 

133 schema = self._tools.get(name) 

134 if schema is None: 

135 return [f"Unknown tool: {name}"] 

136 

137 errors: list[str] = [] 

138 

139 # Check required args 

140 for field in schema.required: 

141 if field not in arguments: 

142 errors.append(f"Missing required argument: {field}") 

143 

144 # JSON Schema validation 

145 try: 

146 jsonschema.validate(instance=arguments, schema=schema.parameters) 

147 except jsonschema.ValidationError as e: 

148 errors.append(f"Schema validation: {e.message}") 

149 

150 return errors 

151 

152 def execute(self, call: ToolCall) -> ToolResult: 

153 """ 

154 Validate and execute a tool call. 

155 

156 Args: 

157 call: Parsed tool call with name and arguments. 

158 

159 Returns: 

160 ToolResult with success/failure and output. 

161 """ 

162 import time 

163 t0 = time.perf_counter() 

164 

165 errors = self.validate_arguments(call.name, call.arguments) 

166 if errors: 

167 return ToolResult( 

168 call_id=call.id, 

169 name=call.name, 

170 success=False, 

171 error="; ".join(errors), 

172 latency_ms=(time.perf_counter() - t0) * 1000, 

173 ) 

174 

175 handler = self._handlers.get(call.name) 

176 if handler is None: 

177 return ToolResult( 

178 call_id=call.id, 

179 name=call.name, 

180 success=False, 

181 error=f"No handler for tool: {call.name}", 

182 latency_ms=(time.perf_counter() - t0) * 1000, 

183 ) 

184 

185 try: 

186 output = handler(**call.arguments) 

187 return ToolResult( 

188 call_id=call.id, 

189 name=call.name, 

190 success=True, 

191 output=output, 

192 latency_ms=(time.perf_counter() - t0) * 1000, 

193 ) 

194 except Exception as exc: 

195 return ToolResult( 

196 call_id=call.id, 

197 name=call.name, 

198 success=False, 

199 error=f"{type(exc).__name__}: {exc}", 

200 latency_ms=(time.perf_counter() - t0) * 1000, 

201 ) 

202 

203 def execute_batch(self, calls: list[ToolCall]) -> list[ToolResult]: 

204 """Execute multiple tool calls. Independent calls run sequentially.""" 

205 return [self.execute(c) for c in calls] 

206 

207 def parse_tool_calls( 

208 self, raw_tool_calls: list[dict[str, Any]] 

209 ) -> list[ToolCall]: 

210 """Parse raw LLM tool_call dicts into ToolCall objects.""" 

211 parsed: list[ToolCall] = [] 

212 for tc in raw_tool_calls: 

213 fn = tc.get("function", tc) 

214 args_raw = fn.get("arguments", "{}") 

215 if isinstance(args_raw, str): 

216 try: 

217 args = json.loads(args_raw) 

218 except json.JSONDecodeError: 

219 args = {} 

220 else: 

221 args = args_raw 

222 parsed.append(ToolCall( 

223 id=tc.get("id", ""), 

224 name=fn.get("name", ""), 

225 arguments=args, 

226 )) 

227 return parsed 

228 

229 @property 

230 def tool_count(self) -> int: 

231 return len(self._tools)