Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth
 10from sqlglot.trie import TrieResult, in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dataframe.sql.types import StructType
 14    from sqlglot.dialects.dialect import DialectType
 15
 16    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 17
 18TABLE_ARGS = ("this", "db", "catalog")
 19
 20
 21class Schema(abc.ABC):
 22    """Abstract base class for database schemas"""
 23
 24    dialect: DialectType
 25
 26    @abc.abstractmethod
 27    def add_table(
 28        self,
 29        table: exp.Table | str,
 30        column_mapping: t.Optional[ColumnMapping] = None,
 31        dialect: DialectType = None,
 32        normalize: t.Optional[bool] = None,
 33        match_depth: bool = True,
 34    ) -> None:
 35        """
 36        Register or update a table. Some implementing classes may require column information to also be provided.
 37        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 38
 39        Args:
 40            table: the `Table` expression instance or string representing the table.
 41            column_mapping: a column mapping that describes the structure of the table.
 42            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 43            normalize: whether to normalize identifiers according to the dialect of interest.
 44            match_depth: whether to enforce that the table must match the schema's depth or not.
 45        """
 46
 47    @abc.abstractmethod
 48    def column_names(
 49        self,
 50        table: exp.Table | str,
 51        only_visible: bool = False,
 52        dialect: DialectType = None,
 53        normalize: t.Optional[bool] = None,
 54    ) -> t.List[str]:
 55        """
 56        Get the column names for a table.
 57
 58        Args:
 59            table: the `Table` expression instance.
 60            only_visible: whether to include invisible columns.
 61            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 62            normalize: whether to normalize identifiers according to the dialect of interest.
 63
 64        Returns:
 65            The list of column names.
 66        """
 67
 68    @abc.abstractmethod
 69    def get_column_type(
 70        self,
 71        table: exp.Table | str,
 72        column: exp.Column | str,
 73        dialect: DialectType = None,
 74        normalize: t.Optional[bool] = None,
 75    ) -> exp.DataType:
 76        """
 77        Get the `sqlglot.exp.DataType` type of a column in the schema.
 78
 79        Args:
 80            table: the source table.
 81            column: the target column.
 82            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 83            normalize: whether to normalize identifiers according to the dialect of interest.
 84
 85        Returns:
 86            The resulting column type.
 87        """
 88
 89    def has_column(
 90        self,
 91        table: exp.Table | str,
 92        column: exp.Column | str,
 93        dialect: DialectType = None,
 94        normalize: t.Optional[bool] = None,
 95    ) -> bool:
 96        """
 97        Returns whether or not `column` appears in `table`'s schema.
 98
 99        Args:
100            table: the source table.
101            column: the target column.
102            dialect: the SQL dialect that will be used to parse `table` if it's a string.
103            normalize: whether to normalize identifiers according to the dialect of interest.
104
105        Returns:
106            True if the column appears in the schema, False otherwise.
107        """
108        name = column if isinstance(column, str) else column.name
109        return name in self.column_names(table, dialect=dialect, normalize=normalize)
110
111    @property
112    @abc.abstractmethod
113    def supported_table_args(self) -> t.Tuple[str, ...]:
114        """
115        Table arguments this schema support, e.g. `("this", "db", "catalog")`
116        """
117
118    @property
119    def empty(self) -> bool:
120        """Returns whether or not the schema is empty."""
121        return True
122
123
124class AbstractMappingSchema:
125    def __init__(
126        self,
127        mapping: t.Optional[t.Dict] = None,
128    ) -> None:
129        self.mapping = mapping or {}
130        self.mapping_trie = new_trie(
131            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
132        )
133        self._supported_table_args: t.Tuple[str, ...] = tuple()
134
135    @property
136    def empty(self) -> bool:
137        return not self.mapping
138
139    def depth(self) -> int:
140        return dict_depth(self.mapping)
141
142    @property
143    def supported_table_args(self) -> t.Tuple[str, ...]:
144        if not self._supported_table_args and self.mapping:
145            depth = self.depth()
146
147            if not depth:  # None
148                self._supported_table_args = tuple()
149            elif 1 <= depth <= 3:
150                self._supported_table_args = TABLE_ARGS[:depth]
151            else:
152                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
153
154        return self._supported_table_args
155
156    def table_parts(self, table: exp.Table) -> t.List[str]:
157        if isinstance(table.this, exp.ReadCSV):
158            return [table.this.name]
159        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
160
161    def find(
162        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
163    ) -> t.Optional[t.Any]:
164        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
165        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
166
167        if value == TrieResult.FAILED:
168            return None
169
170        if value == TrieResult.PREFIX:
171            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
172
173            if len(possibilities) == 1:
174                parts.extend(possibilities[0])
175            else:
176                message = ", ".join(".".join(parts) for parts in possibilities)
177                if raise_on_missing:
178                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
179                return None
180
181        return self.nested_get(parts, raise_on_missing=raise_on_missing)
182
183    def nested_get(
184        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
185    ) -> t.Optional[t.Any]:
186        return nested_get(
187            d or self.mapping,
188            *zip(self.supported_table_args, reversed(parts)),
189            raise_on_missing=raise_on_missing,
190        )
191
192
193class MappingSchema(AbstractMappingSchema, Schema):
194    """
195    Schema based on a nested mapping.
196
197    Args:
198        schema: Mapping in one of the following forms:
199            1. {table: {col: type}}
200            2. {db: {table: {col: type}}}
201            3. {catalog: {db: {table: {col: type}}}}
202            4. None - Tables will be added later
203        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
204            are assumed to be visible. The nesting should mirror that of the schema:
205            1. {table: set(*cols)}}
206            2. {db: {table: set(*cols)}}}
207            3. {catalog: {db: {table: set(*cols)}}}}
208        dialect: The dialect to be used for custom type mappings & parsing string arguments.
209        normalize: Whether to normalize identifier names according to the given dialect or not.
210    """
211
212    def __init__(
213        self,
214        schema: t.Optional[t.Dict] = None,
215        visible: t.Optional[t.Dict] = None,
216        dialect: DialectType = None,
217        normalize: bool = True,
218    ) -> None:
219        self.dialect = dialect
220        self.visible = visible or {}
221        self.normalize = normalize
222        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
223        self._depth = 0
224
225        super().__init__(self._normalize(schema or {}))
226
227    @classmethod
228    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
229        return MappingSchema(
230            schema=mapping_schema.mapping,
231            visible=mapping_schema.visible,
232            dialect=mapping_schema.dialect,
233            normalize=mapping_schema.normalize,
234        )
235
236    def copy(self, **kwargs) -> MappingSchema:
237        return MappingSchema(
238            **{  # type: ignore
239                "schema": self.mapping.copy(),
240                "visible": self.visible.copy(),
241                "dialect": self.dialect,
242                "normalize": self.normalize,
243                **kwargs,
244            }
245        )
246
247    def add_table(
248        self,
249        table: exp.Table | str,
250        column_mapping: t.Optional[ColumnMapping] = None,
251        dialect: DialectType = None,
252        normalize: t.Optional[bool] = None,
253        match_depth: bool = True,
254    ) -> None:
255        """
256        Register or update a table. Updates are only performed if a new column mapping is provided.
257        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
258
259        Args:
260            table: the `Table` expression instance or string representing the table.
261            column_mapping: a column mapping that describes the structure of the table.
262            dialect: the SQL dialect that will be used to parse `table` if it's a string.
263            normalize: whether to normalize identifiers according to the dialect of interest.
264            match_depth: whether to enforce that the table must match the schema's depth or not.
265        """
266        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
267
268        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
269            raise SchemaError(
270                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
271                f"schema's nesting level: {self.depth()}."
272            )
273
274        normalized_column_mapping = {
275            self._normalize_name(key, dialect=dialect, normalize=normalize): value
276            for key, value in ensure_column_mapping(column_mapping).items()
277        }
278
279        schema = self.find(normalized_table, raise_on_missing=False)
280        if schema and not normalized_column_mapping:
281            return
282
283        parts = self.table_parts(normalized_table)
284
285        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
286        new_trie([parts], self.mapping_trie)
287
288    def column_names(
289        self,
290        table: exp.Table | str,
291        only_visible: bool = False,
292        dialect: DialectType = None,
293        normalize: t.Optional[bool] = None,
294    ) -> t.List[str]:
295        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
296
297        schema = self.find(normalized_table)
298        if schema is None:
299            return []
300
301        if not only_visible or not self.visible:
302            return list(schema)
303
304        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
305        return [col for col in schema if col in visible]
306
307    def get_column_type(
308        self,
309        table: exp.Table | str,
310        column: exp.Column | str,
311        dialect: DialectType = None,
312        normalize: t.Optional[bool] = None,
313    ) -> exp.DataType:
314        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
315
316        normalized_column_name = self._normalize_name(
317            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
318        )
319
320        table_schema = self.find(normalized_table, raise_on_missing=False)
321        if table_schema:
322            column_type = table_schema.get(normalized_column_name)
323
324            if isinstance(column_type, exp.DataType):
325                return column_type
326            elif isinstance(column_type, str):
327                return self._to_data_type(column_type, dialect=dialect)
328
329        return exp.DataType.build("unknown")
330
331    def has_column(
332        self,
333        table: exp.Table | str,
334        column: exp.Column | str,
335        dialect: DialectType = None,
336        normalize: t.Optional[bool] = None,
337    ) -> bool:
338        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
339
340        normalized_column_name = self._normalize_name(
341            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
342        )
343
344        table_schema = self.find(normalized_table, raise_on_missing=False)
345        return normalized_column_name in table_schema if table_schema else False
346
347    def _normalize(self, schema: t.Dict) -> t.Dict:
348        """
349        Normalizes all identifiers in the schema.
350
351        Args:
352            schema: the schema to normalize.
353
354        Returns:
355            The normalized schema mapping.
356        """
357        normalized_mapping: t.Dict = {}
358        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
359
360        for keys in flattened_schema:
361            columns = nested_get(schema, *zip(keys, keys))
362
363            if not isinstance(columns, dict):
364                raise SchemaError(
365                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
366                )
367
368            normalized_keys = [
369                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
370            ]
371            for column_name, column_type in columns.items():
372                nested_set(
373                    normalized_mapping,
374                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
375                    column_type,
376                )
377
378        return normalized_mapping
379
380    def _normalize_table(
381        self,
382        table: exp.Table | str,
383        dialect: DialectType = None,
384        normalize: t.Optional[bool] = None,
385    ) -> exp.Table:
386        normalized_table = exp.maybe_parse(
387            table, into=exp.Table, dialect=dialect or self.dialect, copy=True
388        )
389
390        for arg in TABLE_ARGS:
391            value = normalized_table.args.get(arg)
392            if isinstance(value, (str, exp.Identifier)):
393                normalized_table.set(
394                    arg,
395                    exp.to_identifier(
396                        self._normalize_name(
397                            value, dialect=dialect, is_table=True, normalize=normalize
398                        )
399                    ),
400                )
401
402        return normalized_table
403
404    def _normalize_name(
405        self,
406        name: str | exp.Identifier,
407        dialect: DialectType = None,
408        is_table: bool = False,
409        normalize: t.Optional[bool] = None,
410    ) -> str:
411        return normalize_name(
412            name,
413            dialect=dialect or self.dialect,
414            is_table=is_table,
415            normalize=self.normalize if normalize is None else normalize,
416        )
417
418    def depth(self) -> int:
419        if not self.empty and not self._depth:
420            # The columns themselves are a mapping, but we don't want to include those
421            self._depth = super().depth() - 1
422        return self._depth
423
424    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
425        """
426        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
427
428        Args:
429            schema_type: the type we want to convert.
430            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
431
432        Returns:
433            The resulting expression type.
434        """
435        if schema_type not in self._type_mapping_cache:
436            dialect = dialect or self.dialect
437            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
438
439            try:
440                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
441                self._type_mapping_cache[schema_type] = expression
442            except AttributeError:
443                in_dialect = f" in dialect {dialect}" if dialect else ""
444                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
445
446        return self._type_mapping_cache[schema_type]
447
448
449def normalize_name(
450    identifier: str | exp.Identifier,
451    dialect: DialectType = None,
452    is_table: bool = False,
453    normalize: t.Optional[bool] = True,
454) -> str:
455    if isinstance(identifier, str):
456        identifier = exp.parse_identifier(identifier, dialect=dialect)
457
458    if not normalize:
459        return identifier.name
460
461    # This can be useful for normalize_identifier
462    identifier.meta["is_table"] = is_table
463    return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
464
465
466def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
467    if isinstance(schema, Schema):
468        return schema
469
470    return MappingSchema(schema, **kwargs)
471
472
473def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
474    if mapping is None:
475        return {}
476    elif isinstance(mapping, dict):
477        return mapping
478    elif isinstance(mapping, str):
479        col_name_type_strs = [x.strip() for x in mapping.split(",")]
480        return {
481            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
482            for name_type_str in col_name_type_strs
483        }
484    # Check if mapping looks like a DataFrame StructType
485    elif hasattr(mapping, "simpleString"):
486        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
487    elif isinstance(mapping, list):
488        return {x.strip(): None for x in mapping}
489
490    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
491
492
493def flatten_schema(
494    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
495) -> t.List[t.List[str]]:
496    tables = []
497    keys = keys or []
498
499    for k, v in schema.items():
500        if depth >= 2:
501            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
502        elif depth == 1:
503            tables.append(keys + [k])
504
505    return tables
506
507
508def nested_get(
509    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
510) -> t.Optional[t.Any]:
511    """
512    Get a value for a nested dictionary.
513
514    Args:
515        d: the dictionary to search.
516        *path: tuples of (name, key), where:
517            `key` is the key in the dictionary to get.
518            `name` is a string to use in the error if `key` isn't found.
519
520    Returns:
521        The value or None if it doesn't exist.
522    """
523    for name, key in path:
524        d = d.get(key)  # type: ignore
525        if d is None:
526            if raise_on_missing:
527                name = "table" if name == "this" else name
528                raise ValueError(f"Unknown {name}: {key}")
529            return None
530
531    return d
532
533
534def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
535    """
536    In-place set a value for a nested dictionary
537
538    Example:
539        >>> nested_set({}, ["top_key", "second_key"], "value")
540        {'top_key': {'second_key': 'value'}}
541
542        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
543        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
544
545    Args:
546        d: dictionary to update.
547        keys: the keys that makeup the path to `value`.
548        value: the value to set in the dictionary for the given key path.
549
550    Returns:
551        The (possibly) updated dictionary.
552    """
553    if not keys:
554        return d
555
556    if len(keys) == 1:
557        d[keys[0]] = value
558        return d
559
560    subd = d
561    for key in keys[:-1]:
562        if key not in subd:
563            subd = subd.setdefault(key, {})
564        else:
565            subd = subd[key]
566
567    subd[keys[-1]] = value
568    return d
TABLE_ARGS = ('this', 'db', 'catalog')
class Schema(abc.ABC):
 22class Schema(abc.ABC):
 23    """Abstract base class for database schemas"""
 24
 25    dialect: DialectType
 26
 27    @abc.abstractmethod
 28    def add_table(
 29        self,
 30        table: exp.Table | str,
 31        column_mapping: t.Optional[ColumnMapping] = None,
 32        dialect: DialectType = None,
 33        normalize: t.Optional[bool] = None,
 34        match_depth: bool = True,
 35    ) -> None:
 36        """
 37        Register or update a table. Some implementing classes may require column information to also be provided.
 38        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 39
 40        Args:
 41            table: the `Table` expression instance or string representing the table.
 42            column_mapping: a column mapping that describes the structure of the table.
 43            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 44            normalize: whether to normalize identifiers according to the dialect of interest.
 45            match_depth: whether to enforce that the table must match the schema's depth or not.
 46        """
 47
 48    @abc.abstractmethod
 49    def column_names(
 50        self,
 51        table: exp.Table | str,
 52        only_visible: bool = False,
 53        dialect: DialectType = None,
 54        normalize: t.Optional[bool] = None,
 55    ) -> t.List[str]:
 56        """
 57        Get the column names for a table.
 58
 59        Args:
 60            table: the `Table` expression instance.
 61            only_visible: whether to include invisible columns.
 62            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 63            normalize: whether to normalize identifiers according to the dialect of interest.
 64
 65        Returns:
 66            The list of column names.
 67        """
 68
 69    @abc.abstractmethod
 70    def get_column_type(
 71        self,
 72        table: exp.Table | str,
 73        column: exp.Column | str,
 74        dialect: DialectType = None,
 75        normalize: t.Optional[bool] = None,
 76    ) -> exp.DataType:
 77        """
 78        Get the `sqlglot.exp.DataType` type of a column in the schema.
 79
 80        Args:
 81            table: the source table.
 82            column: the target column.
 83            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 84            normalize: whether to normalize identifiers according to the dialect of interest.
 85
 86        Returns:
 87            The resulting column type.
 88        """
 89
 90    def has_column(
 91        self,
 92        table: exp.Table | str,
 93        column: exp.Column | str,
 94        dialect: DialectType = None,
 95        normalize: t.Optional[bool] = None,
 96    ) -> bool:
 97        """
 98        Returns whether or not `column` appears in `table`'s schema.
 99
100        Args:
101            table: the source table.
102            column: the target column.
103            dialect: the SQL dialect that will be used to parse `table` if it's a string.
104            normalize: whether to normalize identifiers according to the dialect of interest.
105
106        Returns:
107            True if the column appears in the schema, False otherwise.
108        """
109        name = column if isinstance(column, str) else column.name
110        return name in self.column_names(table, dialect=dialect, normalize=normalize)
111
112    @property
113    @abc.abstractmethod
114    def supported_table_args(self) -> t.Tuple[str, ...]:
115        """
116        Table arguments this schema support, e.g. `("this", "db", "catalog")`
117        """
118
119    @property
120    def empty(self) -> bool:
121        """Returns whether or not the schema is empty."""
122        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
27    @abc.abstractmethod
28    def add_table(
29        self,
30        table: exp.Table | str,
31        column_mapping: t.Optional[ColumnMapping] = None,
32        dialect: DialectType = None,
33        normalize: t.Optional[bool] = None,
34        match_depth: bool = True,
35    ) -> None:
36        """
37        Register or update a table. Some implementing classes may require column information to also be provided.
38        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
39
40        Args:
41            table: the `Table` expression instance or string representing the table.
42            column_mapping: a column mapping that describes the structure of the table.
43            dialect: the SQL dialect that will be used to parse `table` if it's a string.
44            normalize: whether to normalize identifiers according to the dialect of interest.
45            match_depth: whether to enforce that the table must match the schema's depth or not.
46        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
48    @abc.abstractmethod
49    def column_names(
50        self,
51        table: exp.Table | str,
52        only_visible: bool = False,
53        dialect: DialectType = None,
54        normalize: t.Optional[bool] = None,
55    ) -> t.List[str]:
56        """
57        Get the column names for a table.
58
59        Args:
60            table: the `Table` expression instance.
61            only_visible: whether to include invisible columns.
62            dialect: the SQL dialect that will be used to parse `table` if it's a string.
63            normalize: whether to normalize identifiers according to the dialect of interest.
64
65        Returns:
66            The list of column names.
67        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
69    @abc.abstractmethod
70    def get_column_type(
71        self,
72        table: exp.Table | str,
73        column: exp.Column | str,
74        dialect: DialectType = None,
75        normalize: t.Optional[bool] = None,
76    ) -> exp.DataType:
77        """
78        Get the `sqlglot.exp.DataType` type of a column in the schema.
79
80        Args:
81            table: the source table.
82            column: the target column.
83            dialect: the SQL dialect that will be used to parse `table` if it's a string.
84            normalize: whether to normalize identifiers according to the dialect of interest.
85
86        Returns:
87            The resulting column type.
88        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
 90    def has_column(
 91        self,
 92        table: exp.Table | str,
 93        column: exp.Column | str,
 94        dialect: DialectType = None,
 95        normalize: t.Optional[bool] = None,
 96    ) -> bool:
 97        """
 98        Returns whether or not `column` appears in `table`'s schema.
 99
