Coverage for src\funcall\params_to_schema.py: 66%

64 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-18 18:13 +0900

1import dataclasses 

2from dataclasses import fields, is_dataclass 

3from typing import Any, get_args, get_origin 

4 

5from pydantic import BaseModel, create_model 

6 

7 

8def params_to_schema(params: list[Any]) -> dict[str, Any]: 

9 """ 

10 读取一个参数列表, 内容可以是各种符合类型, dataclass, pydantic, 基本类型, 甚至嵌套或嵌套在列表里。 

11 输出描述这组参数的 jsonschema。 

12 """ 

13 

14 def to_field_type(param: type) -> tuple[type, Any]: 

15 from typing import Union as TypingUnion 

16 

17 origin = get_origin(param) 

18 args = get_args(param) 

19 # Union/Optional 

20 if origin is TypingUnion: 20 ↛ 21line 20 didn't jump to line 21 because the condition on line 20 was never true

21 union_types = tuple(to_field_type(a)[0] for a in args) 

22 try: 

23 union_type = TypingUnion[union_types] 

24 except TypeError: 

25 union_type = TypingUnion.__getitem__(union_types) 

26 return (union_type, ...) 

27 # List 

28 if origin is list: 

29 item_type = to_field_type(args[0])[0] 

30 return (list[item_type], ...) 

31 # Dict 

32 if origin is dict: 

33 key_type = to_field_type(args[0])[0] 

34 value_type = to_field_type(args[1])[0] 

35 return (dict[key_type, value_type], ...) 

36 # Tuple 

37 if origin is tuple: 37 ↛ 38line 37 didn't jump to line 38 because the condition on line 37 was never true

38 if len(args) == 2 and args[1] is Ellipsis: 

39 item_type = to_field_type(args[0])[0] 

40 return (list[item_type], ...) 

41 if len(args) > 0: 

42 item_types = tuple(to_field_type(a)[0] for a in args) 

43 if len(item_types) == 1: 

44 return (list[item_types[0]], ...) 

45 try: 

46 union_type = TypingUnion[item_types] 

47 except TypeError: 

48 union_type = TypingUnion.__getitem__(item_types) 

49 return (list[union_type], ...) 

50 from typing import Any as TypingAny 

51 

52 return (list[TypingAny], ...) 

53 # Pydantic BaseModel 

54 if isinstance(param, type) and issubclass(param, BaseModel): 

55 return (param, ...) 

56 # Dataclass 

57 if is_dataclass(param): 

58 # 动态创建 Pydantic Model,支持 description 

59 model = create_model( 

60 param.__name__, 

61 **{ 

62 f.name: ( 

63 (f.type, ...) 

64 if f.default is dataclasses.MISSING and f.default_factory is dataclasses.MISSING 

65 else (f.type, f.default if f.default is not dataclasses.MISSING else f.default_factory()) 

66 ) 

67 for f in fields(param) 

68 }, 

69 ) 

70 # 注入 description 

71 for f in fields(param): 

72 desc = f.metadata.get("description") if "description" in f.metadata else None 

73 if desc: 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true

74 model.model_fields[f.name].description = desc 

75 return (model, ...) 

76 # 基本类型 

77 if isinstance(param, type): 77 ↛ 79line 77 didn't jump to line 79 because the condition on line 77 was always true

78 return (param, ...) 

79 err = f"Unsupported param type: {param}" 

80 raise TypeError(err) 

81 

82 model = create_model( 

83 "ParamsModel", 

84 **{f"param_{i}": to_field_type(p)[0] for i, p in enumerate(params)}, 

85 ) 

86 schema = model.model_json_schema() 

87 

88 def add_additional_properties_false(obj): 

89 if isinstance(obj, dict): 

90 if obj.get("type") == "object": 

91 obj.setdefault("additionalProperties", False) 

92 for v in obj.values(): 

93 add_additional_properties_false(v) 

94 elif isinstance(obj, list): 

95 for item in obj: 

96 add_additional_properties_false(item) 

97 

98 add_additional_properties_false(schema) 

99 return schema