Coverage for src\funcall\__init__.py: 84%

217 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-19 18:21 +0900

1import asyncio 

2import concurrent.futures 

3import dataclasses 

4import inspect 

5import json 

6from collections.abc import Callable 

7from logging import getLogger 

8from typing import Generic, Literal, Required, TypedDict, TypeVar, Union, get_args, get_type_hints 

9 

10import litellm 

11from openai.types.responses import ( 

12 FunctionToolParam, 

13 ResponseFunctionToolCall, 

14) 

15from pydantic import BaseModel 

16 

17from funcall.params_to_schema import params_to_schema 

18 

19logger = getLogger("funcall") 

20 

21T = TypeVar("T") 

22 

23 

24class Context(Generic[T]): 

25 """Generic context container for dependency injection in function calls.""" 

26 

27 def __init__(self, value: T | None = None) -> None: 

28 self.value = value 

29 

30 

31class LiteLLMFunctionSpec(TypedDict): 

32 """Type definition for LiteLLM function specification.""" 

33 

34 name: Required[str] 

35 parameters: Required[dict[str, object] | None] 

36 strict: Required[bool | None] 

37 type: Required[Literal["function"]] 

38 description: str | None 

39 

40 

41class LiteLLMFunctionToolParam(TypedDict): 

42 """Type definition for LiteLLM function tool parameter.""" 

43 

44 type: Literal["function"] 

45 function: Required[LiteLLMFunctionSpec] 

46 

47 

48def generate_function_metadata( 

49 func: Callable, 

50 target: Literal["openai", "litellm"] = "openai", 

51) -> FunctionToolParam | LiteLLMFunctionToolParam: 

52 """ 

53 Generate function metadata for OpenAI or LiteLLM function calling. 

54 

55 Args: 

56 func: The function to generate metadata for 

57 target: Target platform ("openai" or "litellm") 

58 

59 Returns: 

60 Function metadata in the appropriate format 

61 """ 

62 signature = inspect.signature(func) 

63 type_hints = get_type_hints(func) 

64 description = func.__doc__.strip() if func.__doc__ else "" 

65 

66 # Extract non-context parameters 

67 param_names, param_types, context_count = _extract_parameters(signature, type_hints) 

68 

69 if context_count > 1: 

70 logger.warning( 

71 "Multiple Context-type parameters detected in function '%s'. Only one context instance will be injected at runtime.", 

72 func.__name__, 

73 ) 

74 

75 schema = params_to_schema(param_types) 

76 

77 # Handle single parameter case (dataclass or BaseModel) 

78 if len(param_names) == 1: 

79 metadata = _generate_single_param_metadata( 

80 func, 

81 param_types[0], 

82 schema, 

83 description, 

84 target, 

85 ) 

86 if metadata: 

87 return metadata 

88 

89 # Handle multiple parameters case 

90 return _generate_multi_param_metadata(func, param_names, schema, description, target) 

91 

92 

93def _extract_parameters(signature: inspect.Signature, type_hints: dict) -> tuple[list[str], list[type], int]: 

94 """Extract parameter information from function signature.""" 

95 param_names = [] 

96 param_types = [] 

97 context_count = 0 

98 

99 for name in signature.parameters: 

100 hint = type_hints.get(name, str) 

101 

102 # Skip Context-type parameters 

103 if _is_context_type(hint): 

104 context_count += 1 

105 continue 

106 

107 param_names.append(name) 

108 param_types.append(hint) 

109 

110 return param_names, param_types, context_count 

111 

112 

113def _is_context_type(hint: type) -> bool: 

114 """Check if a type hint is a Context type.""" 

115 return getattr(hint, "__origin__", None) is Context or hint is Context 

116 

117 

118def _is_optional_type(hint: type) -> bool: 

119 """判断类型是否为 Optional/Union[..., None]""" 

120 origin = getattr(hint, "__origin__", None) 

121 if origin is Union: 121 ↛ 122line 121 didn't jump to line 122 because the condition on line 121 was never true

122 args = get_args(hint) 

123 return any(a is type(None) for a in args) 

124 return False 

125 

126 

