Coverage for src\funcall\__init__.py: 86%
112 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 18:13 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 18:13 +0900
1import contextlib
2import dataclasses
3import inspect
4import json
5from collections.abc import Callable
6from logging import getLogger
7from typing import Generic, TypeVar, Union, get_args, get_type_hints
9from openai.types.responses import (
10 FunctionToolParam,
11 ResponseFunctionToolCall,
12)
13from pydantic import BaseModel
15from .params_to_schema import params_to_schema
17logger = getLogger("funcall")
19T = TypeVar("T")
22class Context(Generic[T]):
23 def __init__(self, value: T | None = None) -> None:
24 self.value = value
27def generate_meta(func: Callable) -> FunctionToolParam:
28 sig = inspect.signature(func)
29 type_hints = get_type_hints(func)
30 doc = func.__doc__.strip() if func.__doc__ else ""
31 param_names = []
32 param_types = []
33 context_param_count = 0
34 for name in sig.parameters:
35 hint = type_hints.get(name, str)
36 # 跳过所有类型为 Context 的参数
37 if getattr(hint, "__origin__", None) is Context or hint is Context:
38 context_param_count += 1
39 continue
40 param_names.append(name)
41 param_types.append(hint)
42 if context_param_count > 1:
43 logger.warning("Multiple Context-type parameters detected in function '%s'. Only one context instance will be injected at runtime.", func.__name__)
44 schema = params_to_schema(param_types)
45 # 单参数且为 dataclass 或 BaseModel,提升其字段为顶层
46 if len(param_names) == 1:
47 hint = param_types[0]
48 if isinstance(hint, type) and (dataclasses.is_dataclass(hint) or (BaseModel and issubclass(hint, BaseModel))):
49 prop = schema["properties"]["param_0"]
50 # 跟进 $ref
51 if "$ref" in prop: 51 ↛ 59line 51 didn't jump to line 59 because the condition on line 51 was always true
52 ref = prop["$ref"]
53 def_name = ref.split("/", 2)[-1]
54 def_schema = schema["$defs"][def_name]
55 properties = def_schema["properties"]
56 required = def_schema.get("required", [])
57 additional = def_schema.get("additionalProperties", False)
58 else:
59 properties = prop["properties"]
60 required = prop.get("required", [])
61 additional = prop.get("additionalProperties", False)
62 meta: FunctionToolParam = {
63 "type": "function",
64 "name": func.__name__,
65 "description": doc,
66 "parameters": {
67 "type": "object",
68 "properties": properties,
69 "required": required,
70 "additionalProperties": additional,
71 },
72 "strict": True,
73 }
74 if "$defs" in schema: 74 ↛ 76line 74 didn't jump to line 76 because the condition on line 74 was always true
75 meta["parameters"]["$defs"] = schema["$defs"]
76 return meta
77 # 多参数或非 dataclass/BaseModel
78 properties = {}
79 required = []
80 for i, name in enumerate(param_names):
81 prop = schema["properties"][f"param_{i}"]
82 properties[name] = prop
83 required.append(name)
84 meta: FunctionToolParam = {
85 "type": "function",
86 "name": func.__name__,
87 "description": doc,
88 "parameters": {
89 "type": "object",
90 "properties": properties,
91 "required": required,
92 "additionalProperties": False,
93 },
94 "strict": True,
95 }
96 if "$defs" in schema:
97 meta["parameters"]["$defs"] = schema["$defs"]
98 return meta
101def _convert_arg(value: object, hint: type) -> object:
102 result = value
103 origin = getattr(hint, "__origin__", None)
104 if origin is list or origin is set or origin is tuple: 104 ↛ 105line 104 didn't jump to line 105 because the condition on line 104 was never true
105 args = get_args(hint)
106 item_type = args[0] if args else str
107 result = [_convert_arg(v, item_type) for v in value]
108 elif origin is dict: 108 ↛ 109line 108 didn't jump to line 109 because the condition on line 108 was never true
109 result = value
110 elif getattr(hint, "__origin__", None) is Union: 110 ↛ 112line 110 didn't jump to line 112 because the condition on line 110 was never true
111 # 只处理 Optional[X],否则直接返回
112 args = get_args(hint)
113 non_none = [a for a in args if a is not type(None)]
114 result = _convert_arg(value, non_none[0]) if len(non_none) == 1 else value
115 elif (isinstance(hint, type) and BaseModel and issubclass(hint, BaseModel)) or dataclasses.is_dataclass(hint): 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true
116 result = hint(**value) if isinstance(value, dict) else value
117 else:
118 result = value
119 return result
122class Funcall:
123 def __init__(self, functions: list | None = None) -> None:
124 if functions is None:
125 functions = []
126 self.functions = functions
127 self.function_map = {func.__name__: func for func in functions}
129 def get_tools(self) -> list[FunctionToolParam]:
130 return [generate_meta(func) for func in self.functions]
132 def handle_function_call(self, item: ResponseFunctionToolCall, context: object = None):
133 if item.name in self.function_map:
134 func = self.function_map[item.name]
135 args = item.arguments
136 sig = inspect.signature(func)
137 type_hints = get_type_hints(func)
138 kwargs = json.loads(args)
139 # 兼容单参数为数组时直接传数组
140 if isinstance(kwargs, list) and len(sig.parameters) == 1: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true
141 only_param = next(iter(sig.parameters))
142 kwargs = {only_param: kwargs}
143 new_kwargs = {}
144 for name in sig.parameters:
145 hint = type_hints.get(name, str)
146 if getattr(hint, "__origin__", None) is Context or hint is Context:
147 new_kwargs[name] = context
148 elif name in kwargs:
149 new_kwargs[name] = _convert_arg(kwargs[name], hint)
150 elif isinstance(hint, type) and BaseModel and issubclass(hint, BaseModel):
151 # 支持 pydantic 单对象参数
152 with contextlib.suppress(Exception):
153 new_kwargs[name] = hint(**kwargs)
154 elif dataclasses.is_dataclass(hint): 154 ↛ 144line 154 didn't jump to line 144 because the condition on line 154 was always true
155 # 支持 dataclass 单对象参数
156 with contextlib.suppress(Exception):
157 new_kwargs[name] = hint(**kwargs)
158 return func(**new_kwargs)
159 msg = f"Function {item.name} not found"
160 raise ValueError(msg)
163__all__ = ["Context", "Funcall", "generate_meta"]