Coverage for src\funcall\__init__.py: 84%
217 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-19 18:21 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-19 18:21 +0900
1import asyncio
2import concurrent.futures
3import dataclasses
4import inspect
5import json
6from collections.abc import Callable
7from logging import getLogger
8from typing import Generic, Literal, Required, TypedDict, TypeVar, Union, get_args, get_type_hints
10import litellm
11from openai.types.responses import (
12 FunctionToolParam,
13 ResponseFunctionToolCall,
14)
15from pydantic import BaseModel
17from funcall.params_to_schema import params_to_schema
19logger = getLogger("funcall")
21T = TypeVar("T")
24class Context(Generic[T]):
25 """Generic context container for dependency injection in function calls."""
27 def __init__(self, value: T | None = None) -> None:
28 self.value = value
31class LiteLLMFunctionSpec(TypedDict):
32 """Type definition for LiteLLM function specification."""
34 name: Required[str]
35 parameters: Required[dict[str, object] | None]
36 strict: Required[bool | None]
37 type: Required[Literal["function"]]
38 description: str | None
41class LiteLLMFunctionToolParam(TypedDict):
42 """Type definition for LiteLLM function tool parameter."""
44 type: Literal["function"]
45 function: Required[LiteLLMFunctionSpec]
48def generate_function_metadata(
49 func: Callable,
50 target: Literal["openai", "litellm"] = "openai",
51) -> FunctionToolParam | LiteLLMFunctionToolParam:
52 """
53 Generate function metadata for OpenAI or LiteLLM function calling.
55 Args:
56 func: The function to generate metadata for
57 target: Target platform ("openai" or "litellm")
59 Returns:
60 Function metadata in the appropriate format
61 """
62 signature = inspect.signature(func)
63 type_hints = get_type_hints(func)
64 description = func.__doc__.strip() if func.__doc__ else ""
66 # Extract non-context parameters
67 param_names, param_types, context_count = _extract_parameters(signature, type_hints)
69 if context_count > 1:
70 logger.warning(
71 "Multiple Context-type parameters detected in function '%s'. Only one context instance will be injected at runtime.",
72 func.__name__,
73 )
75 schema = params_to_schema(param_types)
77 # Handle single parameter case (dataclass or BaseModel)
78 if len(param_names) == 1:
79 metadata = _generate_single_param_metadata(
80 func,
81 param_types[0],
82 schema,
83 description,
84 target,
85 )
86 if metadata:
87 return metadata
89 # Handle multiple parameters case
90 return _generate_multi_param_metadata(func, param_names, schema, description, target)
93def _extract_parameters(signature: inspect.Signature, type_hints: dict) -> tuple[list[str], list[type], int]:
94 """Extract parameter information from function signature."""
95 param_names = []
96 param_types = []
97 context_count = 0
99 for name in signature.parameters:
100 hint = type_hints.get(name, str)
102 # Skip Context-type parameters
103 if _is_context_type(hint):
104 context_count += 1
105 continue
107 param_names.append(name)
108 param_types.append(hint)
110 return param_names, param_types, context_count
113def _is_context_type(hint: type) -> bool:
114 """Check if a type hint is a Context type."""
115 return getattr(hint, "__origin__", None) is Context or hint is Context
118def _is_optional_type(hint: type) -> bool:
119 """判断类型是否为 Optional/Union[..., None]"""
120 origin = getattr(hint, "__origin__", None)
121 if origin is Union: 121 ↛ 122line 121 didn't jump to line 122 because the condition on line 121 was never true
122 args = get_args(hint)
123 return any(a is type(None) for a in args)
124 return False
127def _generate_single_param_metadata(
128 func: Callable,
129 param_type: type,
130 schema: dict,
131 description: str,
132 target: str,
133) -> FunctionToolParam | LiteLLMFunctionToolParam | None:
134 """Generate metadata for functions with a single dataclass/BaseModel parameter."""
135 if not (isinstance(param_type, type) and (dataclasses.is_dataclass(param_type) or (BaseModel and issubclass(param_type, BaseModel)))):
136 return None
138 prop = schema["properties"]["param_0"]
139 properties = prop["properties"]
140 required = prop.get("required", [])
141 additional_properties = prop.get("additionalProperties", False)
143 base_params = {
144 "type": "object",
145 "properties": properties,
146 "additionalProperties": additional_properties,
147 }
149 if target == "litellm":
150 model_fields = None
151 if BaseModel and issubclass(param_type, BaseModel): 151 ↛ 153line 151 didn't jump to line 153 because the condition on line 151 was always true
152 model_fields = param_type.model_fields
153 elif dataclasses.is_dataclass(param_type):
154 model_fields = {f.name: f for f in dataclasses.fields(param_type)}
155 litellm_required = []
156 for k in properties:
157 # 优先用 pydantic/dc 字段信息判断
158 is_optional = False
159 if model_fields and k in model_fields: 159 ↛ 167line 159 didn't jump to line 167 because the condition on line 159 was always true
160 if BaseModel and issubclass(param_type, BaseModel): 160 ↛ 164line 160 didn't jump to line 164 because the condition on line 160 was always true
161 ann = model_fields[k].annotation
162 is_optional = _is_optional_type(ann) or model_fields[k].is_required is False
163 else:
164 ann = model_fields[k].type
165 is_optional = _is_optional_type(ann) or (getattr(model_fields[k], "default", dataclasses.MISSING) is not dataclasses.MISSING)
166 else:
167 is_optional = k not in required
168 if not is_optional: 168 ↛ 156line 168 didn't jump to line 156 because the condition on line 168 was always true
169 litellm_required.append(k)
170 return {
171 "type": "function",
172 "function": {
173 "name": func.__name__,
174 "description": description,
175 "parameters": {
176 **base_params,
177 "required": litellm_required,
178 },
179 },
180 }
182 # OpenAI format
183 metadata: FunctionToolParam = {
184 "type": "function",
185 "name": func.__name__,
186 "description": description,
187 "parameters": {
188 **base_params,
189 "required": list(properties.keys()),
190 },
191 "strict": True,
192 }
193 return metadata
196def _generate_multi_param_metadata(
197 func: Callable,
198 param_names: list[str],
199 schema: dict,
200 description: str,
201 target: str,
202) -> FunctionToolParam | LiteLLMFunctionToolParam:
203 """Generate metadata for functions with multiple parameters."""
204 properties = {}
205 for i, name in enumerate(param_names):
206 properties[name] = schema["properties"][f"param_{i}"]
208 base_params = {
209 "type": "object",
210 "properties": properties,
211 "additionalProperties": False,
212 }
214 if target == "litellm":
215 sig = inspect.signature(func)
216 type_hints = get_type_hints(func)
217 litellm_required = []
218 for name in param_names:
219 hint = type_hints.get(name, str)
220 param = sig.parameters[name]
221 is_optional = _is_optional_type(hint) or (param.default != inspect.Parameter.empty)
222 if not is_optional:
223 litellm_required.append(name)
224 return {
225 "type": "function",
226 "function": {
227 "name": func.__name__,
228 "description": description,
229 "parameters": {
230 **base_params,
231 "required": litellm_required,
232 },
233 },
234 }
236 # OpenAI format
237 metadata: FunctionToolParam = {
238 "type": "function",
239 "name": func.__name__,
240 "description": description,
241 "parameters": {
242 **base_params,
243 "required": list(param_names),
244 },
245 "strict": True,
246 }
248 return metadata
251def _convert_argument_type(value: object, hint: type) -> object:
252 """
253 Convert argument values to match expected types.
255 Args:
256 value: The value to convert
257 hint: The type hint to convert to
259 Returns:
260 Converted value
261 """
262 origin = getattr(hint, "__origin__", None)
263 result = value
265 # Handle collection types
266 if origin in (list, set, tuple):
267 args = get_args(hint)
268 item_type = args[0] if args else str
269 result = [_convert_argument_type(v, item_type) for v in value]
270 elif origin is dict: 270 ↛ 271line 270 didn't jump to line 271 because the condition on line 270 was never true
271 result = value
272 elif origin is Union:
273 args = get_args(hint)
274 non_none_types = [a for a in args if a is not type(None)]
275 result = _convert_argument_type(value, non_none_types[0]) if len(non_none_types) == 1 else value
276 elif isinstance(hint, type) and BaseModel and issubclass(hint, BaseModel):
277 if isinstance(value, dict): 277 ↛ 282line 277 didn't jump to line 282 because the condition on line 277 was always true
278 fields = hint.model_fields
279 converted_data = {k: _convert_argument_type(v, fields[k].annotation) if k in fields else v for k, v in value.items()}
280 result = hint(**converted_data)
281 else:
282 result = value
283 elif dataclasses.is_dataclass(hint):
284 if isinstance(value, dict): 284 ↛ 289line 284 didn't jump to line 289 because the condition on line 284 was always true
285 field_types = {f.name: f.type for f in dataclasses.fields(hint)}
286 converted_data = {k: _convert_argument_type(v, field_types.get(k, type(v))) for k, v in value.items()}
287 result = hint(**converted_data)
288 else:
289 result = value
291 return result
294def _is_async_function(func: Callable) -> bool:
295 """Check if a function is asynchronous."""
296 return inspect.iscoroutinefunction(func)
299class Funcall:
300 """Handler for function calling in LLM interactions."""
302 def __init__(self, functions: list[Callable] | None = None) -> None:
303 """
304 Initialize the function call handler.
306 Args:
307 functions: List of functions to register
308 """
309 self.functions = functions or []
310 self.function_registry = {func.__name__: func for func in self.functions}
312 def get_tools(self, target: Literal["openai", "litellm"] = "openai") -> list[FunctionToolParam]:
313 """
314 Get tool definitions for the specified target platform.
316 Args:
317 target: Target platform ("openai" or "litellm")
319 Returns:
320 List of function tool parameters
321 """
322 return [generate_function_metadata(func, target) for func in self.functions]
324 def _prepare_function_execution(
325 self,
326 func_name: str,
327 args: str,
328 context: object = None,
329 ) -> tuple[Callable, dict]:
330 """
331 Prepare function call arguments and context injection.
333 Args:
334 func_name: Name of the function to call
335 args: JSON string of function arguments
336 context: Context object to inject
338 Returns:
339 Tuple of (function, prepared_kwargs)
340 """
341 if func_name not in self.function_registry:
342 msg = f"Function {func_name} not found"
343 raise ValueError(msg)
345 func = self.function_registry[func_name]
346 signature = inspect.signature(func)
347 type_hints = get_type_hints(func)
348 arguments = json.loads(args)
350 # Find non-context parameters
351 non_context_params = [name for name in signature.parameters if not _is_context_type(type_hints.get(name, str))]
353 # Handle single parameter case
354 if len(non_context_params) == 1 and (not isinstance(arguments, dict) or set(arguments.keys()) != set(non_context_params)):
355 arguments = {non_context_params[0]: arguments}
357 # Prepare final kwargs with type conversion and context injection
358 prepared_kwargs = {}
359 for param_name in signature.parameters:
360 hint = type_hints.get(param_name, str)
362 if _is_context_type(hint):
363 prepared_kwargs[param_name] = context
364 elif param_name in arguments: 364 ↛ 359line 364 didn't jump to line 359 because the condition on line 364 was always true
365 prepared_kwargs[param_name] = _convert_argument_type(arguments[param_name], hint)
367 return func, prepared_kwargs
369 def _execute_sync_in_async_context(self, func: Callable, kwargs: dict) -> object:
370 """Execute synchronous function in async context safely."""
371 try:
372 loop = asyncio.get_event_loop()
373 if loop.is_running(): 373 ↛ 375line 373 didn't jump to line 375 because the condition on line 373 was never true
374 # If already in event loop, use thread pool
375 with concurrent.futures.ThreadPoolExecutor() as executor:
376 future = executor.submit(func, **kwargs)
377 return future.result()
378 else:
379 return loop.run_until_complete(func(**kwargs))
380 except RuntimeError:
381 # No event loop exists, create new one
382 return asyncio.run(func(**kwargs))
384 def call_function(
385 self,
386 name: str,
387 arguments: str,
388 context: object = None,
389 ) -> object:
390 """
391 Call a function by name with JSON arguments synchronously.
393 Args:
394 name: Name of the function to call
395 arguments: JSON string of function arguments
396 context: Context object to inject (optional)
398 Returns:
399 Function execution result
401 Raises:
402 ValueError: If function is not found
403 json.JSONDecodeError: If arguments are not valid JSON
404 """
405 func, kwargs = self._prepare_function_execution(name, arguments, context)
407 if _is_async_function(func):
408 logger.warning(
409 "Function %s is async but being called synchronously. Consider using call_function_async.",
410 name,
411 )
412 return self._execute_sync_in_async_context(func, kwargs)
414 return func(**kwargs)
416 async def call_function_async(
417 self,
418 name: str,
419 arguments: str,
420 context: object = None,
421 ) -> object:
422 """
423 Call a function by name with JSON arguments asynchronously.
425 Args:
426 name: Name of the function to call
427 arguments: JSON string of function arguments
428 context: Context object to inject (optional)
430 Returns:
431 Function execution result
433 Raises:
434 ValueError: If function is not found
435 json.JSONDecodeError: If arguments are not valid JSON
436 """
437 func, kwargs = self._prepare_function_execution(name, arguments, context)
439 if _is_async_function(func): 439 ↛ 443line 439 didn't jump to line 443 because the condition on line 439 was always true
440 return await func(**kwargs)
442 # Run sync function in thread pool to avoid blocking event loop
443 loop = asyncio.get_event_loop()
444 return await loop.run_in_executor(None, lambda: func(**kwargs))
446 def handle_openai_function_call(
447 self,
448 call: ResponseFunctionToolCall,
449 context: object = None,
450 ) -> object:
451 """
452 Handle OpenAI function call synchronously.
454 Args:
455 call: OpenAI function tool call
456 context: Context object to inject
458 Returns:
459 Function execution result
460 """
461 if not isinstance(call, ResponseFunctionToolCall): 461 ↛ 462line 461 didn't jump to line 462 because the condition on line 461 was never true
462 msg = "call must be an instance of ResponseFunctionToolCall"
463 raise TypeError(msg)
465 return self.call_function(call.name, call.arguments, context)
467 async def handle_openai_function_call_async(
468 self,
469 call: ResponseFunctionToolCall,
470 context: object = None,
471 ) -> object:
472 """
473 Handle OpenAI function call asynchronously.
475 Args:
476 call: OpenAI function tool call
477 context: Context object to inject
479 Returns:
480 Function execution result
481 """
482 if not isinstance(call, ResponseFunctionToolCall): 482 ↛ 483line 482 didn't jump to line 483 because the condition on line 482 was never true
483 msg = "call must be an instance of ResponseFunctionToolCall"
484 raise TypeError(msg)
486 return await self.call_function_async(call.name, call.arguments, context)
488 def handle_litellm_function_call(
489 self,
490 call: litellm.ChatCompletionMessageToolCall,
491 context: object = None,
492 ) -> object:
493 """
494 Handle LiteLLM function call synchronously.
496 Args:
497 call: LiteLLM function tool call
498 context: Context object to inject
500 Returns:
501 Function execution result
502 """
503 if not isinstance(call, litellm.ChatCompletionMessageToolCall): 503 ↛ 504line 503 didn't jump to line 504 because the condition on line 503 was never true
504 msg = "call must be an instance of litellm.ChatCompletionMessageToolCall"
505 raise TypeError(msg)
507 return self.call_function(
508 call.function.name,
509 call.function.arguments,
510 context,
511 )
513 async def handle_litellm_function_call_async(
514 self,
515 call: litellm.ChatCompletionMessageToolCall,
516 context: object = None,
517 ) -> object:
518 """
519 Handle LiteLLM function call asynchronously.
521 Args:
522 call: LiteLLM function tool call
523 context: Context object to inject
525 Returns:
526 Function execution result
527 """
528 if not isinstance(call, litellm.ChatCompletionMessageToolCall): 528 ↛ 529line 528 didn't jump to line 529 because the condition on line 528 was never true
529 msg = "call must be an instance of litellm.ChatCompletionMessageToolCall"
530 raise TypeError(msg)
532 return await self.call_function_async(
533 call.function.name,
534 call.function.arguments,
535 context,
536 )
538 def handle_function_call(
539 self,
540 call: ResponseFunctionToolCall | litellm.ChatCompletionMessageToolCall,
541 context: object = None,
542 ) -> object:
543 """
544 Handle function call synchronously (unified interface).
546 Args:
547 call: Function tool call (OpenAI or LiteLLM)
548 context: Context object to inject
550 Returns:
551 Function execution result
552 """
553 if isinstance(call, ResponseFunctionToolCall):
554 return self.handle_openai_function_call(call, context)
555 if isinstance(call, litellm.ChatCompletionMessageToolCall): 555 ↛ 557line 555 didn't jump to line 557 because the condition on line 555 was always true
556 return self.handle_litellm_function_call(call, context)
557 msg = "call must be an instance of ResponseFunctionToolCall or litellm.ChatCompletionMessageToolCall"
558 raise TypeError(msg)
560 async def handle_function_call_async(
561 self,
562 call: ResponseFunctionToolCall | litellm.ChatCompletionMessageToolCall,
563 context: object = None,
564 ) -> object:
565 """
566 Handle function call asynchronously (unified interface).
568 Args:
569 call: Function tool call (OpenAI or LiteLLM)
570 context: Context object to inject
572 Returns:
573 Function execution result
574 """
575 if isinstance(call, ResponseFunctionToolCall):
576 return await self.handle_openai_function_call_async(call, context)
577 if isinstance(call, litellm.ChatCompletionMessageToolCall): 577 ↛ 579line 577 didn't jump to line 579 because the condition on line 577 was always true
578 return await self.handle_litellm_function_call_async(call, context)
579 msg = "call must be an instance of ResponseFunctionToolCall or litellm.ChatCompletionMessageToolCall"
580 raise TypeError(msg)
582# 兼容旧接口
583generate_meta = generate_function_metadata
586__all__ = ["Context", "Funcall", "generate_function_metadata"]