127def _generate_single_param_metadata( 

128 func: Callable, 

129 param_type: type, 

130 schema: dict, 

131 description: str, 

132 target: str, 

133) -> FunctionToolParam | LiteLLMFunctionToolParam | None: 

134 """Generate metadata for functions with a single dataclass/BaseModel parameter.""" 

135 if not (isinstance(param_type, type) and (dataclasses.is_dataclass(param_type) or (BaseModel and issubclass(param_type, BaseModel)))): 

136 return None 

137 

138 prop = schema["properties"]["param_0"] 

139 properties = prop["properties"] 

140 required = prop.get("required", []) 

141 additional_properties = prop.get("additionalProperties", False) 

142 

143 base_params = { 

144 "type": "object", 

145 "properties": properties, 

146 "additionalProperties": additional_properties, 

147 } 

148 

149 if target == "litellm": 

150 model_fields = None 

151 if BaseModel and issubclass(param_type, BaseModel): 151 ↛ 153line 151 didn't jump to line 153 because the condition on line 151 was always true

152 model_fields = param_type.model_fields 

153 elif dataclasses.is_dataclass(param_type): 

154 model_fields = {f.name: f for f in dataclasses.fields(param_type)} 

155 litellm_required = [] 

156 for k in properties: 

157 # 优先用 pydantic/dc 字段信息判断 

158 is_optional = False 

159 if model_fields and k in model_fields: 159 ↛ 167line 159 didn't jump to line 167 because the condition on line 159 was always true

160 if BaseModel and issubclass(param_type, BaseModel): 160 ↛ 164line 160 didn't jump to line 164 because the condition on line 160 was always true

161 ann = model_fields[k].annotation 

162 is_optional = _is_optional_type(ann) or model_fields[k].is_required is False 

163 else: 

164 ann = model_fields[k].type 

165 is_optional = _is_optional_type(ann) or (getattr(model_fields[k], "default", dataclasses.MISSING) is not dataclasses.MISSING) 

166 else: 

167 is_optional = k not in required 

168 if not is_optional: 168 ↛ 156line 168 didn't jump to line 156 because the condition on line 168 was always true

169 litellm_required.append(k) 

170 return { 

171 "type": "function", 

172 "function": { 

173 "name": func.__name__, 

174 "description": description, 

175 "parameters": { 

176 **base_params, 

177 "required": litellm_required, 

178 }, 

179 }, 

180 } 

181 

182 # OpenAI format 

183 metadata: FunctionToolParam = { 

184 "type": "function", 

185 "name": func.__name__, 

186 "description": description, 

187 "parameters": { 

188 **base_params, 

189 "required": list(properties.keys()), 

190 }, 

191 "strict": True, 

192 } 

193 return metadata 

194 

195 

196def _generate_multi_param_metadata( 

197 func: Callable, 

198 param_names: list[str], 

199 schema: dict, 

200 description: str, 

201 target: str, 

202) -> FunctionToolParam | LiteLLMFunctionToolParam: 

203 """Generate metadata for functions with multiple parameters.""" 

204 properties = {} 

205 for i, name in enumerate(param_names): 

206 properties[name] = schema["properties"][f"param_{i}"] 

207 

208 base_params = { 

209 "type": "object", 

210 "properties": properties, 

211 "additionalProperties": False, 

212 } 

213 

214 if target == "litellm": 

215 sig = inspect.signature(func) 

216 type_hints = get_type_hints(func) 

217 litellm_required = [] 

218 for name in param_names: 

219 hint = type_hints.get(name, str) 

220 param = sig.parameters[name] 

221 is_optional = _is_optional_type(hint) or (param.default != inspect.Parameter.empty) 

222 if not is_optional: 

223 litellm_required.append(name) 

224 return { 

225 "type": "function", 

226 "function": { 

227 "name": func.__name__, 

228 "description": description, 

229 "parameters": { 

230 **base_params, 

231 "required": litellm_required, 

232 }, 

233 }, 

234 } 

235 

236 # OpenAI format 

237 metadata: FunctionToolParam = { 

238 "type": "function", 

239 "name": func.__name__, 

240 "description": description, 

241 "parameters": { 

242 **base_params, 

243 "required": list(param_names), 

244 }, 

245 "strict": True, 

246 } 

