Coverage for polypandas/schema.py: 68%

181 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-02-24 14:21 -0500

1"""Schema inference and dtype conversion for pandas.""" 

2 

3from dataclasses import fields as dataclass_fields 

4from dataclasses import is_dataclass 

5from datetime import date, datetime 

6from decimal import Decimal 

7from typing import Any, Dict, List, Literal, Optional, Type, Union, get_args, get_origin 

8 

9from typing_extensions import get_type_hints 

10 

11from polypandas.exceptions import SchemaInferenceError, UnsupportedTypeError 

12 

13 

14def is_optional(type_hint: Type) -> bool: 

15 """Check if a type hint is Optional (Union with None).""" 

16 origin = get_origin(type_hint) 

17 if origin is Union: 

18 args = get_args(type_hint) 

19 return type(None) in args 

20 return False 

21 

22 

23def unwrap_optional(type_hint: Type) -> Type: 

24 """Unwrap Optional type to get the inner type.""" 

25 if is_optional(type_hint): 

26 args = get_args(type_hint) 

27 non_none_args = [arg for arg in args if arg is not type(None)] 

28 if non_none_args: 

29 return non_none_args[0] 

30 return type_hint 

31 

32 

33def infer_literal_type(literal_type: Type) -> Type: 

34 """Infer the base type from a Literal type.""" 

35 origin = get_origin(literal_type) 

36 if origin is not Literal: 

37 return literal_type 

38 

39 args = get_args(literal_type) 

40 if not args: 

41 raise SchemaInferenceError(f"Empty Literal type: {literal_type}") 

42 

43 value_types = [type(arg) for arg in args] 

44 if len(set(value_types)) == 1: 

45 return value_types[0] 

46 

47 if all(t in (int, float) for t in value_types): 

48 return float if float in value_types else int 

49 if all(t is str for t in value_types): 

50 return str 

51 if all(t is bool for t in value_types): 

52 return bool 

53 

54 raise SchemaInferenceError( 

55 f"Cannot infer unified type from Literal with mixed types: {literal_type}" 

56 ) 

57 

58 

59def python_type_to_pandas_dtype(python_type: Type) -> Any: 

60 """Convert a Python type to a pandas dtype (string or dtype). 

61 

62 Returns a string like 'int64', 'float64', 'object', 'bool', 'datetime64[ns]' 

63 when pandas is not available; otherwise can return actual numpy/pandas dtypes. 

64 

65 Args: 

66 python_type: The Python type to convert. 

67 

68 Returns: 

69 A string or dtype suitable for pandas DataFrame. 

70 

71 Raises: 

72 UnsupportedTypeError: If the type cannot be converted. 

73 """ 

74 if is_optional(python_type): 

75 python_type = unwrap_optional(python_type) 

76 

77 origin = get_origin(python_type) 

78 if origin is Literal: 

79 python_type = infer_literal_type(python_type) 

80 

81 origin = get_origin(python_type) 

82 args = get_args(python_type) 

83 

84 type_mapping = { 

85 str: "object", 

86 int: "int64", 

87 float: "float64", 

88 bool: "bool", 

89 bytes: "object", 

90 bytearray: "object", 

91 date: "datetime64[ns]", 

92 datetime: "datetime64[ns]", 

93 Decimal: "object", 

94 } 

95 

96 if python_type in type_mapping: 

97 return type_mapping[python_type] 

98 

99 if origin in (list, List): 

100 if not args: 

101 raise SchemaInferenceError(f"Cannot infer array element type from {python_type}") 

102 return "object" # pandas object dtype for list columns 

103 

104 if origin in (dict, Dict): 

105 return "object" # pandas object dtype for dict columns 

106 

107 if is_dataclass(python_type): 

108 return "object" # nested struct as object 

109 

110 if hasattr(python_type, "model_fields"): 

111 return "object" 

112 

113 if hasattr(python_type, "__annotations__"): 

114 return "object" 

115 

116 raise UnsupportedTypeError(f"Cannot convert type {python_type} to pandas dtype") 

117 

118 

119def _get_model_field_types(model: Type) -> Dict[str, Type]: 

120 """Get field name -> type mapping from a model.""" 

121 if is_dataclass(model): 

122 type_hints = get_type_hints(model) 

123 return {f.name: type_hints.get(f.name, f.type) for f in dataclass_fields(model)} 

124 if hasattr(model, "model_fields"): 

125 return {name: info.annotation for name, info in model.model_fields.items()} 

126 if hasattr(model, "__annotations__"): 

127 return dict(model.__annotations__) 

128 raise SchemaInferenceError(f"Cannot infer schema from {model}") 

129 

130 

131def infer_schema( 

132 model: Type, 

133 schema: Optional[Dict[str, Any]] = None, 

134) -> Dict[str, Any]: 

135 """Infer a pandas dtype dict from a model type. 

136 

137 Args: 

138 model: The model type (dataclass, Pydantic, TypedDict). 

139 schema: Optional explicit dtype dict (column name -> dtype). If provided, returned as-is. 

140 

141 Returns: 

142 A dict mapping column names to pandas dtypes (strings or dtypes). 

143 

144 Raises: 

145 SchemaInferenceError: If schema cannot be inferred. 

146 """ 

147 if schema is not None and isinstance(schema, dict): 

148 return schema 

149 

150 try: 

151 field_types = _get_model_field_types(model) 

152 except Exception as e: 

