Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6import sqlglot
  7from sqlglot import expressions as exp
  8from sqlglot.errors import ParseError, SchemaError
  9from sqlglot.helper import dict_depth
 10from sqlglot.trie import in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dataframe.sql.types import StructType
 14    from sqlglot.dialects.dialect import DialectType
 15
 16    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 17
 18TABLE_ARGS = ("this", "db", "catalog")
 19
 20T = t.TypeVar("T")
 21
 22
 23class Schema(abc.ABC):
 24    """Abstract base class for database schemas"""
 25
 26    @abc.abstractmethod
 27    def add_table(
 28        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
 29    ) -> None:
 30        """
 31        Register or update a table. Some implementing classes may require column information to also be provided.
 32
 33        Args:
 34            table: table expression instance or string representing the table.
 35            column_mapping: a column mapping that describes the structure of the table.
 36        """
 37
 38    @abc.abstractmethod
 39    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
 40        """
 41        Get the column names for a table.
 42
 43        Args:
 44            table: the `Table` expression instance.
 45            only_visible: whether to include invisible columns.
 46
 47        Returns:
 48            The list of column names.
 49        """
 50
 51    @abc.abstractmethod
 52    def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
 53        """
 54        Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
 55
 56        Args:
 57            table: the source table.
 58            column: the target column.
 59
 60        Returns:
 61            The resulting column type.
 62        """
 63
 64    @property
 65    def supported_table_args(self) -> t.Tuple[str, ...]:
 66        """
 67        Table arguments this schema support, e.g. `("this", "db", "catalog")`
 68        """
 69        raise NotImplementedError
 70
 71    @property
 72    def empty(self) -> bool:
 73        """Returns whether or not the schema is empty."""
 74        return True
 75
 76
 77class AbstractMappingSchema(t.Generic[T]):
 78    def __init__(
 79        self,
 80        mapping: dict | None = None,
 81    ) -> None:
 82        self.mapping = mapping or {}
 83        self.mapping_trie = new_trie(
 84            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
 85        )
 86        self._supported_table_args: t.Tuple[str, ...] = tuple()
 87
 88    @property
 89    def empty(self) -> bool:
 90        return not self.mapping
 91
 92    def _depth(self) -> int:
 93        return dict_depth(self.mapping)
 94
 95    @property
 96    def supported_table_args(self) -> t.Tuple[str, ...]:
 97        if not self._supported_table_args and self.mapping:
 98            depth = self._depth()
 99
