Coverage for src\funcall\params_to_schema.py: 77%

117 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-18 23:51 +0900

1import dataclasses 

2import types 

3from dataclasses import fields, is_dataclass 

4from typing import Any, get_args, get_origin 

5from typing import Any as TypingAny 

6from typing import Union as TypingUnion 

7 

8from pydantic import BaseModel, create_model 

9 

10 

11def _create_union_type(union_types: tuple) -> type: 

12 """Create a Union type, handling compatibility issues""" 

13 try: 

14 return TypingUnion[union_types] # noqa: UP007 

15 except TypeError: 

16 return TypingUnion.__getitem__(union_types) 

17 

18 

19def _handle_tuple_type(args: tuple) -> type: 

20 """Handle Tuple type conversion""" 

21 if not args: 

22 return list[TypingAny] 

23 

24 # Tuple[T, ...] -> List[T] 

25 if len(args) == 2 and args[1] is Ellipsis: 

26 item_type = to_field_type(args[0]) 

27 return list[item_type] 

28 

29 # Tuple[T1, T2, ...] -> List[Union[T1, T2, ...]] 

30 item_types = tuple(to_field_type(a) for a in args) 

31 if len(item_types) == 1: 

32 return list[item_types[0]] 

33 

34 union_type = _create_union_type(item_types) 

35 return list[union_type] 

36 

37 

38def _dataclass_to_pydantic_model(dataclass_type: type) -> type: 

39 """Convert a dataclass to a Pydantic Model""" 

40 model_fields = {} 

41 

42 for field in fields(dataclass_type): 

43 # Determine the default value of the field 

44 if field.default is not dataclasses.MISSING: 

45 default_value = field.default 

46 elif field.default_factory is not dataclasses.MISSING: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true

47 default_value = field.default_factory 

48 else: 

49 default_value = ... 

50 

51 model_fields[field.name] = (field.type, default_value) 

52 

53 # Create Pydantic Model 

54 model = create_model(dataclass_type.__name__, **model_fields) 

55 

56 # Add field descriptions 

57 _add_field_descriptions(model, dataclass_type) 

58 

59 return model 

60 

61 

62def _add_field_descriptions(model: type, dataclass_type: type) -> None: 

63 """Add descriptions to Pydantic Model fields""" 

64 for field in fields(dataclass_type): 

65 if hasattr(field, "metadata") and "description" in field.metadata: 

66 description = field.metadata["description"] 

67 if hasattr(model, "model_fields") and field.name in model.model_fields: 67 ↛ 64line 67 didn't jump to line 64 because the condition on line 67 was always true

68 model.model_fields[field.name].description = description 

69 

70 

71def to_field_type(param: type) -> type: # noqa: C901, PLR0911 

72 """ 

73 Convert various type annotations to field types. 

74 """ 

75 if param is None: 75 ↛ 76line 75 didn't jump to line 76 because the condition on line 75 was never true

76 return type(None) 

77 

78 origin = get_origin(param) 

79 args = get_args(param) 

80 

81 # Check if it's a Pydantic BaseModel 

82 if isinstance(param, type) and issubclass(param, BaseModel): 

83 return param 

84 

85 # Check if it's a dataclass 

86 if is_dataclass(param): 

87 return _dataclass_to_pydantic_model(param) 

88 

89 # Handle generic types 

90 if origin is not None: 

91 # Union/Optional (compatible with 3.10+ X | Y) 

92 if origin is TypingUnion or (hasattr(types, "UnionType") and origin is types.UnionType): 

93 union_types = tuple(to_field_type(a) for a in args) 

94 return _create_union_type(union_types) 

95 

96 # List 

97 if origin is list: 

98 item_type = to_field_type(args[0]) if args else TypingAny 

99 return list[item_type] 

100 

101 # Dict - provide clearer error message 

102 if origin is dict: 102 ↛ 107line 102 didn't jump to line 107 because the condition on line 102 was always true

103 msg = f"Dict type {param} is not supported directly, use pydantic BaseModel or dataclass instead." 

104 raise TypeError(msg) 

105 

106 # Tuple 

107 if origin is tuple: 

108 return _handle_tuple_type(args) 

109 