153 raise SchemaInferenceError(f"Cannot infer schema from {model}: {e}") from e 

154 

155 result = {} 

156 for field_name, field_type in field_types.items(): 

157 try: 

158 result[field_name] = python_type_to_pandas_dtype(field_type) 

159 except UnsupportedTypeError as e: 

160 raise SchemaInferenceError(f"Cannot infer type for field {field_name}: {e}") from e 

161 

162 return result 

163 

164 

165def infer_dtypes_for_dataframe(model: Type) -> Optional[Dict[str, Any]]: 

166 """Infer dtypes dict suitable for pd.DataFrame(..., dtype=...). 

167 

168 Returns None if pandas is not available (caller can omit dtype and let inference happen). 

169 """ 

170 return infer_schema(model) 

171 

172 

173def _is_struct_like(type_hint: Type) -> bool: 

174 """True if the type is a dataclass, Pydantic model, or TypedDict (nested struct).""" 

175 if is_dataclass(type_hint): 

176 return True 

177 if get_origin(type_hint) in (list, List): 

178 args = get_args(type_hint) 

179 if args and (is_dataclass(args[0]) or hasattr(args[0], "model_fields")): 

180 return True 

181 if hasattr(type_hint, "model_fields"): 

182 return True 

183 if hasattr(type_hint, "__annotations__") and get_origin(type_hint) not in ( 

184 list, 

185 List, 

186 dict, 

187 Dict, 

188 ): 

189 return True 

190 return False 

191 

192 

193def has_nested_structs(model: Type) -> bool: 

194 """Return True if the model has any field that is a nested struct (dataclass, Pydantic, etc.) or list of structs.""" 

195 try: 

196 field_types = _get_model_field_types(model) 

197 except Exception: 

198 return False 

199 for field_type in field_types.values(): 

200 t = unwrap_optional(field_type) 

201 if _is_struct_like(t): 

202 return True 

203 if get_origin(t) in (list, List): 

204 args = get_args(t) 

205 if args and _is_struct_like(args[0]): 

206 return True 

207 return False 

208 

209 

210def infer_pyarrow_schema(model: Type) -> Optional[Any]: 

211 """Infer a PyArrow Schema from the model type. 

212 

213 Returns None if PyArrow is not installed or schema cannot be inferred. 

214 Requires the optional dependency: pip install polypandas[pyarrow] 

215 """ 

216 try: 

217 import pyarrow as pa 

218 except ImportError: 

219 return None 

220 

221 def python_type_to_pa(python_type: Type, nullable: bool = True) -> Any: 

222 if is_optional(python_type): 

223 python_type = unwrap_optional(python_type) 

224 origin = get_origin(python_type) 

225 if origin is Literal: 

226 python_type = infer_literal_type(python_type) 

227 origin = None 

228 args = get_args(python_type) if origin else () 

229 

230 if python_type in (str, int, float, bool): 

231 type_map = {str: pa.string(), int: pa.int64(), float: pa.float64(), bool: pa.bool_()} 

232 return type_map[python_type] 

233 if python_type is bytes or python_type is bytearray: 

234 return pa.binary() 

235 if python_type is date: 

236 return pa.date32() 

237 if python_type is datetime: 

238 return pa.timestamp("us") 

239 if python_type is Decimal: 

240 return pa.decimal128(38, 9) 

241 

242 if origin in (list, List) and args: 

243 inner = python_type_to_pa(args[0], nullable=True) 

244 return pa.list_(inner) 

245 if origin in (dict, Dict) and len(args) >= 2: 

246 k = python_type_to_pa(args[0], nullable=False) 

247 v = python_type_to_pa(args[1], nullable=True) 

248 return pa.map_(k, v) 

249 

250 if is_dataclass(python_type): 

251 fields = [] 

252 type_hints = get_type_hints(python_type) 

253 for f in dataclass_fields(python_type): 

254 ft = type_hints.get(f.name, f.type) 

255 n = is_optional(ft) 

256 pa_type = python_type_to_pa(ft, nullable=n) 

257 fields.append(pa.field(f.name, pa_type, nullable=n)) 

258 return pa.struct(fields) 

259 

260 if hasattr(python_type, "model_fields"): 

261 fields = [] 

262 for name, info in python_type.model_fields.items(): 

263 ft = info.annotation 

264 n = not info.is_required() or is_optional(ft) 

265 pa_type = python_type_to_pa(ft, nullable=n) 

266 fields.append(pa.field(name, pa_type, nullable=n)) 

267 return pa.struct(fields) 

268 

269 if hasattr(python_type, "__annotations__"): 

270 required: set = getattr(python_type, "__required_keys__", set()) 

271 fields = [] 

272 for name, ft in python_type.__annotations__.items(): 

273 n = name not in required or is_optional(ft) 

274 pa_type = python_type_to_pa(ft, nullable=n) 

275 fields.append(pa.field(name, pa_type, nullable=n)) 

276 return pa.struct(fields) 

277 

278 return None 

279 

280 try: 

281 field_types = _get_model_field_types(model) 

282 except Exception: 

283 return None 

284 

285 pa_fields = [] 

286 for field_name, field_type in field_types.items(): 

287 nullable = is_optional(field_type) 

288 pa_type = python_type_to_pa(field_type, nullable=nullable) 

289 if pa_type is None: 

290 return None 

291 pa_fields.append(pa.field(field_name, pa_type, nullable=nullable)) 

292 

293 return pa.schema(pa_fields)