247 

248 return metadata 

249 

250 

251def _convert_argument_type(value: object, hint: type) -> object: 

252 """ 

253 Convert argument values to match expected types. 

254 

255 Args: 

256 value: The value to convert 

257 hint: The type hint to convert to 

258 

259 Returns: 

260 Converted value 

261 """ 

262 origin = getattr(hint, "__origin__", None) 

263 result = value 

264 

265 # Handle collection types 

266 if origin in (list, set, tuple): 

267 args = get_args(hint) 

268 item_type = args[0] if args else str 

269 result = [_convert_argument_type(v, item_type) for v in value] 

270 elif origin is dict: 270 ↛ 271line 270 didn't jump to line 271 because the condition on line 270 was never true

271 result = value 

272 elif origin is Union: 

273 args = get_args(hint) 

274 non_none_types = [a for a in args if a is not type(None)] 

275 result = _convert_argument_type(value, non_none_types[0]) if len(non_none_types) == 1 else value 

276 elif isinstance(hint, type) and BaseModel and issubclass(hint, BaseModel): 

277 if isinstance(value, dict): 277 ↛ 282line 277 didn't jump to line 282 because the condition on line 277 was always true

278 fields = hint.model_fields 

279 converted_data = {k: _convert_argument_type(v, fields[k].annotation) if k in fields else v for k, v in value.items()} 

280 result = hint(**converted_data) 

281 else: 

282 result = value 

283 elif dataclasses.is_dataclass(hint): 

284 if isinstance(value, dict): 284 ↛ 289line 284 didn't jump to line 289 because the condition on line 284 was always true

285 field_types = {f.name: f.type for f in dataclasses.fields(hint)} 

286 converted_data = {k: _convert_argument_type(v, field_types.get(k, type(v))) for k, v in value.items()} 

287 result = hint(**converted_data) 

288 else: 

289 result = value 

290 

291 return result 

292 

293 

294def _is_async_function(func: Callable) -> bool: 

295 """Check if a function is asynchronous.""" 

296 return inspect.iscoroutinefunction(func) 

297 

298 

299class Funcall: 

300 """Handler for function calling in LLM interactions.""" 

301 

302 def __init__(self, functions: list[Callable] | None = None) -> None: 

303 """ 

304 Initialize the function call handler. 

305 

306 Args: 

307 functions: List of functions to register 

308 """ 

309 self.functions = functions or [] 

310 self.function_registry = {func.__name__: func for func in self.functions} 

311 

312 def get_tools(self, target: Literal["openai", "litellm"] = "openai") -> list[FunctionToolParam]: 

313 """ 

314 Get tool definitions for the specified target platform. 

315 

316 Args: 

317 target: Target platform ("openai" or "litellm") 

318 

319 Returns: 

320 List of function tool parameters 

321 """ 

322 return [generate_function_metadata(func, target) for func in self.functions] 

323 

324 def _prepare_function_execution( 

325 self, 

326 func_name: str, 

327 args: str, 

328 context: object = None, 

329 ) -> tuple[Callable, dict]: 

330 """ 

331 Prepare function call arguments and context injection. 

332 

333 Args: 

334 func_name: Name of the function to call 

335 args: JSON string of function arguments 

336 context: Context object to inject 

337 

338 Returns: 

339 Tuple of (function, prepared_kwargs) 

340 """ 

341 if func_name not in self.function_registry: 

342 msg = f"Function {func_name} not found" 

343 raise ValueError(msg) 

344 

345 func = self.function_registry[func_name] 

346 signature = inspect.signature(func) 

347 type_hints = get_type_hints(func) 

348 arguments = json.loads(args) 

349 

350 # Find non-context parameters 

351 non_context_params = [name for name in signature.parameters if not _is_context_type(type_hints.get(name, str))] 

352 

353 # Handle single parameter case 

354 if len(non_context_params) == 1 and (not isinstance(arguments, dict) or set(arguments.keys()) != set(non_context_params)): 

355 arguments = {non_context_params[0]: arguments} 

356 