110 # Basic type handling 

111 if isinstance(param, type): 111 ↛ 118line 111 didn't jump to line 118 because the condition on line 111 was always true

112 if param is dict: 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true

113 msg = "Dict type is not supported directly, use pydantic BaseModel or dataclass instead." 

114 raise TypeError(msg) 

115 return param 

116 

117 # If none match, raise error 

118 msg = f"Unsupported param type: {param} (type: {type(param)})" 

119 raise TypeError(msg) 

120 

121 

122def params_to_schema(params: list[Any]) -> dict[str, Any]: 

123 """ 

124 Read a parameter list, which can contain various types, dataclasses, pydantic models, basic types, even nested or nested in lists. 

125 Output a jsonschema describing this set of parameters. 

126 

127 Args: 

128 params: List of parameter types 

129 no_refs: If True, inline all definitions instead of using $ref (default: True) 

130 """ 

131 if not isinstance(params, list): 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true

132 msg = "params must be a list" 

133 raise TypeError(msg) 

134 

135 # Build parameter model 

136 if not params: 

137 model = create_model("ParamsModel") 

138 else: 

139 model_fields = {} 

140 for i, p in enumerate(params): 

141 field_type = to_field_type(p) 

142 model_fields[f"param_{i}"] = (field_type, ...) 

143 

144 model = create_model("ParamsModel", **model_fields) 

145 

146 # Generate schema with explicit mode to avoid $refs 

147 schema = model.model_json_schema(mode="serialization") 

148 

149 # Apply additional normalization 

150 _normalize_schema(schema) 

151 

152 # Remove $defs section if we want no refs 

153 if "$defs" in schema: 

154 schema = _inline_definitions(schema) 

155 

156 return schema 

157 

158 

159def _inline_definitions(schema: dict) -> dict: 

160 """ 

161 Inline all $ref definitions to avoid using references 

162 """ 

163 if "$defs" not in schema: 163 ↛ 164line 163 didn't jump to line 164 because the condition on line 163 was never true

164 return schema 

165 

166 definitions = schema["$defs"] 

167 

168 def replace_refs(obj: Any) -> Any: # noqa: ANN401 

169 if isinstance(obj, dict): 

170 if "$ref" in obj: 

171 # Extract definition name from $ref 

172 ref_path = obj["$ref"] 

173 if ref_path.startswith("#/$defs/"): 173 ↛ 180line 173 didn't jump to line 180 because the condition on line 173 was always true

174 def_name = ref_path[8:] # Remove "#/$defs/" 

175 if def_name in definitions: 175 ↛ 180line 175 didn't jump to line 180 because the condition on line 175 was always true

176 # Replace $ref with inline definition 

177 inline_def = definitions[def_name].copy() 

178 # Recursively replace refs in the inline definition 

179 return replace_refs(inline_def) 

180 return obj 

181 # Recursively process all dict values 

182 return {k: replace_refs(v) for k, v in obj.items()} 

183 if isinstance(obj, list): 

184 # Recursively process all list items 

185 return [replace_refs(item) for item in obj] 

186 return obj 

187 

188 # Replace all $refs in the schema 

189 return replace_refs(schema) 

190 

191 

192def _normalize_schema(schema: dict | list) -> None: 

193 """ 

194 Normalize schema, add additionalProperties: false and fix required fields 

195 

196 Args: 

197 schema: The schema to normalize 

198 """ 

199 if isinstance(schema, dict): 

200 if schema.get("type") == "object": 

201 schema.setdefault("additionalProperties", False) 

202 # OpenAI Function Calling: required must contain all properties 

203 if "properties" in schema: 203 ↛ 207line 203 didn't jump to line 207 because the condition on line 203 was always true

204 schema["required"] = list(schema["properties"].keys()) 

205 

206 # Recursively handle nested objects 

207 for value in schema.values(): 

208 if isinstance(value, (dict, list)): 

209 _normalize_schema(value) 

210 

211 elif isinstance(schema, list): 211 ↛ exitline 211 didn't return from function '_normalize_schema' because the condition on line 211 was always true

212 for item in schema: 

213 if isinstance(item, (dict, list)): 

214 _normalize_schema(item)