100            if not depth:  # None
101                self._supported_table_args = tuple()
102            elif 1 <= depth <= 3:
103                self._supported_table_args = TABLE_ARGS[:depth]
104            else:
105                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
106
107        return self._supported_table_args
108
109    def table_parts(self, table: exp.Table) -> t.List[str]:
110        if isinstance(table.this, exp.ReadCSV):
111            return [table.this.name]
112        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
113
114    def find(
115        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
116    ) -> t.Optional[T]:
117        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
118        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
119
120        if value == 0:
121            return None
122        elif value == 1:
123            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
124            if len(possibilities) == 1:
125                parts.extend(possibilities[0])
126            else:
127                message = ", ".join(".".join(parts) for parts in possibilities)
128                if raise_on_missing:
129                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
130                return None
131        return self.nested_get(parts, raise_on_missing=raise_on_missing)
132
133    def nested_get(
134        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
135    ) -> t.Optional[t.Any]:
136        return nested_get(
137            d or self.mapping,
138            *zip(self.supported_table_args, reversed(parts)),
139            raise_on_missing=raise_on_missing,
140        )
141
142
143class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
144    """
145    Schema based on a nested mapping.
146
147    Args:
148        schema (dict): Mapping in one of the following forms:
149            1. {table: {col: type}}
150            2. {db: {table: {col: type}}}
151            3. {catalog: {db: {table: {col: type}}}}
152            4. None - Tables will be added later
153        visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
154            are assumed to be visible. The nesting should mirror that of the schema:
155            1. {table: set(*cols)}}
156            2. {db: {table: set(*cols)}}}
157            3. {catalog: {db: {table: set(*cols)}}}}
158        dialect (str): The dialect to be used for custom type mappings.
159    """
160
161    def __init__(
162        self,
163        schema: t.Optional[t.Dict] = None,
164        visible: t.Optional[t.Dict] = None,
165        dialect: DialectType = None,
166    ) -> None:
167        self.dialect = dialect
168        self.visible = visible or {}
169        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
170        super().__init__(self._normalize(schema or {}))
171
172    @classmethod
173    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
174        return MappingSchema(
175            schema=mapping_schema.mapping,
176            visible=mapping_schema.visible,
177            dialect=mapping_schema.dialect,
178        )
179
180    def copy(self, **kwargs) -> MappingSchema:
181        return MappingSchema(
182            **{  # type: ignore
183                "schema": self.mapping.copy(),
184                "visible": self.visible.copy(),
185                "dialect": self.dialect,
186                **kwargs,
187            }
188        )
189
190    def add_table(
191        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
192    ) -> None:
193        """
194        Register or update a table. Updates are only performed if a new column mapping is provided.
195
196        Args:
197            table: the `Table` expression instance or string representing the table.
198            column_mapping: a column mapping that describes the structure of the table.
199        """
200        normalized_table = self._normalize_table(self._ensure_table(table))
201        normalized_column_mapping = {
202            self._normalize_name(key): value
203            for key, value in ensure_column_mapping(column_mapping).items()
204        }
205
206        schema = self.find(normalized_table, raise_on_missing=False)
207        if schema and not normalized_column_mapping:
208            return
209
210        parts = self.table_parts(normalized_table)
211
212        nested_set(
213            self.mapping,
214            tuple(reversed(parts)),
215            normalized_column_mapping,
216        )
217        new_trie([parts], self.mapping_trie)
218
219    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
220        table_ = self._normalize_table(self._ensure_table(table))
221        schema = self.find(table_)
222
223        if schema is None:
224            return []
225
226        if not only_visible or not self.visible:
227            return list(schema)
228
229        visible = self.nested_get(self.table_parts(table_), self.visible)
230        return [col for col in schema if col in visible]  # type: ignore
231
232    def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
233        column_name = self._normalize_name(column if isinstance(column, str) else column.this)
234        table_ = self._normalize_table(self._ensure_table(table))
235
236        table_schema = self.find(table_, raise_on_missing=False)
237        if table_schema:
238            column_type = table_schema.get(column_name)
239
240            if isinstance(column_type, exp.DataType):
241                return column_type
242            elif isinstance(column_type, str):
243                return self._to_data_type(column_type.upper())
244            raise SchemaError(f"Unknown column type '{column_type}'")
245
246        return exp.DataType.build("unknown")
247
248    def _normalize(self, schema: t.Dict) -> t.Dict:
249        """
250        Converts all identifiers in the schema into lowercase, unless they're quoted.
251
252        Args:
253            schema: the schema to normalize.
254
255        Returns:
256            The normalized schema mapping.
257        """
258        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
259
260        normalized_mapping: t.Dict = {}
261        for keys in flattened_schema:
262            columns = nested_get(schema, *zip(keys, keys))
263            assert columns is not None
264
265            normalized_keys = [self._normalize_name(key) for key in keys]
266            for column_name, column_type in columns.items():
267                nested_set(
268                    normalized_mapping,
269                    normalized_keys + [self._normalize_name(column_name)],
270                    column_type,
271                )
272
273        return normalized_mapping
274
275    def _normalize_table(self, table: exp.Table) -> exp.Table:
276        normalized_table = table.copy()
277        for arg in TABLE_ARGS:
278            value = normalized_table.args.get(arg)
279            if isinstance(value, (str, exp.Identifier)):
280                normalized_table.set(arg, self._normalize_name(value))
281
282        return normalized_table
283
284    def _normalize_name(self, name: str | exp.Identifier) -> str:
285        try:
286            identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier)
287        except ParseError:
288            return name if isinstance(name, str) else name.name
289
290        return identifier.name if identifier.quoted else identifier.name.lower()
291
292    def _depth(self) -> int:
293        # The columns themselves are a mapping, but we don't want to include those
294        return super()._depth() - 1
295
296    def _ensure_table(self, table: exp.Table | str) -> exp.Table:
297        if isinstance(table, exp.Table):
298            return table
299
300        table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table)
301        if not table_:
302            raise SchemaError(f"Not a valid table '{table}'")
303
304        return table_
305
306    def _to_data_type(self, schema_type: str) -> exp.DataType:
307        """
308        Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
309
310        Args:
311            schema_type: the type we want to convert.
312
313        Returns:
314            The resulting expression type.
315        """
316        if schema_type not in self._type_mapping_cache:
317            try:
318                expression = exp.DataType.build(schema_type, dialect=self.dialect)
319                self._type_mapping_cache[schema_type] = expression
320            except AttributeError:
321                raise SchemaError(f"Failed to convert type {schema_type}")
322
323        return self._type_mapping_cache[schema_type]
324
325
326def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
327    if isinstance(schema, Schema):
328        return schema
329
330    return MappingSchema(schema, dialect=dialect)
331
332
333def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
334    if isinstance(mapping, dict):
335        return mapping
336    elif isinstance(mapping, str):
337        col_name_type_strs = [x.strip() for x in mapping.split(",")]
338        return {
339            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
340            for name_type_str in col_name_type_strs
341        }
342    # Check if mapping looks like a DataFrame StructType
343    elif hasattr(mapping, "simpleString"):
344        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}  # type: ignore
345    elif isinstance(mapping, list):
346        return {x.strip(): None for x in mapping}
347    elif mapping is None:
348        return {}
349    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
350
351
352def flatten_schema(
353    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
354) -> t.List[t.List[str]]:
355    tables = []
356    keys = keys or []
357
358    for k, v in schema.items():
359        if depth >= 2:
360            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
361        elif depth == 1:
362            tables.append(keys + [k])
363    return tables
364
365
366def nested_get(
367    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
368) -> t.Optional[t.Any]:
369    """
370    Get a value for a nested dictionary.
371
372    Args:
373        d: the dictionary to search.
374        *path: tuples of (name, key), where:
375            `key` is the key in the dictionary to get.
376            `name` is a string to use in the error if `key` isn't found.
377
378    Returns:
379        The value or None if it doesn't exist.
380    """
381    for name, key in path:
382        d = d.get(key)  # type: ignore
383        if d is None:
384            if raise_on_missing:
385                name = "table" if name == "this" else name
386                raise ValueError(f"Unknown {name}: {key}")
387            return None
388    return d
389
390
391def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
392    """
393    In-place set a value for a nested dictionary
394
395    Example:
396        >>> nested_set({}, ["top_key", "second_key"], "value")
397        {'top_key': {'second_key': 'value'}}
398
399        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
400        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
401
402    Args:
403        d: dictionary to update.
404        keys: the keys that makeup the path to `value`.
405        value: the value to set in the dictionary for the given key path.
406
407    Returns:
408        The (possibly) updated dictionary.
409    """
410    if not keys:
411        return d
412
413    if len(keys) == 1:
414        d[keys[0]] = value
415        return d
416
417    subd = d
418    for key in keys[:-1]:
419        if key not in subd:
420            subd = subd.setdefault(key, {})
421        else:
422            subd = subd[key]
423
424    subd[keys[-1]] = value
425    return d
class Schema(abc.ABC):
24class Schema(abc.ABC):
25    """Abstract base class for database schemas"""
26
27    @abc.abstractmethod
28    def add_table(
29        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
30    ) -> None:
31        """
32        Register or update a table. Some implementing classes may require column information to also be provided.
33
34        Args:
35            table: table expression instance or string representing the table.
36            column_mapping: a column mapping that describes the structure of the table.
37        """
38
39    @abc.abstractmethod
40    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
41        """
42        Get the column names for a table.
43
44        Args:
45            table: the `Table` expression instance.
46            only_visible: whether to include invisible columns.
47
48        Returns:
49            The list of column names.
50        """
51
52    @abc.abstractmethod
53    def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
54        """
55        Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
56
57        Args:
58            table: the source table.
59            column: the target column.
60
61        Returns:
62            The resulting column type.
63        """
64
65    @property
66    def supported_table_args(self) -> t.Tuple[str, ...]:
67        """
68        Table arguments this schema support, e.g. `("this", "db", "catalog")`
69        """
70        raise NotImplementedError
71
72    @property
73    def empty(self) -> bool:
74        """Returns whether or not the schema is empty."""
75        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None) -> None:
27    @abc.abstractmethod
28    def add_table(
29        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
30    ) -> None:
31        """
32        Register or update a table. Some implementing classes may require column information to also be provided.
33
34        Args:
35            table: table expression instance or string representing the table.
36            column_mapping: a column mapping that describes the structure of the table.
37        """

