Coverage for src\funcall\__init__.py: 77%

185 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-18 23:51 +0900

1import asyncio 

2import concurrent.futures 

3import dataclasses 

4import inspect 

5import json 

6from collections.abc import Callable 

7from logging import getLogger 

8from typing import Generic, Literal, Required, TypedDict, TypeVar, Union, get_args, get_type_hints 

9 

10import litellm 

11from openai.types.responses import ( 

12 FunctionToolParam, 

13 ResponseFunctionToolCall, 

14) 

15from pydantic import BaseModel 

16 

17from .params_to_schema import params_to_schema 

18 

19logger = getLogger("funcall") 

20 

21T = TypeVar("T") 

22 

23 

24class Context(Generic[T]): 

25 def __init__(self, value: T | None = None) -> None: 

26 self.value = value 

27 

28 

29class LiteLLMFunctionToolParam(TypedDict): 

30 name: Required[str] 

31 parameters: Required[dict[str, object] | None] 

32 strict: Required[bool | None] 

33 type: Required[Literal["function"]] 

34 description: str | None 

35 

36 

37class LiteLLMFunctionToolParam(TypedDict): 

38 type: Literal["function"] 

39 function: Required[LiteLLMFunctionToolParam] 

40 

41 

42def generate_meta(func: Callable, target: Literal["openai", "litellm"] = "openai") -> FunctionToolParam: 

43 sig = inspect.signature(func) 

44 type_hints = get_type_hints(func) 

45 doc = func.__doc__.strip() if func.__doc__ else "" 

46 param_names = [] 

47 param_types = [] 

48 context_param_count = 0 

49 for name in sig.parameters: 

50 hint = type_hints.get(name, str) 

51 # 跳过所有类型为 Context 的参数 

52 if getattr(hint, "__origin__", None) is Context or hint is Context: 

53 context_param_count += 1 

54 continue 

55 param_names.append(name) 

56 param_types.append(hint) 

57 if context_param_count > 1: 

58 logger.warning("Multiple Context-type parameters detected in function '%s'. Only one context instance will be injected at runtime.", func.__name__) 

59 schema = params_to_schema(param_types) 

60 # 单参数且为 dataclass 或 BaseModel,提升其字段为顶层 

61 if len(param_names) == 1: 

62 hint = param_types[0] 

63 if isinstance(hint, type) and (dataclasses.is_dataclass(hint) or (BaseModel and issubclass(hint, BaseModel))): 

64 prop = schema["properties"]["param_0"] 

65 properties = prop["properties"] 

66 required = prop.get("required", []) 

67 additional = prop.get("additionalProperties", False) 

68 if target == "litellm": 

69 meta = { 

70 "type": "function", 

71 "function": { 

72 "name": func.__name__, 

73 "description": doc, 

74 "parameters": { 

75 "type": "object", 

76 "properties": properties, 

77 # litellm Function Calling 要求 required 必须包含所有字段 

78 "required": list(properties.keys()) if required else [], 

79 "additionalProperties": additional, 

80 }, 

81 }, 

82 } 

83 elif target == "openai": 83 ↛ 97line 83 didn't jump to line 97 because the condition on line 83 was always true

84 meta: FunctionToolParam = { 

85 "type": "function", 

86 "name": func.__name__, 

87 "description": doc, 

88 "parameters": { 

89 "type": "object", 

90 "properties": properties, 

91 # OpenAI Function Calling 要求 required 必须包含所有字段 

92 "required": list(properties.keys()), 

93 "additionalProperties": additional, 

94 }, 

95 "strict": True, 

96 } 

97 return meta 

98 # 多参数或非 dataclass/BaseModel 

99 properties = {} 

100 required = [] 

101 for i, name in enumerate(param_names): 

102 prop = schema["properties"][f"param_{i}"] 

103 properties[name] = prop 

104 required.append(name) 

105 meta: FunctionToolParam = { 

106 "type": "function", 

107 "name": func.__name__, 

108 "description": doc, 

109 "parameters": { 

110 "type": "object", 

111 "properties": properties, 

112 "required": required, 

113 "additionalProperties": False, 

114 }, 

115 "strict": True, 

116 } 

117 if "$defs" in schema: 

118 meta["parameters"]["$defs"] = schema["$defs"] 

119 return meta 

120 

121 

122def _convert_arg(value: object, hint: type) -> object: # noqa: PLR0911 

123 origin = getattr(hint, "__origin__", None) 

124 if origin in (list, set, tuple): 

125 args = get_args(hint) 

126 item_type = args[0] if args else str 