357 # Prepare final kwargs with type conversion and context injection 

358 prepared_kwargs = {} 

359 for param_name in signature.parameters: 

360 hint = type_hints.get(param_name, str) 

361 

362 if _is_context_type(hint): 

363 prepared_kwargs[param_name] = context 

364 elif param_name in arguments: 364 ↛ 359line 364 didn't jump to line 359 because the condition on line 364 was always true

365 prepared_kwargs[param_name] = _convert_argument_type(arguments[param_name], hint) 

366 

367 return func, prepared_kwargs 

368 

369 def _execute_sync_in_async_context(self, func: Callable, kwargs: dict) -> object: 

370 """Execute synchronous function in async context safely.""" 

371 try: 

372 loop = asyncio.get_event_loop() 

373 if loop.is_running(): 373 ↛ 375line 373 didn't jump to line 375 because the condition on line 373 was never true

374 # If already in event loop, use thread pool 

375 with concurrent.futures.ThreadPoolExecutor() as executor: 

376 future = executor.submit(func, **kwargs) 

377 return future.result() 

378 else: 

379 return loop.run_until_complete(func(**kwargs)) 

380 except RuntimeError: 

381 # No event loop exists, create new one 

382 return asyncio.run(func(**kwargs)) 

383 

384 def call_function( 

385 self, 

386 name: str, 

387 arguments: str, 

388 context: object = None, 

389 ) -> object: 

390 """ 

391 Call a function by name with JSON arguments synchronously. 

392 

393 Args: 

394 name: Name of the function to call 

395 arguments: JSON string of function arguments 

396 context: Context object to inject (optional) 

397 

398 Returns: 

399 Function execution result 

400 

401 Raises: 

402 ValueError: If function is not found 

403 json.JSONDecodeError: If arguments are not valid JSON 

404 """ 

405 func, kwargs = self._prepare_function_execution(name, arguments, context) 

406 

407 if _is_async_function(func): 

408 logger.warning( 

409 "Function %s is async but being called synchronously. Consider using call_function_async.", 

410 name, 

411 ) 

412 return self._execute_sync_in_async_context(func, kwargs) 

413 

414 return func(**kwargs) 

415 

416 async def call_function_async( 

417 self, 

418 name: str, 

419 arguments: str, 

420 context: object = None, 

421 ) -> object: 

422 """ 

423 Call a function by name with JSON arguments asynchronously. 

424 

425 Args: 

426 name: Name of the function to call 

427 arguments: JSON string of function arguments 

428 context: Context object to inject (optional) 

429 

430 Returns: 

431 Function execution result 

432 

433 Raises: 

434 ValueError: If function is not found 

435 json.JSONDecodeError: If arguments are not valid JSON 

436 """ 

437 func, kwargs = self._prepare_function_execution(name, arguments, context) 

438 

439 if _is_async_function(func): 439 ↛ 443line 439 didn't jump to line 443 because the condition on line 439 was always true

440 return await func(**kwargs) 

441 

442 # Run sync function in thread pool to avoid blocking event loop 

443 loop = asyncio.get_event_loop() 

444 return await loop.run_in_executor(None, lambda: func(**kwargs)) 

445 

446 def handle_openai_function_call( 

447 self, 

448 call: ResponseFunctionToolCall, 

449 context: object = None, 

450 ) -> object: 

451 """ 

452 Handle OpenAI function call synchronously. 

453 

454 Args: 

455 call: OpenAI function tool call 

456 context: Context object to inject 

457 

458 Returns: 

459 Function execution result 

460 """ 

461 if not isinstance(call, ResponseFunctionToolCall): 461 ↛ 462line 461 didn't jump to line 462 because the condition on line 461 was never true

462 msg = "call must be an instance of ResponseFunctionToolCall" 

463 raise TypeError(msg) 

464 

465 return self.call_function(call.name, call.arguments, context) 

466 

467 async def handle_openai_function_call_async( 

468 self, 

469 call: ResponseFunctionToolCall, 

470 context: object = None, 

471 ) -> object: 

472 """ 

473 Handle OpenAI function call asynchronously. 

474 

475 Args: 

476 call: OpenAI function tool call 

477 context: Context object to inject 

478 

479 Returns: 

480 Function execution result 

481 """ 