Register or update a table. Some implementing classes may require column information to also be provided.

Arguments:
  • table: table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False) -> List[str]:
39    @abc.abstractmethod
40    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
41        """
42        Get the column names for a table.
43
44        Args:
45            table: the `Table` expression instance.
46            only_visible: whether to include invisible columns.
47
48        Returns:
49            The list of column names.
50        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column) -> sqlglot.expressions.DataType:
52    @abc.abstractmethod
53    def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
54        """
55        Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
56
57        Args:
58            table: the source table.
59            column: the target column.
60
61        Returns:
62            The resulting column type.
63        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
Returns:

The resulting column type.

supported_table_args: Tuple[str, ...]

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool

Returns whether or not the schema is empty.

class AbstractMappingSchema(typing.Generic[~T]):
 78class AbstractMappingSchema(t.Generic[T]):
 79    def __init__(
 80        self,
 81        mapping: dict | None = None,
 82    ) -> None:
 83        self.mapping = mapping or {}
 84        self.mapping_trie = new_trie(
 85            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
 86        )
 87        self._supported_table_args: t.Tuple[str, ...] = tuple()
 88
 89    @property
 90    def empty(self) -> bool:
 91        return not self.mapping
 92
 93    def _depth(self) -> int:
 94        return dict_depth(self.mapping)
 95
 96    @property
 97    def supported_table_args(self) -> t.Tuple[str, ...]:
 98        if not self._supported_table_args and self.mapping:
 99            depth = self._depth()
