Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6import sqlglot
  7from sqlglot import expressions as exp
  8from sqlglot.errors import ParseError, SchemaError
  9from sqlglot.helper import dict_depth
 10from sqlglot.trie import in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dataframe.sql.types import StructType
 14    from sqlglot.dialects.dialect import DialectType
 15
 16    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 17
 18TABLE_ARGS = ("this", "db", "catalog")
 19
 20T = t.TypeVar("T")
 21
 22
 23class Schema(abc.ABC):
 24    """Abstract base class for database schemas"""
 25
 26    @abc.abstractmethod
 27    def add_table(
 28        self,
 29        table: exp.Table | str,
 30        column_mapping: t.Optional[ColumnMapping] = None,
 31        dialect: DialectType = None,
 32    ) -> None:
 33        """
 34        Register or update a table. Some implementing classes may require column information to also be provided.
 35
 36        Args:
 37            table: the `Table` expression instance or string representing the table.
 38            column_mapping: a column mapping that describes the structure of the table.
 39            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 40        """
 41
 42    @abc.abstractmethod
 43    def column_names(
 44        self,
 45        table: exp.Table | str,
 46        only_visible: bool = False,
 47        dialect: DialectType = None,
 48    ) -> t.List[str]:
 49        """
 50        Get the column names for a table.
 51
 52        Args:
 53            table: the `Table` expression instance.
 54            only_visible: whether to include invisible columns.
 55            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 56
 57        Returns:
 58            The list of column names.
 59        """
 60
 61    @abc.abstractmethod
 62    def get_column_type(
 63        self,
 64        table: exp.Table | str,
 65        column: exp.Column,
 66        dialect: DialectType = None,
 67    ) -> exp.DataType:
 68        """
 69        Get the `sqlglot.exp.DataType` type of a column in the schema.
 70
 71        Args:
 72            table: the source table.
 73            column: the target column.
 74            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 75
 76        Returns:
 77            The resulting column type.
 78        """
 79
 80    @property
 81    @abc.abstractmethod
 82    def supported_table_args(self) -> t.Tuple[str, ...]:
 83        """
 84        Table arguments this schema support, e.g. `("this", "db", "catalog")`
 85        """
 86
 87    @property
 88    def empty(self) -> bool:
 89        """Returns whether or not the schema is empty."""
 90        return True
 91
 92
 93class AbstractMappingSchema(t.Generic[T]):
 94    def __init__(
 95        self,
 96        mapping: t.Optional[t.Dict] = None,
 97    ) -> None:
 98        self.mapping = mapping or {}
 99        self.mapping_trie = new_trie(
100            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
101        )
102        self._supported_table_args: t.Tuple[str, ...] = tuple()
103
104    @property
105    def empty(self) -> bool:
106        return not self.mapping
107
108    def _depth(self) -> int:
109        return dict_depth(self.mapping)
110
111    @property
112    def supported_table_args(self) -> t.Tuple[str, ...]:
113        if not self._supported_table_args and self.mapping:
114            depth = self._depth()
115
116            if not depth:  # None
117                self._supported_table_args = tuple()
118            elif 1 <= depth <= 3:
119                self._supported_table_args = TABLE_ARGS[:depth]
120            else:
121                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
122
123        return self._supported_table_args
124
125    def table_parts(self, table: exp.Table) -> t.List[str]:
126        if isinstance(table.this, exp.ReadCSV):
127            return [table.this.name]
128        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
129
130    def find(
131        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
132    ) -> t.Optional[T]:
133        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
134        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
135
136        if value == 0:
137            return None
138
139        if value == 1:
140            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
141
142            if len(possibilities) == 1:
143                parts.extend(possibilities[0])
144            else:
145                message = ", ".join(".".join(parts) for parts in possibilities)
146                if raise_on_missing:
147                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
148                return None
149
150        return self.nested_get(parts, raise_on_missing=raise_on_missing)
151
152    def nested_get(
153        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
154    ) -> t.Optional[t.Any]:
155        return nested_get(
156            d or self.mapping,
157            *zip(self.supported_table_args, reversed(parts)),
158            raise_on_missing=raise_on_missing,
159        )
160
161
162class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
163    """
164    Schema based on a nested mapping.
165
166    Args:
167        schema: Mapping in one of the following forms:
168            1. {table: {col: type}}
169            2. {db: {table: {col: type}}}
170            3. {catalog: {db: {table: {col: type}}}}
171            4. None - Tables will be added later
172        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
173            are assumed to be visible. The nesting should mirror that of the schema:
174            1. {table: set(*cols)}}
175            2. {db: {table: set(*cols)}}}
176            3. {catalog: {db: {table: set(*cols)}}}}
177        dialect: The dialect to be used for custom type mappings & parsing string arguments.
178    """
179
180    def __init__(
181        self,
182        schema: t.Optional[t.Dict] = None,
183        visible: t.Optional[t.Dict] = None,
184        dialect: DialectType = None,
185    ) -> None:
186        self.dialect = dialect
187        self.visible = visible or {}
188        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
189
190        super().__init__(self._normalize(schema or {}))
191
192    @classmethod
193    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
194        return MappingSchema(
195            schema=mapping_schema.mapping,
196            visible=mapping_schema.visible,
197            dialect=mapping_schema.dialect,
198        )
199
200    def copy(self, **kwargs) -> MappingSchema:
201        return MappingSchema(
202            **{  # type: ignore
203                "schema": self.mapping.copy(),
204                "visible": self.visible.copy(),
205                "dialect": self.dialect,
206                **kwargs,
207            }
208        )
209
210    def add_table(
211        self,
212        table: exp.Table | str,
213        column_mapping: t.Optional[ColumnMapping] = None,
214        dialect: DialectType = None,
215    ) -> None:
216        """
217        Register or update a table. Updates are only performed if a new column mapping is provided.
218
219        Args:
220            table: the `Table` expression instance or string representing the table.
221            column_mapping: a column mapping that describes the structure of the table.
222            dialect: the SQL dialect that will be used to parse `table` if it's a string.
223        """
224        normalized_table = self._normalize_table(
225            self._ensure_table(table, dialect=dialect), dialect=dialect
226        )
227        normalized_column_mapping = {
228            self._normalize_name(key, dialect=dialect): value
229            for key, value in ensure_column_mapping(column_mapping).items()
230        }
231
232        schema = self.find(normalized_table, raise_on_missing=False)
233        if schema and not normalized_column_mapping:
234            return
235
236        parts = self.table_parts(normalized_table)
237
238        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
239        new_trie([parts], self.mapping_trie)
240
241    def column_names(
242        self,
243        table: exp.Table | str,
244        only_visible: bool = False,
245        dialect: DialectType = None,
246    ) -> t.List[str]:
247        normalized_table = self._normalize_table(
248            self._ensure_table(table, dialect=dialect), dialect=dialect
249        )
250
251        schema = self.find(normalized_table)
252        if schema is None:
253            return []
254
255        if not only_visible or not self.visible:
256            return list(schema)
257
258        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
259        return [col for col in schema if col in visible]
260
261    def get_column_type(
262        self,
263        table: exp.Table | str,
264        column: exp.Column,
265        dialect: DialectType = None,
266    ) -> exp.DataType:
267        normalized_table = self._normalize_table(
268            self._ensure_table(table, dialect=dialect), dialect=dialect
269        )
270        normalized_column_name = self._normalize_name(
271            column if isinstance(column, str) else column.this, dialect=dialect
272        )
273
274        table_schema = self.find(normalized_table, raise_on_missing=False)
275        if table_schema:
276            column_type = table_schema.get(normalized_column_name)
277
278            if isinstance(column_type, exp.DataType):
279                return column_type
280            elif isinstance(column_type, str):
281                return self._to_data_type(column_type.upper(), dialect=dialect)
282
283            raise SchemaError(f"Unknown column type '{column_type}'")
284
285        return exp.DataType.build("unknown")
286
287    def _normalize(self, schema: t.Dict) -> t.Dict:
288        """
289        Converts all identifiers in the schema into lowercase, unless they're quoted.
290
291        Args:
292            schema: the schema to normalize.
293
294        Returns:
295            The normalized schema mapping.
296        """
297        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
298
299        normalized_mapping: t.Dict = {}
300        for keys in flattened_schema:
301            columns = nested_get(schema, *zip(keys, keys))
302            assert columns is not None
303
304            normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys]
305            for column_name, column_type in columns.items():
306                nested_set(
307                    normalized_mapping,
308                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
309                    column_type,
310                )
311
312        return normalized_mapping
313
314    def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table:
315        normalized_table = table.copy()
316
317        for arg in TABLE_ARGS:
318            value = normalized_table.args.get(arg)
319            if isinstance(value, (str, exp.Identifier)):
320                normalized_table.set(arg, self._normalize_name(value, dialect=dialect))
321
322        return normalized_table
323
324    def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str:
325        dialect = dialect or self.dialect
326
327        try:
328            identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
329        except ParseError:
330            return name if isinstance(name, str) else name.name
331
332        return identifier.name if identifier.quoted else identifier.name.lower()
333
334    def _depth(self) -> int:
335        # The columns themselves are a mapping, but we don't want to include those
336        return super()._depth() - 1
337
338    def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
339        if isinstance(table, exp.Table):
340            return table
341
342        dialect = dialect or self.dialect
343        parsed_table = sqlglot.parse_one(table, read=dialect, into=exp.Table)
344
345        if not parsed_table:
346            in_dialect = f" in dialect {dialect}" if dialect else ""
347            raise SchemaError(f"Failed to parse table '{table}'{in_dialect}.")
348
349        return parsed_table
350
351    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
352        """
353        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
354
355        Args:
356            schema_type: the type we want to convert.
357            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
358
359        Returns:
360            The resulting expression type.
361        """
362        if schema_type not in self._type_mapping_cache:
363            dialect = dialect or self.dialect
364
365            try:
366                expression = exp.DataType.build(schema_type, dialect=dialect)
367                self._type_mapping_cache[schema_type] = expression
368            except AttributeError:
369                in_dialect = f" in dialect {dialect}" if dialect else ""
370                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
371
372        return self._type_mapping_cache[schema_type]
373
374
375def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
376    if isinstance(schema, Schema):
377        return schema
378
379    return MappingSchema(schema, dialect=dialect)
380
381
382def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
383    if mapping is None:
384        return {}
385    elif isinstance(mapping, dict):
386        return mapping
387    elif isinstance(mapping, str):
388        col_name_type_strs = [x.strip() for x in mapping.split(",")]
389        return {
390            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
391            for name_type_str in col_name_type_strs
392        }
393    # Check if mapping looks like a DataFrame StructType
394    elif hasattr(mapping, "simpleString"):
395        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
396    elif isinstance(mapping, list):
397        return {x.strip(): None for x in mapping}
398
399    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
400
401
402def flatten_schema(
403    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
404) -> t.List[t.List[str]]:
405    tables = []
406    keys = keys or []
407
408    for k, v in schema.items():
409        if depth >= 2:
410            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
411        elif depth == 1:
412            tables.append(keys + [k])
413
414    return tables
415
416
417def nested_get(
418    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
419) -> t.Optional[t.Any]:
420    """
421    Get a value for a nested dictionary.
422
423    Args:
424        d: the dictionary to search.
425        *path: tuples of (name, key), where:
426            `key` is the key in the dictionary to get.
427            `name` is a string to use in the error if `key` isn't found.
428
429    Returns:
430        The value or None if it doesn't exist.
431    """
432    for name, key in path:
433        d = d.get(key)  # type: ignore
434        if d is None:
435            if raise_on_missing:
436                name = "table" if name == "this" else name
437                raise ValueError(f"Unknown {name}: {key}")
438            return None
439
440    return d
441
442
443def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
444    """
445    In-place set a value for a nested dictionary
446
447    Example:
448        >>> nested_set({}, ["top_key", "second_key"], "value")
449        {'top_key': {'second_key': 'value'}}
450
451        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
452        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
453
454    Args:
455        d: dictionary to update.
456        keys: the keys that makeup the path to `value`.
457        value: the value to set in the dictionary for the given key path.
458
459    Returns:
460        The (possibly) updated dictionary.
461    """
462    if not keys:
463        return d
464
465    if len(keys) == 1:
466        d[keys[0]] = value
467        return d
468
469    subd = d
470    for key in keys[:-1]:
471        if key not in subd:
472            subd = subd.setdefault(key, {})
473        else:
474            subd = subd[key]
475
476    subd[keys[-1]] = value
477    return d
class Schema(abc.ABC):
24class Schema(abc.ABC):
25    """Abstract base class for database schemas"""
26
27    @abc.abstractmethod
28    def add_table(
29        self,
30        table: exp.Table | str,
31        column_mapping: t.Optional[ColumnMapping] = None,
32        dialect: DialectType = None,
33    ) -> None:
34        """
35        Register or update a table. Some implementing classes may require column information to also be provided.
36
37        Args:
38            table: the `Table` expression instance or string representing the table.
39            column_mapping: a column mapping that describes the structure of the table.
40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
41        """
42
43    @abc.abstractmethod
44    def column_names(
45        self,
46        table: exp.Table | str,
47        only_visible: bool = False,
48        dialect: DialectType = None,
49    ) -> t.List[str]:
50        """
51        Get the column names for a table.
52
53        Args:
54            table: the `Table` expression instance.
55            only_visible: whether to include invisible columns.
56            dialect: the SQL dialect that will be used to parse `table` if it's a string.
57
58        Returns:
59            The list of column names.
60        """
61
62    @abc.abstractmethod
63    def get_column_type(
64        self,
65        table: exp.Table | str,
66        column: exp.Column,
67        dialect: DialectType = None,
68    ) -> exp.DataType:
69        """
70        Get the `sqlglot.exp.DataType` type of a column in the schema.
71
72        Args:
73            table: the source table.
74            column: the target column.
75            dialect: the SQL dialect that will be used to parse `table` if it's a string.
76
77        Returns:
78            The resulting column type.
79        """
80
81    @property
82    @abc.abstractmethod
83    def supported_table_args(self) -> t.Tuple[str, ...]:
84        """
85        Table arguments this schema support, e.g. `("this", "db", "catalog")`
86        """
87
88    @property
89    def empty(self) -> bool:
90        """Returns whether or not the schema is empty."""
91        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> None:
27    @abc.abstractmethod
28    def add_table(
29        self,
30        table: exp.Table | str,
31        column_mapping: t.Optional[ColumnMapping] = None,
32        dialect: DialectType = None,
33    ) -> None:
34        """
35        Register or update a table. Some implementing classes may require column information to also be provided.
36
37        Args:
38            table: the `Table` expression instance or string representing the table.
39            column_mapping: a column mapping that describes the structure of the table.
40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
41        """

