Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth, first
 10from sqlglot.trie import TrieResult, in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dialects.dialect import DialectType
 14
 15    ColumnMapping = t.Union[t.Dict, str, t.List]
 16
 17
 18class Schema(abc.ABC):
 19    """Abstract base class for database schemas"""
 20
 21    dialect: DialectType
 22
 23    @abc.abstractmethod
 24    def add_table(
 25        self,
 26        table: exp.Table | str,
 27        column_mapping: t.Optional[ColumnMapping] = None,
 28        dialect: DialectType = None,
 29        normalize: t.Optional[bool] = None,
 30        match_depth: bool = True,
 31    ) -> None:
 32        """
 33        Register or update a table. Some implementing classes may require column information to also be provided.
 34        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 35
 36        Args:
 37            table: the `Table` expression instance or string representing the table.
 38            column_mapping: a column mapping that describes the structure of the table.
 39            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 40            normalize: whether to normalize identifiers according to the dialect of interest.
 41            match_depth: whether to enforce that the table must match the schema's depth or not.
 42        """
 43
 44    @abc.abstractmethod
 45    def column_names(
 46        self,
 47        table: exp.Table | str,
 48        only_visible: bool = False,
 49        dialect: DialectType = None,
 50        normalize: t.Optional[bool] = None,
 51    ) -> t.Sequence[str]:
 52        """
 53        Get the column names for a table.
 54
 55        Args:
 56            table: the `Table` expression instance.
 57            only_visible: whether to include invisible columns.
 58            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 59            normalize: whether to normalize identifiers according to the dialect of interest.
 60
 61        Returns:
 62            The sequence of column names.
 63        """
 64
 65    @abc.abstractmethod
 66    def get_column_type(
 67        self,
 68        table: exp.Table | str,
 69        column: exp.Column | str,
 70        dialect: DialectType = None,
 71        normalize: t.Optional[bool] = None,
 72    ) -> exp.DataType:
 73        """
 74        Get the `sqlglot.exp.DataType` type of a column in the schema.
 75
 76        Args:
 77            table: the source table.
 78            column: the target column.
 79            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 80            normalize: whether to normalize identifiers according to the dialect of interest.
 81
 82        Returns:
 83            The resulting column type.
 84        """
 85
 86    def has_column(
 87        self,
 88        table: exp.Table | str,
 89        column: exp.Column | str,
 90        dialect: DialectType = None,
 91        normalize: t.Optional[bool] = None,
 92    ) -> bool:
 93        """
 94        Returns whether `column` appears in `table`'s schema.
 95
 96        Args:
 97            table: the source table.
 98            column: the target column.
 99            dialect: the SQL dialect that will be used to parse `table` if it's a string.
100            normalize: whether to normalize identifiers according to the dialect of interest.
101
102        Returns:
103            True if the column appears in the schema, False otherwise.
104        """
105        name = column if isinstance(column, str) else column.name
106        return name in self.column_names(table, dialect=dialect, normalize=normalize)
107
108    @property
109    @abc.abstractmethod
110    def supported_table_args(self) -> t.Tuple[str, ...]:
111        """
112        Table arguments this schema support, e.g. `("this", "db", "catalog")`
113        """
114
115    @property
116    def empty(self) -> bool:
117        """Returns whether the schema is empty."""
118        return True
119
120
121class AbstractMappingSchema:
122    def __init__(
123        self,
124        mapping: t.Optional[t.Dict] = None,
125    ) -> None:
126        self.mapping = mapping or {}
127        self.mapping_trie = new_trie(
128            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
129        )
130        self._supported_table_args: t.Tuple[str, ...] = tuple()
131
132    @property
133    def empty(self) -> bool:
134        return not self.mapping
135
136    def depth(self) -> int:
137        return dict_depth(self.mapping)
138
139    @property
140    def supported_table_args(self) -> t.Tuple[str, ...]:
141        if not self._supported_table_args and self.mapping:
142            depth = self.depth()
143
144            if not depth:  # None
145                self._supported_table_args = tuple()
146            elif 1 <= depth <= 3:
147                self._supported_table_args = exp.TABLE_PARTS[:depth]
148            else:
149                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
150
151        return self._supported_table_args
152
153    def table_parts(self, table: exp.Table) -> t.List[str]:
154        if isinstance(table.this, exp.ReadCSV):
155            return [table.this.name]
156        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
157
158    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
159        """
160        Returns the schema of a given table.
161
162        Args:
163            table: the target table.
164            raise_on_missing: whether to raise in case the schema is not found.
165
166        Returns:
167            The schema of the target table.
168        """
169        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
170        value, trie = in_trie(self.mapping_trie, parts)
171
172        if value == TrieResult.FAILED:
173            return None
174
175        if value == TrieResult.PREFIX:
176            possibilities = flatten_schema(trie)
177
178            if len(possibilities) == 1:
179                parts.extend(possibilities[0])
180            else:
181                message = ", ".join(".".join(parts) for parts in possibilities)
182                if raise_on_missing:
183                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
184                return None
185
186        return self.nested_get(parts, raise_on_missing=raise_on_missing)
187
188    def nested_get(
189        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
190    ) -> t.Optional[t.Any]:
191        return nested_get(
192            d or self.mapping,
193            *zip(self.supported_table_args, reversed(parts)),
194            raise_on_missing=raise_on_missing,
195        )
196
197
198class MappingSchema(AbstractMappingSchema, Schema):
199    """
200    Schema based on a nested mapping.
201
202    Args:
203        schema: Mapping in one of the following forms:
204            1. {table: {col: type}}
205            2. {db: {table: {col: type}}}
206            3. {catalog: {db: {table: {col: type}}}}
207            4. None - Tables will be added later
208        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
209            are assumed to be visible. The nesting should mirror that of the schema:
210            1. {table: set(*cols)}}
211            2. {db: {table: set(*cols)}}}
212            3. {catalog: {db: {table: set(*cols)}}}}
213        dialect: The dialect to be used for custom type mappings & parsing string arguments.
214        normalize: Whether to normalize identifier names according to the given dialect or not.
215    """
216
217    def __init__(
218        self,
219        schema: t.Optional[t.Dict] = None,
220        visible: t.Optional[t.Dict] = None,
221        dialect: DialectType = None,
222        normalize: bool = True,
223    ) -> None:
224        self.dialect = dialect
225        self.visible = {} if visible is None else visible
226        self.normalize = normalize
227        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
228        self._depth = 0
229        schema = {} if schema is None else schema
230
231        super().__init__(self._normalize(schema) if self.normalize else schema)
232
233    @classmethod
234    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
235        return MappingSchema(
236            schema=mapping_schema.mapping,
237            visible=mapping_schema.visible,
238            dialect=mapping_schema.dialect,
239            normalize=mapping_schema.normalize,
240        )
241
242    def copy(self, **kwargs) -> MappingSchema:
243        return MappingSchema(
244            **{  # type: ignore
245                "schema": self.mapping.copy(),
246                "visible": self.visible.copy(),
247                "dialect": self.dialect,
248                "normalize": self.normalize,
249                **kwargs,
250            }
251        )
252
253    def add_table(
254        self,
255        table: exp.Table | str,
256        column_mapping: t.Optional[ColumnMapping] = None,
257        dialect: DialectType = None,
258        normalize: t.Optional[bool] = None,
259        match_depth: bool = True,
260    ) -> None:
261        """
262        Register or update a table. Updates are only performed if a new column mapping is provided.
263        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
264
265        Args:
266            table: the `Table` expression instance or string representing the table.
267            column_mapping: a column mapping that describes the structure of the table.
268            dialect: the SQL dialect that will be used to parse `table` if it's a string.
269            normalize: whether to normalize identifiers according to the dialect of interest.
270            match_depth: whether to enforce that the table must match the schema's depth or not.
271        """
272        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
273
274        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
275            raise SchemaError(
276                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
277                f"schema's nesting level: {self.depth()}."
278            )
279
280        normalized_column_mapping = {
281            self._normalize_name(key, dialect=dialect, normalize=normalize): value
282            for key, value in ensure_column_mapping(column_mapping).items()
283        }
284
285        schema = self.find(normalized_table, raise_on_missing=False)
286        if schema and not normalized_column_mapping:
287            return
288
289        parts = self.table_parts(normalized_table)
290
291        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
292        new_trie([parts], self.mapping_trie)
293
294    def column_names(
295        self,
296        table: exp.Table | str,
297        only_visible: bool = False,
298        dialect: DialectType = None,
299        normalize: t.Optional[bool] = None,
300    ) -> t.List[str]:
301        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
302
303        schema = self.find(normalized_table)
304        if schema is None:
305            return []
306
307        if not only_visible or not self.visible:
308            return list(schema)
309
310        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
311        return [col for col in schema if col in visible]
312
313    def get_column_type(
314        self,
315        table: exp.Table | str,
316        column: exp.Column | str,
317        dialect: DialectType = None,
318        normalize: t.Optional[bool] = None,
319    ) -> exp.DataType:
320        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
321
322        normalized_column_name = self._normalize_name(
323            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
324        )
325
326        table_schema = self.find(normalized_table, raise_on_missing=False)
327        if table_schema:
328            column_type = table_schema.get(normalized_column_name)
329
330            if isinstance(column_type, exp.DataType):
331                return column_type
332            elif isinstance(column_type, str):
333                return self._to_data_type(column_type, dialect=dialect)
334
335        return exp.DataType.build("unknown")
336
337    def has_column(
338        self,
339        table: exp.Table | str,
340        column: exp.Column | str,
341        dialect: DialectType = None,
342        normalize: t.Optional[bool] = None,
343    ) -> bool:
344        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
345
346        normalized_column_name = self._normalize_name(
347            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
348        )
349
350        table_schema = self.find(normalized_table, raise_on_missing=False)
351        return normalized_column_name in table_schema if table_schema else False
352
353    def _normalize(self, schema: t.Dict) -> t.Dict:
354        """
355        Normalizes all identifiers in the schema.
356
357        Args:
358            schema: the schema to normalize.
359
360        Returns:
361            The normalized schema mapping.
362        """
363        normalized_mapping: t.Dict = {}
364        flattened_schema = flatten_schema(schema)
365        error_msg = "Table {} must match the schema's nesting level: {}."
366
367        for keys in flattened_schema:
368            columns = nested_get(schema, *zip(keys, keys))
369
370            if not isinstance(columns, dict):
371                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
372            if isinstance(first(columns.values()), dict):
373                raise SchemaError(
374                    error_msg.format(
375                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
376                    ),
377                )
378
379            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
380            for column_name, column_type in columns.items():
381                nested_set(
382                    normalized_mapping,
383                    normalized_keys + [self._normalize_name(column_name)],
384                    column_type,
385                )
386
387        return normalized_mapping
388
389    def _normalize_table(
390        self,
391        table: exp.Table | str,
392        dialect: DialectType = None,
393        normalize: t.Optional[bool] = None,
394    ) -> exp.Table:
395        dialect = dialect or self.dialect
396        normalize = self.normalize if normalize is None else normalize
397
398        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
399
400        if normalize:
401            for arg in exp.TABLE_PARTS:
402                value = normalized_table.args.get(arg)
403                if isinstance(value, exp.Identifier):
404                    normalized_table.set(
405                        arg,
406                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
407                    )
408
409        return normalized_table
410
411    def _normalize_name(
412        self,
413        name: str | exp.Identifier,
414        dialect: DialectType = None,
415        is_table: bool = False,
416        normalize: t.Optional[bool] = None,
417    ) -> str:
418        return normalize_name(
419            name,
420            dialect=dialect or self.dialect,
421            is_table=is_table,
422            normalize=self.normalize if normalize is None else normalize,
423        ).name
424
425    def depth(self) -> int:
426        if not self.empty and not self._depth:
427            # The columns themselves are a mapping, but we don't want to include those
428            self._depth = super().depth() - 1
429        return self._depth
430
431    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
432        """
433        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
434
435        Args:
436            schema_type: the type we want to convert.
437            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
438
439        Returns:
440            The resulting expression type.
441        """
442        if schema_type not in self._type_mapping_cache:
443            dialect = dialect or self.dialect
444            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
445
446            try:
447                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
448                self._type_mapping_cache[schema_type] = expression
449            except AttributeError:
450                in_dialect = f" in dialect {dialect}" if dialect else ""
451                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
452
453        return self._type_mapping_cache[schema_type]
454
455
456def normalize_name(
457    identifier: str | exp.Identifier,
458    dialect: DialectType = None,
459    is_table: bool = False,
460    normalize: t.Optional[bool] = True,
461) -> exp.Identifier:
462    if isinstance(identifier, str):
463        identifier = exp.parse_identifier(identifier, dialect=dialect)
464
465    if not normalize:
466        return identifier
467
468    # this is used for normalize_identifier, bigquery has special rules pertaining tables
469    identifier.meta["is_table"] = is_table
470    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
471
472
473def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
474    if isinstance(schema, Schema):
475        return schema
476
477    return MappingSchema(schema, **kwargs)
478
479
480def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
481    if mapping is None:
482        return {}
483    elif isinstance(mapping, dict):
484        return mapping
485    elif isinstance(mapping, str):
486        col_name_type_strs = [x.strip() for x in mapping.split(",")]
487        return {
488            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
489            for name_type_str in col_name_type_strs
490        }
491    elif isinstance(mapping, list):
492        return {x.strip(): None for x in mapping}
493
494    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
495
496
497def flatten_schema(
498    schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
499) -> t.List[t.List[str]]:
500    tables = []
501    keys = keys or []
502    depth = dict_depth(schema) - 1 if depth is None else depth
503
504    for k, v in schema.items():
505        if depth == 1 or not isinstance(v, dict):
506            tables.append(keys + [k])
507        elif depth >= 2:
508            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
509
510    return tables
511
512
513def nested_get(
514    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
515) -> t.Optional[t.Any]:
516    """
517    Get a value for a nested dictionary.
518
519    Args:
520        d: the dictionary to search.
521        *path: tuples of (name, key), where:
522            `key` is the key in the dictionary to get.
523            `name` is a string to use in the error if `key` isn't found.
524
525    Returns:
526        The value or None if it doesn't exist.
527    """
528    for name, key in path:
529        d = d.get(key)  # type: ignore
530        if d is None:
531            if raise_on_missing:
532                name = "table" if name == "this" else name
533                raise ValueError(f"Unknown {name}: {key}")
534            return None
535
536    return d
537
538
539def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
540    """
541    In-place set a value for a nested dictionary
542
543    Example:
544        >>> nested_set({}, ["top_key", "second_key"], "value")
545        {'top_key': {'second_key': 'value'}}
546
547        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
548        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
549
550    Args:
551        d: dictionary to update.
552        keys: the keys that makeup the path to `value`.
553        value: the value to set in the dictionary for the given key path.
554
555    Returns:
556        The (possibly) updated dictionary.
557    """
558    if not keys:
559        return d
560
561    if len(keys) == 1:
562        d[keys[0]] = value
563        return d
564
565    subd = d
566    for key in keys[:-1]:
567        if key not in subd:
568            subd = subd.setdefault(key, {})
569        else:
570            subd = subd[key]
571
572    subd[keys[-1]] = value
573    return d
class Schema(abc.ABC):
 19class Schema(abc.ABC):
 20    """Abstract base class for database schemas"""
 21
 22    dialect: DialectType
 23
 24    @abc.abstractmethod
 25    def add_table(
 26        self,
 27        table: exp.Table | str,
 28        column_mapping: t.Optional[ColumnMapping] = None,
 29        dialect: DialectType = None,
 30        normalize: t.Optional[bool] = None,
 31        match_depth: bool = True,
 32    ) -> None:
 33        """
 34        Register or update a table. Some implementing classes may require column information to also be provided.
 35        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 36
 37        Args:
 38            table: the `Table` expression instance or string representing the table.
 39            column_mapping: a column mapping that describes the structure of the table.
 40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 41            normalize: whether to normalize identifiers according to the dialect of interest.
 42            match_depth: whether to enforce that the table must match the schema's depth or not.
 43        """
 44
 45    @abc.abstractmethod
 46    def column_names(
 47        self,
 48        table: exp.Table | str,
 49        only_visible: bool = False,
 50        dialect: DialectType = None,
 51        normalize: t.Optional[bool] = None,
 52    ) -> t.Sequence[str]:
 53        """
 54        Get the column names for a table.
 55
 56        Args:
 57            table: the `Table` expression instance.
 58            only_visible: whether to include invisible columns.
 59            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 60            normalize: whether to normalize identifiers according to the dialect of interest.
 61
 62        Returns:
 63            The sequence of column names.
 64        """
 65
 66    @abc.abstractmethod
 67    def get_column_type(
 68        self,
 69        table: exp.Table | str,
 70        column: exp.Column | str,
 71        dialect: DialectType = None,
 72        normalize: t.Optional[bool] = None,
 73    ) -> exp.DataType:
 74        """
 75        Get the `sqlglot.exp.DataType` type of a column in the schema.
 76
 77        Args:
 78            table: the source table.
 79            column: the target column.
 80            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 81            normalize: whether to normalize identifiers according to the dialect of interest.
 82
 83        Returns:
 84            The resulting column type.
 85        """
 86
 87    def has_column(
 88        self,
 89        table: exp.Table | str,
 90        column: exp.Column | str,
 91        dialect: DialectType = None,
 92        normalize: t.Optional[bool] = None,
 93    ) -> bool:
 94        """
 95        Returns whether `column` appears in `table`'s schema.
 96
 97        Args:
 98            table: the source table.
 99            column: the target column.
100            dialect: the SQL dialect that will be used to parse `table` if it's a string.
101            normalize: whether to normalize identifiers according to the dialect of interest.
102
103        Returns:
104            True if the column appears in the schema, False otherwise.
105        """
106        name = column if isinstance(column, str) else column.name
107        return name in self.column_names(table, dialect=dialect, normalize=normalize)
108
109    @property
110    @abc.abstractmethod
111    def supported_table_args(self) -> t.Tuple[str, ...]:
112        """
113        Table arguments this schema support, e.g. `("this", "db", "catalog")`
114        """
115
116    @property
117    def empty(self) -> bool:
118        """Returns whether the schema is empty."""
119        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
24    @abc.abstractmethod
25    def add_table(
26        self,
27        table: exp.Table | str,
28        column_mapping: t.Optional[ColumnMapping] = None,
29        dialect: DialectType = None,
30        normalize: t.Optional[bool] = None,
31        match_depth: bool = True,
32    ) -> None:
33        """
34        Register or update a table. Some implementing classes may require column information to also be provided.
35        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
36
37        Args:
38            table: the `Table` expression instance or string representing the table.
39            column_mapping: a column mapping that describes the structure of the table.
40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
41            normalize: whether to normalize identifiers according to the dialect of interest.
42            match_depth: whether to enforce that the table must match the schema's depth or not.
43        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> Sequence[str]:
45    @abc.abstractmethod
46    def column_names(
47        self,
48        table: exp.Table | str,
49        only_visible: bool = False,
50        dialect: DialectType = None,
51        normalize: t.Optional[bool] = None,
52    ) -> t.Sequence[str]:
53        """
54        Get the column names for a table.
55
56        Args:
57            table: the `Table` expression instance.
58            only_visible: whether to include invisible columns.
59            dialect: the SQL dialect that will be used to parse `table` if it's a string.
60            normalize: whether to normalize identifiers according to the dialect of interest.
61
62        Returns:
63            The sequence of column names.
64        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
66    @abc.abstractmethod
67    def get_column_type(
68        self,
69        table: exp.Table | str,
70        column: exp.Column | str,
71        dialect: DialectType = None,
72        normalize: t.Optional[bool] = None,
73    ) -> exp.DataType:
74        """
75        Get the `sqlglot.exp.DataType` type of a column in the schema.
76
77        Args:
78            table: the source table.
79            column: the target column.
80            dialect: the SQL dialect that will be used to parse `table` if it's a string.
81            normalize: whether to normalize identifiers according to the dialect of interest.
82
83        Returns:
84            The resulting column type.
85        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
 87    def has_column(
 88        self,
 89        table: exp.Table | str,
 90        column: exp.Column | str,
 91        dialect: DialectType = None,
 92        normalize: t.Optional[bool] = None,
 93    ) -> bool:
 94        """
 95        Returns whether `column` appears in `table`'s schema.
 96
 97        Args:
 98            table: the source table.
 99            column: the target column.