100
101            if not depth:  # None
102                self._supported_table_args = tuple()
103            elif 1 <= depth <= 3:
104                self._supported_table_args = TABLE_ARGS[:depth]
105            else:
106                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
107
108        return self._supported_table_args
109
110    def table_parts(self, table: exp.Table) -> t.List[str]:
111        if isinstance(table.this, exp.ReadCSV):
112            return [table.this.name]
113        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
114
115    def find(
116        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
117    ) -> t.Optional[T]:
118        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
119        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
120
121        if value == 0:
122            return None
123        elif value == 1:
124            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
125            if len(possibilities) == 1:
126                parts.extend(possibilities[0])
127            else:
128                message = ", ".join(".".join(parts) for parts in possibilities)
129                if raise_on_missing:
130                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
131                return None
132        return self.nested_get(parts, raise_on_missing=raise_on_missing)
133
134    def nested_get(
135        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
136    ) -> t.Optional[t.Any]:
137        return nested_get(
138            d or self.mapping,
139            *zip(self.supported_table_args, reversed(parts)),
140            raise_on_missing=raise_on_missing,
141        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

AbstractMappingSchema(mapping: dict | None = None)
79    def __init__(
80        self,
81        mapping: dict | None = None,
82    ) -> None:
83        self.mapping = mapping or {}
84        self.mapping_trie = new_trie(
85            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
86        )
87        self._supported_table_args: t.Tuple[str, ...] = tuple()
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
110    def table_parts(self, table: exp.Table) -> t.List[str]:
111        if isinstance(table.this, exp.ReadCSV):
112            return [table.this.name]
113        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[~T]:
115    def find(
116        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
117    ) -> t.Optional[T]:
118        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
119        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
120
121        if value == 0:
122            return None
123        elif value == 1:
124            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
125            if len(possibilities) == 1:
126                parts.extend(possibilities[0])
127            else:
128                message = ", ".join(".".join(parts) for parts in possibilities)
129                if raise_on_missing:
130                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
131                return None
132        return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
134    def nested_get(
135        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
136    ) -> t.Optional[t.Any]:
137        return nested_get(
138            d or self.mapping,
139            *zip(self.supported_table_args, reversed(parts)),
140            raise_on_missing=raise_on_missing,
141        )
144class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
145    """
146    Schema based on a nested mapping.
147
148    Args:
149        schema (dict): Mapping in one of the following forms:
150            1. {table: {col: type}}
151            2. {db: {table: {col: type}}}
152            3. {catalog: {db: {table: {col: type}}}}
153            4. None - Tables will be added later
154        visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
155            are assumed to be visible. The nesting should mirror that of the schema:
156            1. {table: set(*cols)}}
157            2. {db: {table: set(*cols)}}}
158            3. {catalog: {db: {table: set(*cols)}}}}
159        dialect (str): The dialect to be used for custom type mappings.
160    """
161
162    def __init__(
163        self,
164        schema: t.Optional[t.Dict] = None,
165        visible: t.Optional[t.Dict] = None,
166        dialect: DialectType = None,
167    ) -> None:
168        self.dialect = dialect
169        self.visible = visible or {}
170        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
171        super().__init__(self._normalize(schema or {}))
172
173    @classmethod
174    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
175        return MappingSchema(
176            schema=mapping_schema.mapping,
177            visible=mapping_schema.visible,
178            dialect=mapping_schema.dialect,
179        )
180
181    def copy(self, **kwargs) -> MappingSchema:
182        return MappingSchema(
183            **{  # type: ignore
184                "schema": self.mapping.copy(),
185                "visible": self.visible.copy(),
186                "dialect": self.dialect,
187                **kwargs,
188            }
189        )
190
191    def add_table(
192        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
193    ) -> None:
194        """
195        Register or update a table. Updates are only performed if a new column mapping is provided.
196
197        Args:
198            table: the `Table` expression instance or string representing the table.
199            column_mapping: a column mapping that describes the structure of the table.
200        """
201        normalized_table = self._normalize_table(self._ensure_table(table))
202        normalized_column_mapping = {
203            self._normalize_name(key): value
204            for key, value in ensure_column_mapping(column_mapping).items()
205        }
206
207        schema = self.find(normalized_table, raise_on_missing=False)
208        if schema and not normalized_column_mapping:
209            return
210
211        parts = self.table_parts(normalized_table)
212
213        nested_set(
214            self.mapping,
215            tuple(reversed(parts)),
216            normalized_column_mapping,
217        )
218        new_trie([parts], self.mapping_trie)
219
220    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
221        table_ = self._normalize_table(self._ensure_table(table))
222        schema = self.find(table_)
223
224        if schema is None:
225            return []
226
227        if not only_visible or not self.visible:
228            return list(schema)
229
230        visible = self.nested_get(self.table_parts(table_), self.visible)
231        return [col for col in schema if col in visible]  # type: ignore
232
233    def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
234        column_name = self._normalize_name(column if isinstance(column, str) else column.this)
235        table_ = self._normalize_table(self._ensure_table(table))
236
237        table_schema = self.find(table_, raise_on_missing=False)
238        if table_schema:
239            column_type = table_schema.get(column_name)
240
241            if isinstance(column_type, exp.DataType):
242                return column_type
243            elif isinstance(column_type, str):
244                return self._to_data_type(column_type.upper())
245            raise SchemaError(f"Unknown column type '{column_type}'")
246
247        return exp.DataType.build("unknown")
248
249    def _normalize(self, schema: t.Dict) -> t.Dict:
250        """
251        Converts all identifiers in the schema into lowercase, unless they're quoted.
252
253        Args:
254            schema: the schema to normalize.
255
256        Returns:
257            The normalized schema mapping.
258        """
259        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
260
261        normalized_mapping: t.Dict = {}
262        for keys in flattened_schema:
263            columns = nested_get(schema, *zip(keys, keys))
264            assert columns is not None
265
266            normalized_keys = [self._normalize_name(key) for key in keys]
267            for column_name, column_type in columns.items():
268                nested_set(
269                    normalized_mapping,
270                    normalized_keys + [self._normalize_name(column_name)],
271                    column_type,
272                )
273
274        return normalized_mapping
275
276    def _normalize_table(self, table: exp.Table) -> exp.Table:
277        normalized_table = table.copy()
278        for arg in TABLE_ARGS:
279            value = normalized_table.args.get(arg)
280            if isinstance(value, (str, exp.Identifier)):
281                normalized_table.set(arg, self._normalize_name(value))
282
283        return normalized_table
284
285    def _normalize_name(self, name: str | exp.Identifier) -> str:
286        try:
287            identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier)
288        except ParseError:
289            return name if isinstance(name, str) else name.name
290
291        return identifier.name if identifier.quoted else identifier.name.lower()
292
293    def _depth(self) -> int:
294        # The columns themselves are a mapping, but we don't want to include those
295        return super()._depth() - 1
296
297    def _ensure_table(self, table: exp.Table | str) -> exp.Table:
298        if isinstance(table, exp.Table):
299            return table
300
301        table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table)
302        if not table_:
303            raise SchemaError(f"Not a valid table '{table}'")
304
305        return table_
306
307    def _to_data_type(self, schema_type: str) -> exp.DataType:
308        """
309        Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
310
311        Args:
312            schema_type: the type we want to convert.
313
314        Returns:
315            The resulting expression type.
316        """
317        if schema_type not in self._type_mapping_cache:
318            try:
319                expression = exp.DataType.build(schema_type, dialect=self.dialect)
320                self._type_mapping_cache[schema_type] = expression
321            except AttributeError:
322                raise SchemaError(f"Failed to convert type {schema_type}")
323
324        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema (dict): Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect (str): The dialect to be used for custom type mappings.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None)
162    def __init__(
163        self,
164        schema: t.Optional[t.Dict] = None,
165        visible: t.Optional[t.Dict] = None,
166        dialect: DialectType = None,
167    ) -> None:
168        self.dialect = dialect
169        self.visible = visible or {}
170        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
171        super().__init__(self._normalize(schema or {}))
@classmethod
def from_mapping_schema( cls, mapping_schema: sqlglot.schema.MappingSchema) -> sqlglot.schema.MappingSchema:
173    @classmethod
174    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
175        return MappingSchema(
176            schema=mapping_schema.mapping,
177            visible=mapping_schema.visible,
178            dialect=mapping_schema.dialect,
179        )
def copy(self, **kwargs) -> sqlglot.schema.MappingSchema:
181    def copy(self, **kwargs) -> MappingSchema:
182        return MappingSchema(
183            **{  # type: ignore
184                "schema": self.mapping.copy(),
185                "visible": self.visible.copy(),
186                "dialect": self.dialect,
187                **kwargs,
188            }
189        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None) -> None:
191    def add_table(
192        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
193    ) -> None:
194        """
195        Register or update a table. Updates are only performed if a new column mapping is provided.
196
197        Args:
198            table: the `Table` expression instance or string representing the table.
199            column_mapping: a column mapping that describes the structure of the table.
200        """
201        normalized_table = self._normalize_table(self._ensure_table(table))
202        normalized_column_mapping = {
203            self._normalize_name(key): value
204            for key, value in ensure_column_mapping(column_mapping).items()
205        }
206
207        schema = self.find(normalized_table, raise_on_missing=False)
208        if schema and not normalized_column_mapping:
209            return
210
211        parts = self.table_parts(normalized_table)
212
213        nested_set(
214            self.mapping,
215            tuple(reversed(parts)),
216            normalized_column_mapping,
217        )
218        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False) -> List[str]:
220    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
221        table_ = self._normalize_table(self._ensure_table(table))
222        schema = self.find(table_)
223
224        if schema is None:
225            return []
226
227        if not only_visible or not self.visible:
228            return list(schema)
229
230        visible = self.nested_get(self.table_parts(table_), self.visible)
231        return [col for col in schema if col in visible]  # type: ignore

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
Returns:

The list of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str) -> sqlglot.expressions.DataType:
233    def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
234        column_name = self._normalize_name(column if isinstance(column, str) else column.this)
235        table_ = self._normalize_table(self._ensure_table(table))
236
237        table_schema = self.find(table_, raise_on_missing=False)
238        if table_schema:
239            column_type = table_schema.get(column_name)
240
241            if isinstance(column_type, exp.DataType):
242                return column_type
243            elif isinstance(column_type, str):
244                return self._to_data_type(column_type.upper())
245            raise SchemaError(f"Unknown column type '{column_type}'")
246
247        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
Returns:

The resulting column type.

def ensure_schema( schema: Any, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.schema.Schema:
327def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
328    if isinstance(schema, Schema):
329        return schema
330
331    return MappingSchema(schema, dialect=dialect)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
334def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
335    if isinstance(mapping, dict):
336        return mapping
337    elif isinstance(mapping, str):
338        col_name_type_strs = [x.strip() for x in mapping.split(",")]
339        return {
340            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
341            for name_type_str in col_name_type_strs
342        }
343    # Check if mapping looks like a DataFrame StructType
344    elif hasattr(mapping, "simpleString"):
345        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}  # type: ignore
346    elif isinstance(mapping, list):
347        return {x.strip(): None for x in mapping}
348    elif mapping is None:
349        return {}
350    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
353def flatten_schema(
354    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
355) -> t.List[t.List[str]]:
356    tables = []
357    keys = keys or []
358
359    for k, v in schema.items():
360        if depth >= 2:
361            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
362        elif depth == 1:
363            tables.append(keys + [k])
364    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
367def nested_get(
368    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
369) -> t.Optional[t.Any]:
370    """
371    Get a value for a nested dictionary.
372
373    Args:
374        d: the dictionary to search.
375        *path: tuples of (name, key), where:
376            `key` is the key in the dictionary to get.
377            `name` is a string to use in the error if `key` isn't found.
378
379    Returns:
380        The value or None if it doesn't exist.
381    """
382    for name, key in path:
383        d = d.get(key)  # type: ignore
384        if d is None:
385            if raise_on_missing:
386                name = "table" if name == "this" else name
387                raise ValueError(f"Unknown {name}: {key}")
388            return None
389    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
392def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
393    """
394    In-place set a value for a nested dictionary
395
396    Example:
397        >>> nested_set({}, ["top_key", "second_key"], "value")
398        {'top_key': {'second_key': 'value'}}
399
400        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
401        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
402
403    Args:
404        d: dictionary to update.
405        keys: the keys that makeup the path to `value`.
406        value: the value to set in the dictionary for the given key path.
407
408    Returns:
409        The (possibly) updated dictionary.
410    """
411    if not keys:
412        return d
413
414    if len(keys) == 1:
415        d[keys[0]] = value
416        return d
417
418    subd = d
419    for key in keys[:-1]:
420        if key not in subd:
421            subd = subd.setdefault(key, {})
422        else:
423            subd = subd[key]
424
425    subd[keys[-1]] = value
426    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.