Register or update a table. Some implementing classes may require column information to also be provided.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> List[str]:
43    @abc.abstractmethod
44    def column_names(
45        self,
46        table: exp.Table | str,
47        only_visible: bool = False,
48        dialect: DialectType = None,
49    ) -> t.List[str]:
50        """
51        Get the column names for a table.
52
53        Args:
54            table: the `Table` expression instance.
55            only_visible: whether to include invisible columns.
56            dialect: the SQL dialect that will be used to parse `table` if it's a string.
57
58        Returns:
59            The list of column names.
60        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.expressions.DataType:
62    @abc.abstractmethod
63    def get_column_type(
64        self,
65        table: exp.Table | str,
66        column: exp.Column,
67        dialect: DialectType = None,
68    ) -> exp.DataType:
69        """
70        Get the `sqlglot.exp.DataType` type of a column in the schema.
71
72        Args:
73            table: the source table.
74            column: the target column.
75            dialect: the SQL dialect that will be used to parse `table` if it's a string.
76
77        Returns:
78            The resulting column type.
79        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The resulting column type.

supported_table_args: Tuple[str, ...]

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool

Returns whether or not the schema is empty.

class AbstractMappingSchema(typing.Generic[~T]):
 94class AbstractMappingSchema(t.Generic[T]):
 95    def __init__(
 96        self,
 97        mapping: t.Optional[t.Dict] = None,
 98    ) -> None:
 99        self.mapping = mapping or {}