127 return [_convert_arg(v, item_type) for v in value] 

128 if origin is dict: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true

129 return value 

130 if getattr(hint, "__origin__", None) is Union: 

131 args = get_args(hint) 

132 non_none = [a for a in args if a is not type(None)] 

133 return _convert_arg(value, non_none[0]) if len(non_none) == 1 else value 

134 if isinstance(hint, type) and BaseModel and issubclass(hint, BaseModel): 

135 if isinstance(value, dict): 135 ↛ 138line 135 didn't jump to line 138 because the condition on line 135 was always true

136 fields = hint.model_fields 

137 return hint(**{k: _convert_arg(v, fields[k].annotation) if k in fields else v for k, v in value.items()}) 

138 return value 

139 if dataclasses.is_dataclass(hint): 

140 if isinstance(value, dict): 140 ↛ 143line 140 didn't jump to line 143 because the condition on line 140 was always true

141 field_types = {f.name: f.type for f in dataclasses.fields(hint)} 

142 return hint(**{k: _convert_arg(v, field_types.get(k, type(v))) for k, v in value.items()}) 

143 return value 

144 return value 

145 

146 

147def _is_async_function(func: Callable) -> bool: 

148 """检查函数是否为异步函数""" 

149 return inspect.iscoroutinefunction(func) 

150 

151 

152class Funcall: 

153 def __init__(self, functions: list | None = None) -> None: 

154 if functions is None: 

155 functions = [] 

156 self.functions = functions 

157 self.function_map = {func.__name__: func for func in functions} 

158 

159 # target 可以是 openai 或者 litellm 

160 def get_tools(self, target: Literal["openai", "litellm"] = "openai") -> list[FunctionToolParam]: 

161 return [generate_meta(func, target) for func in self.functions] 

162 

163 def _prepare_function_call(self, func_name: str, args: str, context: object = None) -> tuple[Callable, dict]: 

164 """准备函数调用的通用逻辑""" 

165 if func_name not in self.function_map: 

166 msg = f"Function {func_name} not found" 

167 raise ValueError(msg) 

168 

169 func = self.function_map[func_name] 

170 sig = inspect.signature(func) 

171 type_hints = get_type_hints(func) 

172 kwargs = json.loads(args) 

173 

174 # 找出所有非 context 参数 

175 non_context_params = [name for name in sig.parameters if not (getattr(type_hints.get(name, str), "__origin__", None) is Context or type_hints.get(name, str) is Context)] 

176 

177 # 如果只有一个非 context 参数,且 kwargs 不是以该参数名为 key 的 dict,则包裹 

178 if len(non_context_params) == 1 and (not isinstance(kwargs, dict) or set(kwargs.keys()) != set(non_context_params)): 

179 only_param = non_context_params[0] 

180 kwargs = {only_param: kwargs} 

181 

182 new_kwargs = {} 

183 for name in sig.parameters: 

184 hint = type_hints.get(name, str) 

185 if getattr(hint, "__origin__", None) is Context or hint is Context: 

186 new_kwargs[name] = context 

187 elif name in kwargs: 187 ↛ 183line 187 didn't jump to line 183 because the condition on line 187 was always true

188 new_kwargs[name] = _convert_arg(kwargs[name], hint) 

189 

190 return func, new_kwargs 

191 

192 def handle_openai_function_call(self, item: ResponseFunctionToolCall, context: object = None): 

193 """同步处理 OpenAI 函数调用""" 

194 if not isinstance(item, ResponseFunctionToolCall): 194 ↛ 195line 194 didn't jump to line 195 because the condition on line 194 was never true

195 msg = "item must be an instance of ResponseFunctionToolCall" 

196 raise TypeError(msg) 

197 

198 func, kwargs = self._prepare_function_call(item.name, item.arguments, context) 

199 

200 if _is_async_function(func): 200 ↛ 201line 200 didn't jump to line 201 because the condition on line 200 was never true

201 logger.warning("Function %s is async but being called synchronously. Consider using handle_openai_function_call_async.", item.name) 

202 # 在同步上下文中运行异步函数 

203 try: 

204 loop = asyncio.get_event_loop() 

205 if loop.is_running(): 

206 # 如果已经在事件循环中,创建新的任务 

207 

208 with concurrent.futures.ThreadPoolExecutor() as executor: 

209 future = executor.submit(asyncio.run, func(**kwargs)) 

210 return future.result() 

211 else: 

212 return loop.run_until_complete(func(**kwargs)) 

213 except RuntimeError: 

214 # 没有事件循环,创建新的 

215 return asyncio.run(func(**kwargs)) 