100            dialect: the SQL dialect that will be used to parse `table` if it's a string.
101            normalize: whether to normalize identifiers according to the dialect of interest.
102
103        Returns:
104            True if the column appears in the schema, False otherwise.
105        """
106        name = column if isinstance(column, str) else column.name
107        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

supported_table_args: Tuple[str, ...]
109    @property
110    @abc.abstractmethod
111    def supported_table_args(self) -> t.Tuple[str, ...]:
112        """
113        Table arguments this schema support, e.g. `("this", "db", "catalog")`
114        """

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool
116    @property
117    def empty(self) -> bool:
118        """Returns whether the schema is empty."""
119        return True

Returns whether the schema is empty.

class AbstractMappingSchema:
122class AbstractMappingSchema:
123    def __init__(
124        self,
125        mapping: t.Optional[t.Dict] = None,
126    ) -> None:
127        self.mapping = mapping or {}
128        self.mapping_trie = new_trie(
129            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
130        )
131        self._supported_table_args: t.Tuple[str, ...] = tuple()
132
133    @property
134    def empty(self) -> bool:
135        return not self.mapping
136
137    def depth(self) -> int:
138        return dict_depth(self.mapping)
139
140    @property
141    def supported_table_args(self) -> t.Tuple[str, ...]:
142        if not self._supported_table_args and self.mapping:
143            depth = self.depth()
144
145            if not depth:  # None
146                self._supported_table_args = tuple()
147            elif 1 <= depth <= 3:
148                self._supported_table_args = exp.TABLE_PARTS[:depth]
149            else:
150                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
151
152        return self._supported_table_args
153
154    def table_parts(self, table: exp.Table) -> t.List[str]:
155        if isinstance(table.this, exp.ReadCSV):
156            return [table.this.name]
157        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
158
159    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
160        """
161        Returns the schema of a given table.
162
163        Args:
164            table: the target table.
165            raise_on_missing: whether to raise in case the schema is not found.
166
167        Returns:
168            The schema of the target table.
169        """
170        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
171        value, trie = in_trie(self.mapping_trie, parts)
172
173        if value == TrieResult.FAILED:
174            return None
175
176        if value == TrieResult.PREFIX:
177            possibilities = flatten_schema(trie)
178
179            if len(possibilities) == 1:
180                parts.extend(possibilities[0])
181            else:
182                message = ", ".join(".".join(parts) for parts in possibilities)
183                if raise_on_missing:
184                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
185                return None
186
187        return self.nested_get(parts, raise_on_missing=raise_on_missing)
188
189    def nested_get(
190        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
191    ) -> t.Optional[t.Any]:
192        return nested_get(
193            d or self.mapping,
194            *zip(self.supported_table_args, reversed(parts)),
195            raise_on_missing=raise_on_missing,
196        )
AbstractMappingSchema(mapping: Optional[Dict] = None)
123    def __init__(
124        self,
125        mapping: t.Optional[t.Dict] = None,
126    ) -> None:
127        self.mapping = mapping or {}
128        self.mapping_trie = new_trie(
129            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
130        )
131        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
133    @property
134    def empty(self) -> bool:
135        return not self.mapping
def depth(self) -> int:
137    def depth(self) -> int:
138        return dict_depth(self.mapping)
supported_table_args: Tuple[str, ...]
140    @property
141    def supported_table_args(self) -> t.Tuple[str, ...]:
142        if not self._supported_table_args and self.mapping:
143            depth = self.depth()
144
145            if not depth:  # None
146                self._supported_table_args = tuple()
147            elif 1 <= depth <= 3:
148                self._supported_table_args = exp.TABLE_PARTS[:depth]
149            else:
150                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
151
152        return self._supported_table_args
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
154    def table_parts(self, table: exp.Table) -> t.List[str]:
155        if isinstance(table.this, exp.ReadCSV):
156            return [table.this.name]
157        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True) -> Optional[Any]:
159    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
160        """
161        Returns the schema of a given table.
162
163        Args:
164            table: the target table.
165            raise_on_missing: whether to raise in case the schema is not found.
166
167        Returns:
168            The schema of the target table.
169        """
170        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
171        value, trie = in_trie(self.mapping_trie, parts)
172
173        if value == TrieResult.FAILED:
174            return None
175
176        if value == TrieResult.PREFIX:
177            possibilities = flatten_schema(trie)
178
179            if len(possibilities) == 1:
180                parts.extend(possibilities[0])
181            else:
182                message = ", ".join(".".join(parts) for parts in possibilities)
183                if raise_on_missing:
184                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
185                return None
186
187        return self.nested_get(parts, raise_on_missing=raise_on_missing)

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
Returns:

The schema of the target table.

def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
189    def nested_get(
190        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
191    ) -> t.Optional[t.Any]:
192        return nested_get(
193            d or self.mapping,
194            *zip(self.supported_table_args, reversed(parts)),
195            raise_on_missing=raise_on_missing,
196        )
class MappingSchema(AbstractMappingSchema, Schema):
199class MappingSchema(AbstractMappingSchema, Schema):
200    """
201    Schema based on a nested mapping.
202
203    Args:
204        schema: Mapping in one of the following forms:
205            1. {table: {col: type}}
206            2. {db: {table: {col: type}}}
207            3. {catalog: {db: {table: {col: type}}}}
208            4. None - Tables will be added later
209        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
210            are assumed to be visible. The nesting should mirror that of the schema:
211            1. {table: set(*cols)}}
212            2. {db: {table: set(*cols)}}}
213            3. {catalog: {db: {table: set(*cols)}}}}
214        dialect: The dialect to be used for custom type mappings & parsing string arguments.
215        normalize: Whether to normalize identifier names according to the given dialect or not.
216    """
217
218    def __init__(
219        self,
220        schema: t.Optional[t.Dict] = None,
221        visible: t.Optional[t.Dict] = None,
222        dialect: DialectType = None,
223        normalize: bool = True,
224    ) -> None:
225        self.dialect = dialect
226        self.visible = {} if visible is None else visible
227        self.normalize = normalize
228        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
229        self._depth = 0
230        schema = {} if schema is None else schema
231
232        super().__init__(self._normalize(schema) if self.normalize else schema)
233
234    @classmethod
235    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
236        return MappingSchema(
237            schema=mapping_schema.mapping,
238            visible=mapping_schema.visible,
239            dialect=mapping_schema.dialect,
240            normalize=mapping_schema.normalize,
241        )
242
243    def copy(self, **kwargs) -> MappingSchema:
244        return MappingSchema(
245            **{  # type: ignore
246                "schema": self.mapping.copy(),
247                "visible": self.visible.copy(),
248                "dialect": self.dialect,
249                "normalize": self.normalize,
250                **kwargs,
251            }
252        )
253
254    def add_table(
255        self,
256        table: exp.Table | str,
257        column_mapping: t.Optional[ColumnMapping] = None,
258        dialect: DialectType = None,
259        normalize: t.Optional[bool] = None,
260        match_depth: bool = True,
261    ) -> None:
262        """
263        Register or update a table. Updates are only performed if a new column mapping is provided.
264        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
265
266        Args:
267            table: the `Table` expression instance or string representing the table.
268            column_mapping: a column mapping that describes the structure of the table.
269            dialect: the SQL dialect that will be used to parse `table` if it's a string.
270            normalize: whether to normalize identifiers according to the dialect of interest.
271            match_depth: whether to enforce that the table must match the schema's depth or not.
272        """
273        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
274
275        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
276            raise SchemaError(
277                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
278                f"schema's nesting level: {self.depth()}."
279            )
280
281        normalized_column_mapping = {
282            self._normalize_name(key, dialect=dialect, normalize=normalize): value
283            for key, value in ensure_column_mapping(column_mapping).items()
284        }
285
286        schema = self.find(normalized_table, raise_on_missing=False)
287        if schema and not normalized_column_mapping:
288            return
289
290        parts = self.table_parts(normalized_table)
291
292        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
293        new_trie([parts], self.mapping_trie)
294
295    def column_names(
296        self,
297        table: exp.Table | str,
298        only_visible: bool = False,
299        dialect: DialectType = None,
300        normalize: t.Optional[bool] = None,
301    ) -> t.List[str]:
302        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
303
304        schema = self.find(normalized_table)
305        if schema is None:
306            return []
307
308        if not only_visible or not self.visible:
309            return list(schema)
310
311        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
312        return [col for col in schema if col in visible]
313
314    def get_column_type(
315        self,
316        table: exp.Table | str,
317        column: exp.Column | str,
318        dialect: DialectType = None,
319        normalize: t.Optional[bool] = None,
320    ) -> exp.DataType:
321        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
322
323        normalized_column_name = self._normalize_name(
324            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
325        )
326
327        table_schema = self.find(normalized_table, raise_on_missing=False)
328        if table_schema:
329            column_type = table_schema.get(normalized_column_name)
330
331            if isinstance(column_type, exp.DataType):
332                return column_type
333            elif isinstance(column_type, str):
334                return self._to_data_type(column_type, dialect=dialect)
335
336        return exp.DataType.build("unknown")
337
338    def has_column(
339        self,
340        table: exp.Table | str,
341        column: exp.Column | str,
342        dialect: DialectType = None,
343        normalize: t.Optional[bool] = None,
344    ) -> bool:
345        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
346
347        normalized_column_name = self._normalize_name(
348            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
349        )
350
351        table_schema = self.find(normalized_table, raise_on_missing=False)
352        return normalized_column_name in table_schema if table_schema else False
353
354    def _normalize(self, schema: t.Dict) -> t.Dict:
355        """
356        Normalizes all identifiers in the schema.
357
358        Args:
359            schema: the schema to normalize.
360
361        Returns:
362            The normalized schema mapping.
363        """
364        normalized_mapping: t.Dict = {}
365        flattened_schema = flatten_schema(schema)
366        error_msg = "Table {} must match the schema's nesting level: {}."
367
368        for keys in flattened_schema:
369            columns = nested_get(schema, *zip(keys, keys))
370
371            if not isinstance(columns, dict):
372                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
373            if isinstance(first(columns.values()), dict):
374                raise SchemaError(
375                    error_msg.format(
376                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
377                    ),
378                )
379
380            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
381            for column_name, column_type in columns.items():
382                nested_set(
383                    normalized_mapping,
384                    normalized_keys + [self._normalize_name(column_name)],
385                    column_type,
386                )
387
388        return normalized_mapping
389
390    def _normalize_table(
391        self,
392        table: exp.Table | str,
393        dialect: DialectType = None,
394        normalize: t.Optional[bool] = None,
395    ) -> exp.Table:
396        dialect = dialect or self.dialect
397        normalize = self.normalize if normalize is None else normalize
398
399        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
400
401        if normalize:
402            for arg in exp.TABLE_PARTS:
403                value = normalized_table.args.get(arg)
404                if isinstance(value, exp.Identifier):
405                    normalized_table.set(
406                        arg,
407                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
408                    )
409
410        return normalized_table
411
412    def _normalize_name(
413        self,
414        name: str | exp.Identifier,
415        dialect: DialectType = None,
416        is_table: bool = False,
417        normalize: t.Optional[bool] = None,
418    ) -> str:
419        return normalize_name(
420            name,
421            dialect=dialect or self.dialect,
422            is_table=is_table,
423            normalize=self.normalize if normalize is None else normalize,
424        ).name
425
426    def depth(self) -> int:
427        if not self.empty and not self._depth:
428            # The columns themselves are a mapping, but we don't want to include those
429            self._depth = super().depth() - 1
430        return self._depth
431
432    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
433        """
434        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
435
436        Args:
437            schema_type: the type we want to convert.
438            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
439
440        Returns:
441            The resulting expression type.
442        """
443        if schema_type not in self._type_mapping_cache:
444            dialect = dialect or self.dialect
445            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
446
447            try:
448                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
449                self._type_mapping_cache[schema_type] = expression
450            except AttributeError:
451                in_dialect = f" in dialect {dialect}" if dialect else ""
452                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
453
454        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
218    def __init__(
219        self,
220        schema: t.Optional[t.Dict] = None,
221        visible: t.Optional[t.Dict] = None,
222        dialect: DialectType = None,
223        normalize: bool = True,
224    ) -> None:
225        self.dialect = dialect
226        self.visible = {} if visible is None else visible
227        self.normalize = normalize
228        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
229        self._depth = 0
230        schema = {} if schema is None else schema
231
232        super().__init__(self._normalize(schema) if self.normalize else schema)
dialect
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
234    @classmethod
235    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
236        return MappingSchema(
237            schema=mapping_schema.mapping,
238            visible=mapping_schema.visible,
239            dialect=mapping_schema.dialect,
240            normalize=mapping_schema.normalize,
241        )
def copy(self, **kwargs) -> MappingSchema:
243    def copy(self, **kwargs) -> MappingSchema:
244        return MappingSchema(
245            **{  # type: ignore
246                "schema": self.mapping.copy(),
247                "visible": self.visible.copy(),
248                "dialect": self.dialect,
249                "normalize": self.normalize,
250                **kwargs,
251            }
252        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
254    def add_table(
255        self,
256        table: exp.Table | str,
257        column_mapping: t.Optional[ColumnMapping] = None,
258        dialect: DialectType = None,
259        normalize: t.Optional[bool] = None,
260        match_depth: bool = True,
261    ) -> None:
262        """
263        Register or update a table. Updates are only performed if a new column mapping is provided.
264        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
265
266        Args:
267            table: the `Table` expression instance or string representing the table.
268            column_mapping: a column mapping that describes the structure of the table.
269            dialect: the SQL dialect that will be used to parse `table` if it's a string.
270            normalize: whether to normalize identifiers according to the dialect of interest.
271            match_depth: whether to enforce that the table must match the schema's depth or not.
272        """
273        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
274
275        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
276            raise SchemaError(
277                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
278                f"schema's nesting level: {self.depth()}."
279            )
280
281        normalized_column_mapping = {
282            self._normalize_name(key, dialect=dialect, normalize=normalize): value
283            for key, value in ensure_column_mapping(column_mapping).items()
284        }
285
286        schema = self.find(normalized_table, raise_on_missing=False)
287        if schema and not normalized_column_mapping:
288            return
289
290        parts = self.table_parts(normalized_table)
291
292        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
293        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
295    def column_names(
296        self,
297        table: exp.Table | str,
298        only_visible: bool = False,
299        dialect: DialectType = None,
300        normalize: t.Optional[bool] = None,
301    ) -> t.List[str]:
302        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
303
304        schema = self.find(normalized_table)
305        if schema is None:
306            return []
307
308        if not only_visible or not self.visible:
309            return list(schema)
310
311        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
312        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
314    def get_column_type(
315        self,
316        table: exp.Table | str,
317        column: exp.Column | str,
318        dialect: DialectType = None,
319        normalize: t.Optional[bool] = None,
320    ) -> exp.DataType:
321        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
322
323        normalized_column_name = self._normalize_name(
324            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
325        )
326
327        table_schema = self.find(normalized_table, raise_on_missing=False)
328        if table_schema:
329            column_type = table_schema.get(normalized_column_name)
330
331            if isinstance(column_type, exp.DataType):
332                return column_type
333            elif isinstance(column_type, str):
334                return self._to_data_type(column_type, dialect=dialect)
335
336        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
338    def has_column(
339        self,
340        table: exp.Table | str,
341        column: exp.Column | str,
342        dialect: DialectType = None,
343        normalize: t.Optional[bool] = None,
344    ) -> bool:
345        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
346
347        normalized_column_name = self._normalize_name(
348            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
349        )
350
351        table_schema = self.find(normalized_table, raise_on_missing=False)
352        return normalized_column_name in table_schema if table_schema else False

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
426    def depth(self) -> int:
427        if not self.empty and not self._depth:
428            # The columns themselves are a mapping, but we don't want to include those
429            self._depth = super().depth() - 1
430        return self._depth
def normalize_name( identifier: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> sqlglot.expressions.Identifier:
457def normalize_name(
458    identifier: str | exp.Identifier,
459    dialect: DialectType = None,
460    is_table: bool = False,
461    normalize: t.Optional[bool] = True,
462) -> exp.Identifier:
463    if isinstance(identifier, str):
464        identifier = exp.parse_identifier(identifier, dialect=dialect)
465
466    if not normalize:
467        return identifier
468
469    # this is used for normalize_identifier, bigquery has special rules pertaining tables
470    identifier.meta["is_table"] = is_table
471    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
def ensure_schema( schema: Union[Schema, Dict, NoneType], **kwargs: Any) -> Schema:
474def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
475    if isinstance(schema, Schema):
476        return schema
477
478    return MappingSchema(schema, **kwargs)
def ensure_column_mapping(mapping: Union[Dict, str, List, NoneType]) -> Dict:
481def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
482    if mapping is None:
483        return {}
484    elif isinstance(mapping, dict):
485        return mapping
486    elif isinstance(mapping, str):
487        col_name_type_strs = [x.strip() for x in mapping.split(",")]
488        return {
489            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
490            for name_type_str in col_name_type_strs
491        }
492    elif isinstance(mapping, list):
493        return {x.strip(): None for x in mapping}
494
495    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: Optional[int] = None, keys: Optional[List[str]] = None) -> List[List[str]]:
498def flatten_schema(
499    schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
500) -> t.List[t.List[str]]:
501    tables = []
502    keys = keys or []
503    depth = dict_depth(schema) - 1 if depth is None else depth
504
505    for k, v in schema.items():
506        if depth == 1 or not isinstance(v, dict):
507            tables.append(keys + [k])
508        elif depth >= 2:
509            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
510
511    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
514def nested_get(
515    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
516) -> t.Optional[t.Any]:
517    """
518    Get a value for a nested dictionary.
519
520    Args:
521        d: the dictionary to search.
522        *path: tuples of (name, key), where:
523            `key` is the key in the dictionary to get.
524            `name` is a string to use in the error if `key` isn't found.
525
526    Returns:
527        The value or None if it doesn't exist.
528    """
529    for name, key in path:
530        d = d.get(key)  # type: ignore
531        if d is None:
532            if raise_on_missing:
533                name = "table" if name == "this" else name
534                raise ValueError(f"Unknown {name}: {key}")
535            return None
536
537    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
540def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
541    """
542    In-place set a value for a nested dictionary
543
544    Example:
545        >>> nested_set({}, ["top_key", "second_key"], "value")
546        {'top_key': {'second_key': 'value'}}
547
548        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
549        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
550
551    Args:
552        d: dictionary to update.
553        keys: the keys that makeup the path to `value`.
554        value: the value to set in the dictionary for the given key path.
555
556    Returns:
557        The (possibly) updated dictionary.
558    """
559    if not keys:
560        return d
561
562    if len(keys) == 1:
563        d[keys[0]] = value
564        return d
565
566    subd = d
567    for key in keys[:-1]:
568        if key not in subd:
569            subd = subd.setdefault(key, {})
570        else:
571            subd = subd[key]
572
573    subd[keys[-1]] = value
574    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.