100        self.mapping_trie = new_trie(
101            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
102        )
103        self._supported_table_args: t.Tuple[str, ...] = tuple()
104
105    @property
106    def empty(self) -> bool:
107        return not self.mapping
108
109    def _depth(self) -> int:
110        return dict_depth(self.mapping)
111
112    @property
113    def supported_table_args(self) -> t.Tuple[str, ...]:
114        if not self._supported_table_args and self.mapping:
115            depth = self._depth()
116
117            if not depth:  # None
118                self._supported_table_args = tuple()
119            elif 1 <= depth <= 3:
120                self._supported_table_args = TABLE_ARGS[:depth]
121            else:
122                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
123
124        return self._supported_table_args
125
126    def table_parts(self, table: exp.Table) -> t.List[str]:
127        if isinstance(table.this, exp.ReadCSV):
128            return [table.this.name]
129        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
130
131    def find(
132        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
133    ) -> t.Optional[T]:
134        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
135        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
136
137        if value == 0:
138            return None
139
140        if value == 1:
141            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
142
143            if len(possibilities) == 1:
144                parts.extend(possibilities[0])
145            else:
146                message = ", ".join(".".join(parts) for parts in possibilities)
147                if raise_on_missing:
148                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
149                return None
150
151        return self.nested_get(parts, raise_on_missing=raise_on_missing)
152
153    def nested_get(
154        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
155    ) -> t.Optional[t.Any]:
156        return nested_get(
157            d or self.mapping,
158            *zip(self.supported_table_args, reversed(parts)),
159            raise_on_missing=raise_on_missing,
160        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

AbstractMappingSchema(mapping: Optional[Dict] = None)
 95    def __init__(
 96        self,
 97        mapping: t.Optional[t.Dict] = None,
 98    ) -> None:
 99        self.mapping = mapping or {}
100        self.mapping_trie = new_trie(
101            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
102        )
103        self._supported_table_args: t.Tuple[str, ...] = tuple()
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
126    def table_parts(self, table: exp.Table) -> t.List[str]:
127        if isinstance(table.this, exp.ReadCSV):
128            return [table.this.name]
129        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[~T]:
131    def find(
132        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
133    ) -> t.Optional[T]:
134        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
135        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
136
137        if value == 0:
138            return None
139
140        if value == 1:
141            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
142
143            if len(possibilities) == 1:
144                parts.extend(possibilities[0])
145            else:
146                message = ", ".join(".".join(parts) for parts in possibilities)
147                if raise_on_missing:
148                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
149                return None
150
151        return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
153    def nested_get(
154        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
155    ) -> t.Optional[t.Any]:
156        return nested_get(
157            d or self.mapping,
158            *zip(self.supported_table_args, reversed(parts)),
159            raise_on_missing=raise_on_missing,
160        )
163class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
164    """
165    Schema based on a nested mapping.
166
167    Args:
168        schema: Mapping in one of the following forms:
169            1. {table: {col: type}}
170            2. {db: {table: {col: type}}}
171            3. {catalog: {db: {table: {col: type}}}}
172            4. None - Tables will be added later
173        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
174            are assumed to be visible. The nesting should mirror that of the schema:
175            1. {table: set(*cols)}}
176            2. {db: {table: set(*cols)}}}
177            3. {catalog: {db: {table: set(*cols)}}}}
178        dialect: The dialect to be used for custom type mappings & parsing string arguments.
179    """
180
181    def __init__(
182        self,
183        schema: t.Optional[t.Dict] = None,
184        visible: t.Optional[t.Dict] = None,
185        dialect: DialectType = None,
186    ) -> None:
187        self.dialect = dialect
188        self.visible = visible or {}
189        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
190
191        super().__init__(self._normalize(schema or {}))
192
193    @classmethod
194    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
195        return MappingSchema(
196            schema=mapping_schema.mapping,
197            visible=mapping_schema.visible,
198            dialect=mapping_schema.dialect,
199        )
200
201    def copy(self, **kwargs) -> MappingSchema:
202        return MappingSchema(
203            **{  # type: ignore
204                "schema": self.mapping.copy(),
205                "visible": self.visible.copy(),
206                "dialect": self.dialect,
207                **kwargs,
208            }
209        )
210
211    def add_table(
212        self,
213        table: exp.Table | str,
214        column_mapping: t.Optional[ColumnMapping] = None,
215        dialect: DialectType = None,
216    ) -> None:
217        """
218        Register or update a table. Updates are only performed if a new column mapping is provided.
219
220        Args:
221            table: the `Table` expression instance or string representing the table.
222            column_mapping: a column mapping that describes the structure of the table.
223            dialect: the SQL dialect that will be used to parse `table` if it's a string.
224        """
225        normalized_table = self._normalize_table(
226            self._ensure_table(table, dialect=dialect), dialect=dialect
227        )
228        normalized_column_mapping = {
229            self._normalize_name(key, dialect=dialect): value
230            for key, value in ensure_column_mapping(column_mapping).items()
231        }
232
233        schema = self.find(normalized_table, raise_on_missing=False)
234        if schema and not normalized_column_mapping:
235            return
236
237        parts = self.table_parts(normalized_table)
238
239        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
240        new_trie([parts], self.mapping_trie)
241
242    def column_names(
243        self,
244        table: exp.Table | str,
245        only_visible: bool = False,
246        dialect: DialectType = None,
247    ) -> t.List[str]:
248        normalized_table = self._normalize_table(
249            self._ensure_table(table, dialect=dialect), dialect=dialect
250        )
251
252        schema = self.find(normalized_table)
253        if schema is None:
254            return []
255
256        if not only_visible or not self.visible:
257            return list(schema)
258
259        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
260        return [col for col in schema if col in visible]
261
262    def get_column_type(
263        self,
264        table: exp.Table | str,
265        column: exp.Column,
266        dialect: DialectType = None,
267    ) -> exp.DataType:
268        normalized_table = self._normalize_table(
269            self._ensure_table(table, dialect=dialect), dialect=dialect
270        )
271        normalized_column_name = self._normalize_name(
272            column if isinstance(column, str) else column.this, dialect=dialect
273        )
274
275        table_schema = self.find(normalized_table, raise_on_missing=False)
276        if table_schema:
277            column_type = table_schema.get(normalized_column_name)
278
279            if isinstance(column_type, exp.DataType):
280                return column_type
281            elif isinstance(column_type, str):
282                return self._to_data_type(column_type.upper(), dialect=dialect)
283
284            raise SchemaError(f"Unknown column type '{column_type}'")
285
286        return exp.DataType.build("unknown")
287
288    def _normalize(self, schema: t.Dict) -> t.Dict:
289        """
290        Converts all identifiers in the schema into lowercase, unless they're quoted.
291
292        Args:
293            schema: the schema to normalize.
294
295        Returns:
296            The normalized schema mapping.
297        """
298        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
299
300        normalized_mapping: t.Dict = {}
301        for keys in flattened_schema:
302            columns = nested_get(schema, *zip(keys, keys))
303            assert columns is not None
304
305            normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys]
306            for column_name, column_type in columns.items():
307                nested_set(
308                    normalized_mapping,
309                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
310                    column_type,
311                )
312
313        return normalized_mapping
314
315    def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table:
316        normalized_table = table.copy()
317
318        for arg in TABLE_ARGS:
319            value = normalized_table.args.get(arg)
320            if isinstance(value, (str, exp.Identifier)):
321                normalized_table.set(arg, self._normalize_name(value, dialect=dialect))
322
323        return normalized_table
324
325    def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str:
326        dialect = dialect or self.dialect
327
328        try:
329            identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
330        except ParseError:
331            return name if isinstance(name, str) else name.name
332
333        return identifier.name if identifier.quoted else identifier.name.lower()
334
335    def _depth(self) -> int:
336        # The columns themselves are a mapping, but we don't want to include those
337        return super()._depth() - 1
338
339    def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
340        if isinstance(table, exp.Table):
341            return table
342
343        dialect = dialect or self.dialect
344        parsed_table = sqlglot.parse_one(table, read=dialect, into=exp.Table)
345
346        if not parsed_table:
347            in_dialect = f" in dialect {dialect}" if dialect else ""
348            raise SchemaError(f"Failed to parse table '{table}'{in_dialect}.")
349
350        return parsed_table
351
352    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
353        """
354        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
355
356        Args:
357            schema_type: the type we want to convert.
358            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
359
360        Returns:
361            The resulting expression type.
362        """
363        if schema_type not in self._type_mapping_cache:
364            dialect = dialect or self.dialect
365
366            try:
367                expression = exp.DataType.build(schema_type, dialect=dialect)
368                self._type_mapping_cache[schema_type] = expression
369            except AttributeError:
370                in_dialect = f" in dialect {dialect}" if dialect else ""
371                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
372
373        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None)
181    def __init__(
182        self,
183        schema: t.Optional[t.Dict] = None,
184        visible: t.Optional[t.Dict] = None,
185        dialect: DialectType = None,
186    ) -> None:
187        self.dialect = dialect
188        self.visible = visible or {}
189        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
190
191        super().__init__(self._normalize(schema or {}))
@classmethod
def from_mapping_schema( cls, mapping_schema: sqlglot.schema.MappingSchema) -> sqlglot.schema.MappingSchema:
193    @classmethod
194    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
195        return MappingSchema(
196            schema=mapping_schema.mapping,
197            visible=mapping_schema.visible,
198            dialect=mapping_schema.dialect,
199        )
def copy(self, **kwargs) -> sqlglot.schema.MappingSchema:
201    def copy(self, **kwargs) -> MappingSchema:
202        return MappingSchema(
203            **{  # type: ignore
204                "schema": self.mapping.copy(),
205                "visible": self.visible.copy(),
206                "dialect": self.dialect,
207                **kwargs,
208            }
209        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> None:
211    def add_table(
212        self,
213        table: exp.Table | str,
214        column_mapping: t.Optional[ColumnMapping] = None,
215        dialect: DialectType = None,
216    ) -> None:
217        """
218        Register or update a table. Updates are only performed if a new column mapping is provided.
219
220        Args:
221            table: the `Table` expression instance or string representing the table.
222            column_mapping: a column mapping that describes the structure of the table.
223            dialect: the SQL dialect that will be used to parse `table` if it's a string.
224        """
225        normalized_table = self._normalize_table(
226            self._ensure_table(table, dialect=dialect), dialect=dialect
227        )
228        normalized_column_mapping = {
229            self._normalize_name(key, dialect=dialect): value
230            for key, value in ensure_column_mapping(column_mapping).items()
231        }
232
233        schema = self.find(normalized_table, raise_on_missing=False)
234        if schema and not normalized_column_mapping:
235            return
236
237        parts = self.table_parts(normalized_table)
238
239        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
240        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> List[str]:
242    def column_names(
243        self,
244        table: exp.Table | str,
245        only_visible: bool = False,
246        dialect: DialectType = None,
247    ) -> t.List[str]:
248        normalized_table = self._normalize_table(
249            self._ensure_table(table, dialect=dialect), dialect=dialect
250        )
251
252        schema = self.find(normalized_table)
253        if schema is None:
254            return []
255
256        if not only_visible or not self.visible:
257            return list(schema)
258
259        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
260        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The list of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.expressions.DataType:
262    def get_column_type(
263        self,
264        table: exp.Table | str,
265        column: exp.Column,
266        dialect: DialectType = None,
267    ) -> exp.DataType:
268        normalized_table = self._normalize_table(
269            self._ensure_table(table, dialect=dialect), dialect=dialect
270        )
271        normalized_column_name = self._normalize_name(
272            column if isinstance(column, str) else column.this, dialect=dialect
273        )
274
275        table_schema = self.find(normalized_table, raise_on_missing=False)
276        if table_schema:
277            column_type = table_schema.get(normalized_column_name)
278
279            if isinstance(column_type, exp.DataType):
280                return column_type
281            elif isinstance(column_type, str):
282                return self._to_data_type(column_type.upper(), dialect=dialect)
283
284            raise SchemaError(f"Unknown column type '{column_type}'")
285
286        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The resulting column type.

def ensure_schema( schema: Any, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.schema.Schema:
376def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
377    if isinstance(schema, Schema):
378        return schema
379
380    return MappingSchema(schema, dialect=dialect)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
383def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
384    if mapping is None:
385        return {}
386    elif isinstance(mapping, dict):
387        return mapping
388    elif isinstance(mapping, str):
389        col_name_type_strs = [x.strip() for x in mapping.split(",")]
390        return {
391            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
392            for name_type_str in col_name_type_strs
393        }
394    # Check if mapping looks like a DataFrame StructType
395    elif hasattr(mapping, "simpleString"):
396        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
397    elif isinstance(mapping, list):
398        return {x.strip(): None for x in mapping}
399
400    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
403def flatten_schema(
404    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
405) -> t.List[t.List[str]]:
406    tables = []
407    keys = keys or []
408
409    for k, v in schema.items():
410        if depth >= 2:
411            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
412        elif depth == 1:
413            tables.append(keys + [k])
414
415    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
418def nested_get(
419    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
420) -> t.Optional[t.Any]:
421    """
422    Get a value for a nested dictionary.
423
424    Args:
425        d: the dictionary to search.
426        *path: tuples of (name, key), where:
427            `key` is the key in the dictionary to get.
428            `name` is a string to use in the error if `key` isn't found.
429
430    Returns:
431        The value or None if it doesn't exist.
432    """
433    for name, key in path:
434        d = d.get(key)  # type: ignore
435        if d is None:
436            if raise_on_missing:
437                name = "table" if name == "this" else name
438                raise ValueError(f"Unknown {name}: {key}")
439            return None
440
441    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
444def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
445    """
446    In-place set a value for a nested dictionary
447
448    Example:
449        >>> nested_set({}, ["top_key", "second_key"], "value")
450        {'top_key': {'second_key': 'value'}}
451
452        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
453        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
454
455    Args:
456        d: dictionary to update.
457        keys: the keys that makeup the path to `value`.
458        value: the value to set in the dictionary for the given key path.
459
460    Returns:
461        The (possibly) updated dictionary.
462    """
463    if not keys:
464        return d
465
466    if len(keys) == 1:
467        d[keys[0]] = value
468        return d
469
470    subd = d
471    for key in keys[:-1]:
472        if key not in subd:
473            subd = subd.setdefault(key, {})
474        else:
475            subd = subd[key]
476
477    subd[keys[-1]] = value
478    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.