100        Args:
101            table: the source table.
102            column: the target column.
103            dialect: the SQL dialect that will be used to parse `table` if it's a string.
104            normalize: whether to normalize identifiers according to the dialect of interest.
105
106        Returns:
107            True if the column appears in the schema, False otherwise.
108        """
109        name = column if isinstance(column, str) else column.name
110        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether or not column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

supported_table_args: Tuple[str, ...]

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool

Returns whether or not the schema is empty.

class AbstractMappingSchema:
125class AbstractMappingSchema:
126    def __init__(
127        self,
128        mapping: t.Optional[t.Dict] = None,
129    ) -> None:
130        self.mapping = mapping or {}
131        self.mapping_trie = new_trie(
132            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
133        )
134        self._supported_table_args: t.Tuple[str, ...] = tuple()
135
136    @property
137    def empty(self) -> bool:
138        return not self.mapping
139
140    def depth(self) -> int:
141        return dict_depth(self.mapping)
142
143    @property
144    def supported_table_args(self) -> t.Tuple[str, ...]:
145        if not self._supported_table_args and self.mapping:
146            depth = self.depth()
147
148            if not depth:  # None
149                self._supported_table_args = tuple()
150            elif 1 <= depth <= 3:
151                self._supported_table_args = TABLE_ARGS[:depth]
152            else:
153                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
154
155        return self._supported_table_args
156
157    def table_parts(self, table: exp.Table) -> t.List[str]:
158        if isinstance(table.this, exp.ReadCSV):
159            return [table.this.name]
160        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
161
162    def find(
163        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
164    ) -> t.Optional[t.Any]:
165        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
166        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
167
168        if value == TrieResult.FAILED:
169            return None
170
171        if value == TrieResult.PREFIX:
172            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
173
174            if len(possibilities) == 1:
175                parts.extend(possibilities[0])
176            else:
177                message = ", ".join(".".join(parts) for parts in possibilities)
178                if raise_on_missing:
179                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
180                return None
181
182        return self.nested_get(parts, raise_on_missing=raise_on_missing)
183
184    def nested_get(
185        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
186    ) -> t.Optional[t.Any]:
187        return nested_get(
188            d or self.mapping,
189            *zip(self.supported_table_args, reversed(parts)),
190            raise_on_missing=raise_on_missing,
191        )
AbstractMappingSchema(mapping: Optional[Dict] = None)
126    def __init__(
127        self,
128        mapping: t.Optional[t.Dict] = None,
129    ) -> None:
130        self.mapping = mapping or {}
131        self.mapping_trie = new_trie(
132            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
133        )
134        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
def depth(self) -> int:
140    def depth(self) -> int:
141        return dict_depth(self.mapping)
supported_table_args: Tuple[str, ...]
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
157    def table_parts(self, table: exp.Table) -> t.List[str]:
158        if isinstance(table.this, exp.ReadCSV):
159            return [table.this.name]
160        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[Any]:
162    def find(
163        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
164    ) -> t.Optional[t.Any]:
165        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
166        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
167
168        if value == TrieResult.FAILED:
169            return None
170
171        if value == TrieResult.PREFIX:
172            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
173
174            if len(possibilities) == 1:
175                parts.extend(possibilities[0])
176            else:
177                message = ", ".join(".".join(parts) for parts in possibilities)
178                if raise_on_missing:
179                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
180                return None
181
182        return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
184    def nested_get(
185        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
186    ) -> t.Optional[t.Any]:
187        return nested_get(
188            d or self.mapping,
189            *zip(self.supported_table_args, reversed(parts)),
190            raise_on_missing=raise_on_missing,
191        )
class MappingSchema(AbstractMappingSchema, Schema):
194class MappingSchema(AbstractMappingSchema, Schema):
195    """
196    Schema based on a nested mapping.
197
198    Args:
199        schema: Mapping in one of the following forms:
200            1. {table: {col: type}}
201            2. {db: {table: {col: type}}}
202            3. {catalog: {db: {table: {col: type}}}}
203            4. None - Tables will be added later
204        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
205            are assumed to be visible. The nesting should mirror that of the schema:
206            1. {table: set(*cols)}}
207            2. {db: {table: set(*cols)}}}
208            3. {catalog: {db: {table: set(*cols)}}}}
209        dialect: The dialect to be used for custom type mappings & parsing string arguments.
210        normalize: Whether to normalize identifier names according to the given dialect or not.
211    """
212
213    def __init__(
214        self,
215        schema: t.Optional[t.Dict] = None,
216        visible: t.Optional[t.Dict] = None,
217        dialect: DialectType = None,
218        normalize: bool = True,
219    ) -> None:
220        self.dialect = dialect
221        self.visible = visible or {}
222        self.normalize = normalize
223        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
224        self._depth = 0
225
226        super().__init__(self._normalize(schema or {}))
227
228    @classmethod
229    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
230        return MappingSchema(
231            schema=mapping_schema.mapping,
232            visible=mapping_schema.visible,
233            dialect=mapping_schema.dialect,
234            normalize=mapping_schema.normalize,
235        )
236
237    def copy(self, **kwargs) -> MappingSchema:
238        return MappingSchema(
239            **{  # type: ignore
240                "schema": self.mapping.copy(),
241                "visible": self.visible.copy(),
242                "dialect": self.dialect,
243                "normalize": self.normalize,
244                **kwargs,
245            }
246        )
247
248    def add_table(
249        self,
250        table: exp.Table | str,
251        column_mapping: t.Optional[ColumnMapping] = None,
252        dialect: DialectType = None,
253        normalize: t.Optional[bool] = None,
254        match_depth: bool = True,
255    ) -> None:
256        """
257        Register or update a table. Updates are only performed if a new column mapping is provided.
258        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
259
260        Args:
261            table: the `Table` expression instance or string representing the table.
262            column_mapping: a column mapping that describes the structure of the table.
263            dialect: the SQL dialect that will be used to parse `table` if it's a string.
264            normalize: whether to normalize identifiers according to the dialect of interest.
265            match_depth: whether to enforce that the table must match the schema's depth or not.
266        """
267        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
268
269        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
270            raise SchemaError(
271                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
272                f"schema's nesting level: {self.depth()}."
273            )
274
275        normalized_column_mapping = {
276            self._normalize_name(key, dialect=dialect, normalize=normalize): value
277            for key, value in ensure_column_mapping(column_mapping).items()
278        }
279
280        schema = self.find(normalized_table, raise_on_missing=False)
281        if schema and not normalized_column_mapping:
282            return
283
284        parts = self.table_parts(normalized_table)
285
286        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
287        new_trie([parts], self.mapping_trie)
288
289    def column_names(
290        self,
291        table: exp.Table | str,
292        only_visible: bool = False,
293        dialect: DialectType = None,
294        normalize: t.Optional[bool] = None,
295    ) -> t.List[str]:
296        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
297
298        schema = self.find(normalized_table)
299        if schema is None:
300            return []
301
302        if not only_visible or not self.visible:
303            return list(schema)
304
305        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
306        return [col for col in schema if col in visible]
307
308    def get_column_type(
309        self,
310        table: exp.Table | str,
311        column: exp.Column | str,
312        dialect: DialectType = None,
313        normalize: t.Optional[bool] = None,
314    ) -> exp.DataType:
315        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
316
317        normalized_column_name = self._normalize_name(
318            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
319        )
320
321        table_schema = self.find(normalized_table, raise_on_missing=False)
322        if table_schema:
323            column_type = table_schema.get(normalized_column_name)
324
325            if isinstance(column_type, exp.DataType):
326                return column_type
327            elif isinstance(column_type, str):
328                return self._to_data_type(column_type, dialect=dialect)
329
330        return exp.DataType.build("unknown")
331
332    def has_column(
333        self,
334        table: exp.Table | str,
335        column: exp.Column | str,
336        dialect: DialectType = None,
337        normalize: t.Optional[bool] = None,
338    ) -> bool:
339        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
340
341        normalized_column_name = self._normalize_name(
342            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
343        )
344
345        table_schema = self.find(normalized_table, raise_on_missing=False)
346        return normalized_column_name in table_schema if table_schema else False
347
348    def _normalize(self, schema: t.Dict) -> t.Dict:
349        """
350        Normalizes all identifiers in the schema.
351
352        Args:
353            schema: the schema to normalize.
354
355        Returns:
356            The normalized schema mapping.
357        """
358        normalized_mapping: t.Dict = {}
359        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
360
361        for keys in flattened_schema:
362            columns = nested_get(schema, *zip(keys, keys))
363
364            if not isinstance(columns, dict):
365                raise SchemaError(
366                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
367                )
368
369            normalized_keys = [
370                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
371            ]
372            for column_name, column_type in columns.items():
373                nested_set(
374                    normalized_mapping,
375                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
376                    column_type,
377                )
378
379        return normalized_mapping
380
381    def _normalize_table(
382        self,
383        table: exp.Table | str,
384        dialect: DialectType = None,
385        normalize: t.Optional[bool] = None,
386    ) -> exp.Table:
387        normalized_table = exp.maybe_parse(
388            table, into=exp.Table, dialect=dialect or self.dialect, copy=True
389        )
390
391        for arg in TABLE_ARGS:
392            value = normalized_table.args.get(arg)
393            if isinstance(value, (str, exp.Identifier)):
394                normalized_table.set(
395                    arg,
396                    exp.to_identifier(
397                        self._normalize_name(
398                            value, dialect=dialect, is_table=True, normalize=normalize
399                        )
400                    ),
401                )
402
403        return normalized_table
404
405    def _normalize_name(
406        self,
407        name: str | exp.Identifier,
408        dialect: DialectType = None,
409        is_table: bool = False,
410        normalize: t.Optional[bool] = None,
411    ) -> str:
412        return normalize_name(
413            name,
414            dialect=dialect or self.dialect,
415            is_table=is_table,
416            normalize=self.normalize if normalize is None else normalize,
417        )
418
419    def depth(self) -> int:
420        if not self.empty and not self._depth:
421            # The columns themselves are a mapping, but we don't want to include those
422            self._depth = super().depth() - 1
423        return self._depth
424
425    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
426        """
427        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
428
429        Args:
430            schema_type: the type we want to convert.
431            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
432
433        Returns:
434            The resulting expression type.
435        """
436        if schema_type not in self._type_mapping_cache:
437            dialect = dialect or self.dialect
438            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
439
440            try:
441                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
442                self._type_mapping_cache[schema_type] = expression
443            except AttributeError:
444                in_dialect = f" in dialect {dialect}" if dialect else ""
445                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
446
447        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
213    def __init__(
214        self,
215        schema: t.Optional[t.Dict] = None,
216        visible: t.Optional[t.Dict] = None,
217        dialect: DialectType = None,
218        normalize: bool = True,
219    ) -> None:
220        self.dialect = dialect
221        self.visible = visible or {}
222        self.normalize = normalize
223        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
224        self._depth = 0
225
226        super().__init__(self._normalize(schema or {}))
dialect
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
228    @classmethod
229    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
230        return MappingSchema(
231            schema=mapping_schema.mapping,
232            visible=mapping_schema.visible,
233            dialect=mapping_schema.dialect,
234            normalize=mapping_schema.normalize,
235        )
def copy(self, **kwargs) -> MappingSchema:
237    def copy(self, **kwargs) -> MappingSchema:
238        return MappingSchema(
239            **{  # type: ignore
240                "schema": self.mapping.copy(),
241                "visible": self.visible.copy(),
242                "dialect": self.dialect,
243                "normalize": self.normalize,
244                **kwargs,
245            }
246        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
248    def add_table(
249        self,
250        table: exp.Table | str,
251        column_mapping: t.Optional[ColumnMapping] = None,
252        dialect: DialectType = None,
253        normalize: t.Optional[bool] = None,
254        match_depth: bool = True,
255    ) -> None:
256        """
257        Register or update a table. Updates are only performed if a new column mapping is provided.
258        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
259
260        Args:
261            table: the `Table` expression instance or string representing the table.
262            column_mapping: a column mapping that describes the structure of the table.
263            dialect: the SQL dialect that will be used to parse `table` if it's a string.
264            normalize: whether to normalize identifiers according to the dialect of interest.
265            match_depth: whether to enforce that the table must match the schema's depth or not.
266        """
267        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
268
269        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
270            raise SchemaError(
271                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
272                f"schema's nesting level: {self.depth()}."
273            )
274
275        normalized_column_mapping = {
276            self._normalize_name(key, dialect=dialect, normalize=normalize): value
277            for key, value in ensure_column_mapping(column_mapping).items()
278        }
279
280        schema = self.find(normalized_table, raise_on_missing=False)
281        if schema and not normalized_column_mapping:
282            return
283
284        parts = self.table_parts(normalized_table)
285
286        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
287        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
289    def column_names(
290        self,
291        table: exp.Table | str,
292        only_visible: bool = False,
293        dialect: DialectType = None,
294        normalize: t.Optional[bool] = None,
295    ) -> t.List[str]:
296        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
297
298        schema = self.find(normalized_table)
299        if schema is None:
300            return []
301
302        if not only_visible or not self.visible:
303            return list(schema)
304
305        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
306        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The list of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
308    def get_column_type(
309        self,
310        table: exp.Table | str,
311        column: exp.Column | str,
312        dialect: DialectType = None,
313        normalize: t.Optional[bool] = None,
314    ) -> exp.DataType:
315        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
316
317        normalized_column_name = self._normalize_name(
318            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
319        )
320
321        table_schema = self.find(normalized_table, raise_on_missing=False)
322        if table_schema:
323            column_type = table_schema.get(normalized_column_name)
324
325            if isinstance(column_type, exp.DataType):
326                return column_type
327            elif isinstance(column_type, str):
328                return self._to_data_type(column_type, dialect=dialect)
329
330        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
332    def has_column(
333        self,
334        table: exp.Table | str,
335        column: exp.Column | str,
336        dialect: DialectType = None,
337        normalize: t.Optional[bool] = None,
338    ) -> bool:
339        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
340
341        normalized_column_name = self._normalize_name(
342            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
343        )
344
345        table_schema = self.find(normalized_table, raise_on_missing=False)
346        return normalized_column_name in table_schema if table_schema else False

Returns whether or not column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
419    def depth(self) -> int:
420        if not self.empty and not self._depth:
421            # The columns themselves are a mapping, but we don't want to include those
422            self._depth = super().depth() - 1
423        return self._depth
def normalize_name( identifier: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> str:
450def normalize_name(
451    identifier: str | exp.Identifier,
452    dialect: DialectType = None,
453    is_table: bool = False,
454    normalize: t.Optional[bool] = True,
455) -> str:
456    if isinstance(identifier, str):
457        identifier = exp.parse_identifier(identifier, dialect=dialect)
458
459    if not normalize:
460        return identifier.name
461
462    # This can be useful for normalize_identifier
463    identifier.meta["is_table"] = is_table
464    return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
def ensure_schema( schema: Union[Schema, Dict, NoneType], **kwargs: Any) -> Schema:
467def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
468    if isinstance(schema, Schema):
469        return schema
470
471    return MappingSchema(schema, **kwargs)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
474def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
475    if mapping is None:
476        return {}
477    elif isinstance(mapping, dict):
478        return mapping
479    elif isinstance(mapping, str):
480        col_name_type_strs = [x.strip() for x in mapping.split(",")]
481        return {
482            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
483            for name_type_str in col_name_type_strs
484        }
485    # Check if mapping looks like a DataFrame StructType
486    elif hasattr(mapping, "simpleString"):
487        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
488    elif isinstance(mapping, list):
489        return {x.strip(): None for x in mapping}
490
491    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
494def flatten_schema(
495    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
496) -> t.List[t.List[str]]:
497    tables = []
498    keys = keys or []
499
500    for k, v in schema.items():
501        if depth >= 2:
502            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
503        elif depth == 1:
504            tables.append(keys + [k])
505
506    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
509def nested_get(
510    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
511) -> t.Optional[t.Any]:
512    """
513    Get a value for a nested dictionary.
514
515    Args:
516        d: the dictionary to search.
517        *path: tuples of (name, key), where:
518            `key` is the key in the dictionary to get.
519            `name` is a string to use in the error if `key` isn't found.
520
521    Returns:
522        The value or None if it doesn't exist.
523    """
524    for name, key in path:
525        d = d.get(key)  # type: ignore
526        if d is None:
527            if raise_on_missing:
528                name = "table" if name == "this" else name
529                raise ValueError(f"Unknown {name}: {key}")
530            return None
531
532    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
535def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
536    """
537    In-place set a value for a nested dictionary
538
539    Example:
540        >>> nested_set({}, ["top_key", "second_key"], "value")
541        {'top_key': {'second_key': 'value'}}
542
543        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
544        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
545
546    Args:
547        d: dictionary to update.
548        keys: the keys that makeup the path to `value`.
549        value: the value to set in the dictionary for the given key path.
550
551    Returns:
552        The (possibly) updated dictionary.
553    """
554    if not keys:
555        return d
556
557    if len(keys) == 1:
558        d[keys[0]] = value
559        return d
560
561    subd = d
562    for key in keys[:-1]:
563        if key not in subd:
564            subd = subd.setdefault(key, {})
565        else:
566            subd = subd[key]
567
568    subd[keys[-1]] = value
569    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.