216 else: 

217 return func(**kwargs) 

218 

219 async def handle_openai_function_call_async(self, item: ResponseFunctionToolCall, context: object = None): 

220 """异步处理 OpenAI 函数调用""" 

221 if not isinstance(item, ResponseFunctionToolCall): 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true

222 msg = "item must be an instance of ResponseFunctionToolCall" 

223 raise TypeError(msg) 

224 

225 func, kwargs = self._prepare_function_call(item.name, item.arguments, context) 

226 

227 if _is_async_function(func): 227 ↛ 230line 227 didn't jump to line 230 because the condition on line 227 was always true

228 return await func(**kwargs) 

229 # 在线程池中运行同步函数,避免阻塞事件循环 

230 loop = asyncio.get_event_loop() 

231 return await loop.run_in_executor(None, lambda: func(**kwargs)) 

232 

233 def handle_litellm_function_call(self, item: litellm.ChatCompletionMessageToolCall, context: object = None): 

234 """同步处理 LiteLLM 函数调用""" 

235 if not isinstance(item, litellm.ChatCompletionMessageToolCall): 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true

236 msg = "item must be an instance of litellm.ChatCompletionMessageToolCall" 

237 raise TypeError(msg) 

238 

239 func, kwargs = self._prepare_function_call(item.function.name, item.function.arguments, context) 

240 

241 if _is_async_function(func): 241 ↛ 242line 241 didn't jump to line 242 because the condition on line 241 was never true

242 logger.warning("Function %s is async but being called synchronously. Consider using handle_litellm_function_call_async.", item.function.name) 

243 # 在同步上下文中运行异步函数 

244 try: 

245 loop = asyncio.get_event_loop() 

246 if loop.is_running(): 

247 # 如果已经在事件循环中,创建新的任务 

248 

249 with concurrent.futures.ThreadPoolExecutor() as executor: 

250 future = executor.submit(asyncio.run, func(**kwargs)) 

251 return future.result() 

252 else: 

253 return loop.run_until_complete(func(**kwargs)) 

254 except RuntimeError: 

255 # 没有事件循环,创建新的 

256 return asyncio.run(func(**kwargs)) 

257 else: 

258 return func(**kwargs) 

259 

260 async def handle_litellm_function_call_async(self, item: litellm.ChatCompletionMessageToolCall, context: object = None): 

261 """异步处理 LiteLLM 函数调用""" 

262 if not isinstance(item, litellm.ChatCompletionMessageToolCall): 262 ↛ 263line 262 didn't jump to line 263 because the condition on line 262 was never true

263 msg = "item must be an instance of litellm.ChatCompletionMessageToolCall" 

264 raise TypeError(msg) 

265 

266 func, kwargs = self._prepare_function_call(item.function.name, item.function.arguments, context) 

267 

268 if _is_async_function(func): 268 ↛ 271line 268 didn't jump to line 271 because the condition on line 268 was always true

269 return await func(**kwargs) 

270 # 在线程池中运行同步函数,避免阻塞事件循环 

271 loop = asyncio.get_event_loop() 

272 return await loop.run_in_executor(None, lambda: func(**kwargs)) 

273 

274 def handle_function_call(self, item: ResponseFunctionToolCall | litellm.ChatCompletionMessageToolCall, context: object = None): 

275 """同步函数调用处理""" 

276 if isinstance(item, ResponseFunctionToolCall): 

277 return self.handle_openai_function_call(item, context) 

278 if isinstance(item, litellm.ChatCompletionMessageToolCall): 278 ↛ 280line 278 didn't jump to line 280 because the condition on line 278 was always true

279 return self.handle_litellm_function_call(item, context) 

280 msg = "item must be an instance of ResponseFunctionToolCall or litellm.ChatCompletionMessageToolCall" 

281 raise TypeError(msg) 

282 

283 async def handle_function_call_async(self, item: ResponseFunctionToolCall | litellm.ChatCompletionMessageToolCall, context: object = None): 

284 """异步函数调用处理""" 

285 if isinstance(item, ResponseFunctionToolCall): 

286 return await self.handle_openai_function_call_async(item, context) 

287 if isinstance(item, litellm.ChatCompletionMessageToolCall): 287 ↛ 289line 287 didn't jump to line 289 because the condition on line 287 was always true

288 return await self.handle_litellm_function_call_async(item, context) 

289 msg = "item must be an instance of ResponseFunctionToolCall or litellm.ChatCompletionMessageToolCall" 

290 raise TypeError(msg) 

291 

292 

293__all__ = ["Context", "Funcall", "generate_meta"]