482 if not isinstance(call, ResponseFunctionToolCall): 482 ↛ 483line 482 didn't jump to line 483 because the condition on line 482 was never true

483 msg = "call must be an instance of ResponseFunctionToolCall" 

484 raise TypeError(msg) 

485 

486 return await self.call_function_async(call.name, call.arguments, context) 

487 

488 def handle_litellm_function_call( 

489 self, 

490 call: litellm.ChatCompletionMessageToolCall, 

491 context: object = None, 

492 ) -> object: 

493 """ 

494 Handle LiteLLM function call synchronously. 

495 

496 Args: 

497 call: LiteLLM function tool call 

498 context: Context object to inject 

499 

500 Returns: 

501 Function execution result 

502 """ 

503 if not isinstance(call, litellm.ChatCompletionMessageToolCall): 503 ↛ 504line 503 didn't jump to line 504 because the condition on line 503 was never true

504 msg = "call must be an instance of litellm.ChatCompletionMessageToolCall" 

505 raise TypeError(msg) 

506 

507 return self.call_function( 

508 call.function.name, 

509 call.function.arguments, 

510 context, 

511 ) 

512 

513 async def handle_litellm_function_call_async( 

514 self, 

515 call: litellm.ChatCompletionMessageToolCall, 

516 context: object = None, 

517 ) -> object: 

518 """ 

519 Handle LiteLLM function call asynchronously. 

520 

521 Args: 

522 call: LiteLLM function tool call 

523 context: Context object to inject 

524 

525 Returns: 

526 Function execution result 

527 """ 

528 if not isinstance(call, litellm.ChatCompletionMessageToolCall): 528 ↛ 529line 528 didn't jump to line 529 because the condition on line 528 was never true

529 msg = "call must be an instance of litellm.ChatCompletionMessageToolCall" 

530 raise TypeError(msg) 

531 

532 return await self.call_function_async( 

533 call.function.name, 

534 call.function.arguments, 

535 context, 

536 ) 

537 

538 def handle_function_call( 

539 self, 

540 call: ResponseFunctionToolCall | litellm.ChatCompletionMessageToolCall, 

541 context: object = None, 

542 ) -> object: 

543 """ 

544 Handle function call synchronously (unified interface). 

545 

546 Args: 

547 call: Function tool call (OpenAI or LiteLLM) 

548 context: Context object to inject 

549 

550 Returns: 

551 Function execution result 

552 """ 

553 if isinstance(call, ResponseFunctionToolCall): 

554 return self.handle_openai_function_call(call, context) 

555 if isinstance(call, litellm.ChatCompletionMessageToolCall): 555 ↛ 557line 555 didn't jump to line 557 because the condition on line 555 was always true

556 return self.handle_litellm_function_call(call, context) 

557 msg = "call must be an instance of ResponseFunctionToolCall or litellm.ChatCompletionMessageToolCall" 

558 raise TypeError(msg) 

559 

560 async def handle_function_call_async( 

561 self, 

562 call: ResponseFunctionToolCall | litellm.ChatCompletionMessageToolCall, 

563 context: object = None, 

564 ) -> object: 

565 """ 

566 Handle function call asynchronously (unified interface). 

567 

568 Args: 

569 call: Function tool call (OpenAI or LiteLLM) 

570 context: Context object to inject 

571 

572 Returns: 

573 Function execution result 

574 """ 

575 if isinstance(call, ResponseFunctionToolCall): 

576 return await self.handle_openai_function_call_async(call, context) 

577 if isinstance(call, litellm.ChatCompletionMessageToolCall): 577 ↛ 579line 577 didn't jump to line 579 because the condition on line 577 was always true

578 return await self.handle_litellm_function_call_async(call, context) 

579 msg = "call must be an instance of ResponseFunctionToolCall or litellm.ChatCompletionMessageToolCall" 

580 raise TypeError(msg) 

581 

582# 兼容旧接口 

583generate_meta = generate_function_metadata 

584 

585 

586__all__ = ["Context", "Funcall", "generate_function_metadata"]