Coverage for polypandas/schema.py: 68%
181 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-02-24 14:21 -0500
« prev ^ index » next coverage.py v7.6.1, created at 2026-02-24 14:21 -0500
1"""Schema inference and dtype conversion for pandas."""
3from dataclasses import fields as dataclass_fields
4from dataclasses import is_dataclass
5from datetime import date, datetime
6from decimal import Decimal
7from typing import Any, Dict, List, Literal, Optional, Type, Union, get_args, get_origin
9from typing_extensions import get_type_hints
11from polypandas.exceptions import SchemaInferenceError, UnsupportedTypeError
14def is_optional(type_hint: Type) -> bool:
15 """Check if a type hint is Optional (Union with None)."""
16 origin = get_origin(type_hint)
17 if origin is Union:
18 args = get_args(type_hint)
19 return type(None) in args
20 return False
23def unwrap_optional(type_hint: Type) -> Type:
24 """Unwrap Optional type to get the inner type."""
25 if is_optional(type_hint):
26 args = get_args(type_hint)
27 non_none_args = [arg for arg in args if arg is not type(None)]
28 if non_none_args:
29 return non_none_args[0]
30 return type_hint
33def infer_literal_type(literal_type: Type) -> Type:
34 """Infer the base type from a Literal type."""
35 origin = get_origin(literal_type)
36 if origin is not Literal:
37 return literal_type
39 args = get_args(literal_type)
40 if not args:
41 raise SchemaInferenceError(f"Empty Literal type: {literal_type}")
43 value_types = [type(arg) for arg in args]
44 if len(set(value_types)) == 1:
45 return value_types[0]
47 if all(t in (int, float) for t in value_types):
48 return float if float in value_types else int
49 if all(t is str for t in value_types):
50 return str
51 if all(t is bool for t in value_types):
52 return bool
54 raise SchemaInferenceError(
55 f"Cannot infer unified type from Literal with mixed types: {literal_type}"
56 )
59def python_type_to_pandas_dtype(python_type: Type) -> Any:
60 """Convert a Python type to a pandas dtype (string or dtype).
62 Returns a string like 'int64', 'float64', 'object', 'bool', 'datetime64[ns]'
63 when pandas is not available; otherwise can return actual numpy/pandas dtypes.
65 Args:
66 python_type: The Python type to convert.
68 Returns:
69 A string or dtype suitable for pandas DataFrame.
71 Raises:
72 UnsupportedTypeError: If the type cannot be converted.
73 """
74 if is_optional(python_type):
75 python_type = unwrap_optional(python_type)
77 origin = get_origin(python_type)
78 if origin is Literal:
79 python_type = infer_literal_type(python_type)
81 origin = get_origin(python_type)
82 args = get_args(python_type)
84 type_mapping = {
85 str: "object",
86 int: "int64",
87 float: "float64",
88 bool: "bool",
89 bytes: "object",
90 bytearray: "object",
91 date: "datetime64[ns]",
92 datetime: "datetime64[ns]",
93 Decimal: "object",
94 }
96 if python_type in type_mapping:
97 return type_mapping[python_type]
99 if origin in (list, List):
100 if not args:
101 raise SchemaInferenceError(f"Cannot infer array element type from {python_type}")
102 return "object" # pandas object dtype for list columns
104 if origin in (dict, Dict):
105 return "object" # pandas object dtype for dict columns
107 if is_dataclass(python_type):
108 return "object" # nested struct as object
110 if hasattr(python_type, "model_fields"):
111 return "object"
113 if hasattr(python_type, "__annotations__"):
114 return "object"
116 raise UnsupportedTypeError(f"Cannot convert type {python_type} to pandas dtype")
119def _get_model_field_types(model: Type) -> Dict[str, Type]:
120 """Get field name -> type mapping from a model."""
121 if is_dataclass(model):
122 type_hints = get_type_hints(model)
123 return {f.name: type_hints.get(f.name, f.type) for f in dataclass_fields(model)}
124 if hasattr(model, "model_fields"):
125 return {name: info.annotation for name, info in model.model_fields.items()}
126 if hasattr(model, "__annotations__"):
127 return dict(model.__annotations__)
128 raise SchemaInferenceError(f"Cannot infer schema from {model}")
131def infer_schema(
132 model: Type,
133 schema: Optional[Dict[str, Any]] = None,
134) -> Dict[str, Any]:
135 """Infer a pandas dtype dict from a model type.
137 Args:
138 model: The model type (dataclass, Pydantic, TypedDict).
139 schema: Optional explicit dtype dict (column name -> dtype). If provided, returned as-is.
141 Returns:
142 A dict mapping column names to pandas dtypes (strings or dtypes).
144 Raises:
145 SchemaInferenceError: If schema cannot be inferred.
146 """
147 if schema is not None and isinstance(schema, dict):
148 return schema
150 try:
151 field_types = _get_model_field_types(model)
152 except Exception as e:
153 raise SchemaInferenceError(f"Cannot infer schema from {model}: {e}") from e
155 result = {}
156 for field_name, field_type in field_types.items():
157 try:
158 result[field_name] = python_type_to_pandas_dtype(field_type)
159 except UnsupportedTypeError as e:
160 raise SchemaInferenceError(f"Cannot infer type for field {field_name}: {e}") from e
162 return result
165def infer_dtypes_for_dataframe(model: Type) -> Optional[Dict[str, Any]]:
166 """Infer dtypes dict suitable for pd.DataFrame(..., dtype=...).
168 Returns None if pandas is not available (caller can omit dtype and let inference happen).
169 """
170 return infer_schema(model)
173def _is_struct_like(type_hint: Type) -> bool:
174 """True if the type is a dataclass, Pydantic model, or TypedDict (nested struct)."""
175 if is_dataclass(type_hint):
176 return True
177 if get_origin(type_hint) in (list, List):
178 args = get_args(type_hint)
179 if args and (is_dataclass(args[0]) or hasattr(args[0], "model_fields")):
180 return True
181 if hasattr(type_hint, "model_fields"):
182 return True
183 if hasattr(type_hint, "__annotations__") and get_origin(type_hint) not in (
184 list,
185 List,
186 dict,
187 Dict,
188 ):
189 return True
190 return False
193def has_nested_structs(model: Type) -> bool:
194 """Return True if the model has any field that is a nested struct (dataclass, Pydantic, etc.) or list of structs."""
195 try:
196 field_types = _get_model_field_types(model)
197 except Exception:
198 return False
199 for field_type in field_types.values():
200 t = unwrap_optional(field_type)
201 if _is_struct_like(t):
202 return True
203 if get_origin(t) in (list, List):
204 args = get_args(t)
205 if args and _is_struct_like(args[0]):
206 return True
207 return False
210def infer_pyarrow_schema(model: Type) -> Optional[Any]:
211 """Infer a PyArrow Schema from the model type.
213 Returns None if PyArrow is not installed or schema cannot be inferred.
214 Requires the optional dependency: pip install polypandas[pyarrow]
215 """
216 try:
217 import pyarrow as pa
218 except ImportError:
219 return None
221 def python_type_to_pa(python_type: Type, nullable: bool = True) -> Any:
222 if is_optional(python_type):
223 python_type = unwrap_optional(python_type)
224 origin = get_origin(python_type)
225 if origin is Literal:
226 python_type = infer_literal_type(python_type)
227 origin = None
228 args = get_args(python_type) if origin else ()
230 if python_type in (str, int, float, bool):
231 type_map = {str: pa.string(), int: pa.int64(), float: pa.float64(), bool: pa.bool_()}
232 return type_map[python_type]
233 if python_type is bytes or python_type is bytearray:
234 return pa.binary()
235 if python_type is date:
236 return pa.date32()
237 if python_type is datetime:
238 return pa.timestamp("us")
239 if python_type is Decimal:
240 return pa.decimal128(38, 9)
242 if origin in (list, List) and args:
243 inner = python_type_to_pa(args[0], nullable=True)
244 return pa.list_(inner)
245 if origin in (dict, Dict) and len(args) >= 2:
246 k = python_type_to_pa(args[0], nullable=False)
247 v = python_type_to_pa(args[1], nullable=True)
248 return pa.map_(k, v)
250 if is_dataclass(python_type):
251 fields = []
252 type_hints = get_type_hints(python_type)
253 for f in dataclass_fields(python_type):
254 ft = type_hints.get(f.name, f.type)
255 n = is_optional(ft)
256 pa_type = python_type_to_pa(ft, nullable=n)
257 fields.append(pa.field(f.name, pa_type, nullable=n))
258 return pa.struct(fields)
260 if hasattr(python_type, "model_fields"):
261 fields = []
262 for name, info in python_type.model_fields.items():
263 ft = info.annotation
264 n = not info.is_required() or is_optional(ft)
265 pa_type = python_type_to_pa(ft, nullable=n)
266 fields.append(pa.field(name, pa_type, nullable=n))
267 return pa.struct(fields)
269 if hasattr(python_type, "__annotations__"):
270 required: set = getattr(python_type, "__required_keys__", set())
271 fields = []
272 for name, ft in python_type.__annotations__.items():
273 n = name not in required or is_optional(ft)
274 pa_type = python_type_to_pa(ft, nullable=n)
275 fields.append(pa.field(name, pa_type, nullable=n))
276 return pa.struct(fields)
278 return None
280 try:
281 field_types = _get_model_field_types(model)
282 except Exception:
283 return None
285 pa_fields = []
286 for field_name, field_type in field_types.items():
287 nullable = is_optional(field_type)
288 pa_type = python_type_to_pa(field_type, nullable=nullable)
289 if pa_type is None:
290 return None
291 pa_fields.append(pa.field(field_name, pa_type, nullable=nullable))
293 return pa.schema(pa_fields)