Coverage for src\funcall\__init__.py: 86%

112 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-18 18:13 +0900

1import contextlib 

2import dataclasses 

3import inspect 

4import json 

5from collections.abc import Callable 

6from logging import getLogger 

7from typing import Generic, TypeVar, Union, get_args, get_type_hints 

8 

9from openai.types.responses import ( 

10 FunctionToolParam, 

11 ResponseFunctionToolCall, 

12) 

13from pydantic import BaseModel 

14 

15from .params_to_schema import params_to_schema 

16 

17logger = getLogger("funcall") 

18 

19T = TypeVar("T") 

20 

21 

22class Context(Generic[T]): 

23 def __init__(self, value: T | None = None) -> None: 

24 self.value = value 

25 

26 

27def generate_meta(func: Callable) -> FunctionToolParam: 

28 sig = inspect.signature(func) 

29 type_hints = get_type_hints(func) 

30 doc = func.__doc__.strip() if func.__doc__ else "" 

31 param_names = [] 

32 param_types = [] 

33 context_param_count = 0 

34 for name in sig.parameters: 

35 hint = type_hints.get(name, str) 

36 # 跳过所有类型为 Context 的参数 

37 if getattr(hint, "__origin__", None) is Context or hint is Context: 

38 context_param_count += 1 

39 continue 

40 param_names.append(name) 

41 param_types.append(hint) 

42 if context_param_count > 1: 

43 logger.warning("Multiple Context-type parameters detected in function '%s'. Only one context instance will be injected at runtime.", func.__name__) 

44 schema = params_to_schema(param_types) 

45 # 单参数且为 dataclass 或 BaseModel,提升其字段为顶层 

46 if len(param_names) == 1: 

47 hint = param_types[0] 

48 if isinstance(hint, type) and (dataclasses.is_dataclass(hint) or (BaseModel and issubclass(hint, BaseModel))): 

49 prop = schema["properties"]["param_0"] 

50 # 跟进 $ref 

51 if "$ref" in prop: 51 ↛ 59line 51 didn't jump to line 59 because the condition on line 51 was always true

52 ref = prop["$ref"] 

53 def_name = ref.split("/", 2)[-1] 

54 def_schema = schema["$defs"][def_name] 

55 properties = def_schema["properties"] 

56 required = def_schema.get("required", []) 

57 additional = def_schema.get("additionalProperties", False) 

58 else: 

59 properties = prop["properties"] 

60 required = prop.get("required", []) 

61 additional = prop.get("additionalProperties", False) 

62 meta: FunctionToolParam = { 

63 "type": "function", 

64 "name": func.__name__, 

65 "description": doc, 

66 "parameters": { 

67 "type": "object", 

68 "properties": properties, 

69 "required": required, 

70 "additionalProperties": additional, 

71 }, 

72 "strict": True, 

73 } 

74 if "$defs" in schema: 74 ↛ 76line 74 didn't jump to line 76 because the condition on line 74 was always true

75 meta["parameters"]["$defs"] = schema["$defs"] 

76 return meta 

77 # 多参数或非 dataclass/BaseModel 

78 properties = {} 

79 required = [] 

80 for i, name in enumerate(param_names): 

81 prop = schema["properties"][f"param_{i}"] 

82 properties[name] = prop 

83 required.append(name) 

84 meta: FunctionToolParam = { 

85 "type": "function", 

86 "name": func.__name__, 

87 "description": doc, 

88 "parameters": { 

89 "type": "object", 

90 "properties": properties, 

91 "required": required, 

92 "additionalProperties": False, 

93 }, 

94 "strict": True, 

95 } 

96 if "$defs" in schema: 

97 meta["parameters"]["$defs"] = schema["$defs"] 

98 return meta 

99 

100 

101def _convert_arg(value: object, hint: type) -> object: 

102 result = value 

103 origin = getattr(hint, "__origin__", None) 

104 if origin is list or origin is set or origin is tuple: 104 ↛ 105line 104 didn't jump to line 105 because the condition on line 104 was never true

105 args = get_args(hint) 

106 item_type = args[0] if args else str 

107 result = [_convert_arg(v, item_type) for v in value] 

108 elif origin is dict: 108 ↛ 109line 108 didn't jump to line 109 because the condition on line 108 was never true

109 result = value 

110 elif getattr(hint, "__origin__", None) is Union: 110 ↛ 112line 110 didn't jump to line 112 because the condition on line 110 was never true

111 # 只处理 Optional[X],否则直接返回 

112 args = get_args(hint) 

113 non_none = [a for a in args if a is not type(None)] 

114 result = _convert_arg(value, non_none[0]) if len(non_none) == 1 else value 

115 elif (isinstance(hint, type) and BaseModel and issubclass(hint, BaseModel)) or dataclasses.is_dataclass(hint): 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true

116 result = hint(**value) if isinstance(value, dict) else value 

117 else: 

118 result = value 

119 return result 

120 

121 

122class Funcall: 

123 def __init__(self, functions: list | None = None) -> None: 

124 if functions is None: 

125 functions = [] 

126 self.functions = functions 

127 self.function_map = {func.__name__: func for func in functions} 

128 

129 def get_tools(self) -> list[FunctionToolParam]: 

130 return [generate_meta(func) for func in self.functions] 

131 

132 def handle_function_call(self, item: ResponseFunctionToolCall, context: object = None): 

133 if item.name in self.function_map: 

134 func = self.function_map[item.name] 

135 args = item.arguments 

136 sig = inspect.signature(func) 

137 type_hints = get_type_hints(func) 

138 kwargs = json.loads(args) 

139 # 兼容单参数为数组时直接传数组 

140 if isinstance(kwargs, list) and len(sig.parameters) == 1: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true

141 only_param = next(iter(sig.parameters)) 

142 kwargs = {only_param: kwargs} 

143 new_kwargs = {} 

144 for name in sig.parameters: 

145 hint = type_hints.get(name, str) 

146 if getattr(hint, "__origin__", None) is Context or hint is Context: 

147 new_kwargs[name] = context 

148 elif name in kwargs: 

149 new_kwargs[name] = _convert_arg(kwargs[name], hint) 

150 elif isinstance(hint, type) and BaseModel and issubclass(hint, BaseModel): 

151 # 支持 pydantic 单对象参数 

152 with contextlib.suppress(Exception): 

153 new_kwargs[name] = hint(**kwargs) 

154 elif dataclasses.is_dataclass(hint): 154 ↛ 144line 154 didn't jump to line 144 because the condition on line 154 was always true

155 # 支持 dataclass 单对象参数 

156 with contextlib.suppress(Exception): 

157 new_kwargs[name] = hint(**kwargs) 

158 return func(**new_kwargs) 

159 msg = f"Function {item.name} not found" 

160 raise ValueError(msg) 

161 

162 

163__all__ = ["Context", "Funcall", "generate_meta"]