Coverage for src\funcall\__init__.py: 77%
185 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 23:51 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 23:51 +0900
1import asyncio
2import concurrent.futures
3import dataclasses
4import inspect
5import json
6from collections.abc import Callable
7from logging import getLogger
8from typing import Generic, Literal, Required, TypedDict, TypeVar, Union, get_args, get_type_hints
10import litellm
11from openai.types.responses import (
12 FunctionToolParam,
13 ResponseFunctionToolCall,
14)
15from pydantic import BaseModel
17from .params_to_schema import params_to_schema
19logger = getLogger("funcall")
21T = TypeVar("T")
24class Context(Generic[T]):
25 def __init__(self, value: T | None = None) -> None:
26 self.value = value
29class LiteLLMFunctionToolParam(TypedDict):
30 name: Required[str]
31 parameters: Required[dict[str, object] | None]
32 strict: Required[bool | None]
33 type: Required[Literal["function"]]
34 description: str | None
37class LiteLLMFunctionToolParam(TypedDict):
38 type: Literal["function"]
39 function: Required[LiteLLMFunctionToolParam]
42def generate_meta(func: Callable, target: Literal["openai", "litellm"] = "openai") -> FunctionToolParam:
43 sig = inspect.signature(func)
44 type_hints = get_type_hints(func)
45 doc = func.__doc__.strip() if func.__doc__ else ""
46 param_names = []
47 param_types = []
48 context_param_count = 0
49 for name in sig.parameters:
50 hint = type_hints.get(name, str)
51 # 跳过所有类型为 Context 的参数
52 if getattr(hint, "__origin__", None) is Context or hint is Context:
53 context_param_count += 1
54 continue
55 param_names.append(name)
56 param_types.append(hint)
57 if context_param_count > 1:
58 logger.warning("Multiple Context-type parameters detected in function '%s'. Only one context instance will be injected at runtime.", func.__name__)
59 schema = params_to_schema(param_types)
60 # 单参数且为 dataclass 或 BaseModel,提升其字段为顶层
61 if len(param_names) == 1:
62 hint = param_types[0]
63 if isinstance(hint, type) and (dataclasses.is_dataclass(hint) or (BaseModel and issubclass(hint, BaseModel))):
64 prop = schema["properties"]["param_0"]
65 properties = prop["properties"]
66 required = prop.get("required", [])
67 additional = prop.get("additionalProperties", False)
68 if target == "litellm":
69 meta = {
70 "type": "function",
71 "function": {
72 "name": func.__name__,
73 "description": doc,
74 "parameters": {
75 "type": "object",
76 "properties": properties,
77 # litellm Function Calling 要求 required 必须包含所有字段
78 "required": list(properties.keys()) if required else [],
79 "additionalProperties": additional,
80 },
81 },
82 }
83 elif target == "openai": 83 ↛ 97line 83 didn't jump to line 97 because the condition on line 83 was always true
84 meta: FunctionToolParam = {
85 "type": "function",
86 "name": func.__name__,
87 "description": doc,
88 "parameters": {
89 "type": "object",
90 "properties": properties,
91 # OpenAI Function Calling 要求 required 必须包含所有字段
92 "required": list(properties.keys()),
93 "additionalProperties": additional,
94 },
95 "strict": True,
96 }
97 return meta
98 # 多参数或非 dataclass/BaseModel
99 properties = {}
100 required = []
101 for i, name in enumerate(param_names):
102 prop = schema["properties"][f"param_{i}"]
103 properties[name] = prop
104 required.append(name)
105 meta: FunctionToolParam = {
106 "type": "function",
107 "name": func.__name__,
108 "description": doc,
109 "parameters": {
110 "type": "object",
111 "properties": properties,
112 "required": required,
113 "additionalProperties": False,
114 },
115 "strict": True,
116 }
117 if "$defs" in schema:
118 meta["parameters"]["$defs"] = schema["$defs"]
119 return meta
122def _convert_arg(value: object, hint: type) -> object: # noqa: PLR0911
123 origin = getattr(hint, "__origin__", None)
124 if origin in (list, set, tuple):
125 args = get_args(hint)
126 item_type = args[0] if args else str
127 return [_convert_arg(v, item_type) for v in value]
128 if origin is dict: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true
129 return value
130 if getattr(hint, "__origin__", None) is Union:
131 args = get_args(hint)
132 non_none = [a for a in args if a is not type(None)]
133 return _convert_arg(value, non_none[0]) if len(non_none) == 1 else value
134 if isinstance(hint, type) and BaseModel and issubclass(hint, BaseModel):
135 if isinstance(value, dict): 135 ↛ 138line 135 didn't jump to line 138 because the condition on line 135 was always true
136 fields = hint.model_fields
137 return hint(**{k: _convert_arg(v, fields[k].annotation) if k in fields else v for k, v in value.items()})
138 return value
139 if dataclasses.is_dataclass(hint):
140 if isinstance(value, dict): 140 ↛ 143line 140 didn't jump to line 143 because the condition on line 140 was always true
141 field_types = {f.name: f.type for f in dataclasses.fields(hint)}
142 return hint(**{k: _convert_arg(v, field_types.get(k, type(v))) for k, v in value.items()})
143 return value
144 return value
147def _is_async_function(func: Callable) -> bool:
148 """检查函数是否为异步函数"""
149 return inspect.iscoroutinefunction(func)
152class Funcall:
153 def __init__(self, functions: list | None = None) -> None:
154 if functions is None:
155 functions = []
156 self.functions = functions
157 self.function_map = {func.__name__: func for func in functions}
159 # target 可以是 openai 或者 litellm
160 def get_tools(self, target: Literal["openai", "litellm"] = "openai") -> list[FunctionToolParam]:
161 return [generate_meta(func, target) for func in self.functions]
163 def _prepare_function_call(self, func_name: str, args: str, context: object = None) -> tuple[Callable, dict]:
164 """准备函数调用的通用逻辑"""
165 if func_name not in self.function_map:
166 msg = f"Function {func_name} not found"
167 raise ValueError(msg)
169 func = self.function_map[func_name]
170 sig = inspect.signature(func)
171 type_hints = get_type_hints(func)
172 kwargs = json.loads(args)
174 # 找出所有非 context 参数
175 non_context_params = [name for name in sig.parameters if not (getattr(type_hints.get(name, str), "__origin__", None) is Context or type_hints.get(name, str) is Context)]
177 # 如果只有一个非 context 参数,且 kwargs 不是以该参数名为 key 的 dict,则包裹
178 if len(non_context_params) == 1 and (not isinstance(kwargs, dict) or set(kwargs.keys()) != set(non_context_params)):
179 only_param = non_context_params[0]
180 kwargs = {only_param: kwargs}
182 new_kwargs = {}
183 for name in sig.parameters:
184 hint = type_hints.get(name, str)
185 if getattr(hint, "__origin__", None) is Context or hint is Context:
186 new_kwargs[name] = context
187 elif name in kwargs: 187 ↛ 183line 187 didn't jump to line 183 because the condition on line 187 was always true
188 new_kwargs[name] = _convert_arg(kwargs[name], hint)
190 return func, new_kwargs
192 def handle_openai_function_call(self, item: ResponseFunctionToolCall, context: object = None):
193 """同步处理 OpenAI 函数调用"""
194 if not isinstance(item, ResponseFunctionToolCall): 194 ↛ 195line 194 didn't jump to line 195 because the condition on line 194 was never true
195 msg = "item must be an instance of ResponseFunctionToolCall"
196 raise TypeError(msg)
198 func, kwargs = self._prepare_function_call(item.name, item.arguments, context)
200 if _is_async_function(func): 200 ↛ 201line 200 didn't jump to line 201 because the condition on line 200 was never true
201 logger.warning("Function %s is async but being called synchronously. Consider using handle_openai_function_call_async.", item.name)
202 # 在同步上下文中运行异步函数
203 try:
204 loop = asyncio.get_event_loop()
205 if loop.is_running():
206 # 如果已经在事件循环中,创建新的任务
208 with concurrent.futures.ThreadPoolExecutor() as executor:
209 future = executor.submit(asyncio.run, func(**kwargs))
210 return future.result()
211 else:
212 return loop.run_until_complete(func(**kwargs))
213 except RuntimeError:
214 # 没有事件循环,创建新的
215 return asyncio.run(func(**kwargs))
216 else:
217 return func(**kwargs)
219 async def handle_openai_function_call_async(self, item: ResponseFunctionToolCall, context: object = None):
220 """异步处理 OpenAI 函数调用"""
221 if not isinstance(item, ResponseFunctionToolCall): 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true
222 msg = "item must be an instance of ResponseFunctionToolCall"
223 raise TypeError(msg)
225 func, kwargs = self._prepare_function_call(item.name, item.arguments, context)
227 if _is_async_function(func): 227 ↛ 230line 227 didn't jump to line 230 because the condition on line 227 was always true
228 return await func(**kwargs)
229 # 在线程池中运行同步函数,避免阻塞事件循环
230 loop = asyncio.get_event_loop()
231 return await loop.run_in_executor(None, lambda: func(**kwargs))
233 def handle_litellm_function_call(self, item: litellm.ChatCompletionMessageToolCall, context: object = None):
234 """同步处理 LiteLLM 函数调用"""
235 if not isinstance(item, litellm.ChatCompletionMessageToolCall): 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true
236 msg = "item must be an instance of litellm.ChatCompletionMessageToolCall"
237 raise TypeError(msg)
239 func, kwargs = self._prepare_function_call(item.function.name, item.function.arguments, context)
241 if _is_async_function(func): 241 ↛ 242line 241 didn't jump to line 242 because the condition on line 241 was never true
242 logger.warning("Function %s is async but being called synchronously. Consider using handle_litellm_function_call_async.", item.function.name)
243 # 在同步上下文中运行异步函数
244 try:
245 loop = asyncio.get_event_loop()
246 if loop.is_running():
247 # 如果已经在事件循环中,创建新的任务
249 with concurrent.futures.ThreadPoolExecutor() as executor:
250 future = executor.submit(asyncio.run, func(**kwargs))
251 return future.result()
252 else:
253 return loop.run_until_complete(func(**kwargs))
254 except RuntimeError:
255 # 没有事件循环,创建新的
256 return asyncio.run(func(**kwargs))
257 else:
258 return func(**kwargs)
260 async def handle_litellm_function_call_async(self, item: litellm.ChatCompletionMessageToolCall, context: object = None):
261 """异步处理 LiteLLM 函数调用"""
262 if not isinstance(item, litellm.ChatCompletionMessageToolCall): 262 ↛ 263line 262 didn't jump to line 263 because the condition on line 262 was never true
263 msg = "item must be an instance of litellm.ChatCompletionMessageToolCall"
264 raise TypeError(msg)
266 func, kwargs = self._prepare_function_call(item.function.name, item.function.arguments, context)
268 if _is_async_function(func): 268 ↛ 271line 268 didn't jump to line 271 because the condition on line 268 was always true
269 return await func(**kwargs)
270 # 在线程池中运行同步函数,避免阻塞事件循环
271 loop = asyncio.get_event_loop()
272 return await loop.run_in_executor(None, lambda: func(**kwargs))
274 def handle_function_call(self, item: ResponseFunctionToolCall | litellm.ChatCompletionMessageToolCall, context: object = None):
275 """同步函数调用处理"""
276 if isinstance(item, ResponseFunctionToolCall):
277 return self.handle_openai_function_call(item, context)
278 if isinstance(item, litellm.ChatCompletionMessageToolCall): 278 ↛ 280line 278 didn't jump to line 280 because the condition on line 278 was always true
279 return self.handle_litellm_function_call(item, context)
280 msg = "item must be an instance of ResponseFunctionToolCall or litellm.ChatCompletionMessageToolCall"
281 raise TypeError(msg)
283 async def handle_function_call_async(self, item: ResponseFunctionToolCall | litellm.ChatCompletionMessageToolCall, context: object = None):
284 """异步函数调用处理"""
285 if isinstance(item, ResponseFunctionToolCall):
286 return await self.handle_openai_function_call_async(item, context)
287 if isinstance(item, litellm.ChatCompletionMessageToolCall): 287 ↛ 289line 287 didn't jump to line 289 because the condition on line 287 was always true
288 return await self.handle_litellm_function_call_async(item, context)
289 msg = "item must be an instance of ResponseFunctionToolCall or litellm.ChatCompletionMessageToolCall"
290 raise TypeError(msg)
293__all__ = ["Context", "Funcall", "generate_meta"]