Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot._typing import E
  8from sqlglot.dialects.dialect import Dialect, DialectType
  9from sqlglot.errors import OptimizeError
 10from sqlglot.helper import seq_get
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15
 16def qualify_columns(
 17    expression: exp.Expression,
 18    schema: t.Dict | Schema,
 19    expand_alias_refs: bool = True,
 20    expand_stars: bool = True,
 21    infer_schema: t.Optional[bool] = None,
 22) -> exp.Expression:
 23    """
 24    Rewrite sqlglot AST to have fully qualified columns.
 25
 26    Example:
 27        >>> import sqlglot
 28        >>> schema = {"tbl": {"col": "INT"}}
 29        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 30        >>> qualify_columns(expression, schema).sql()
 31        'SELECT tbl.col AS col FROM tbl'
 32
 33    Args:
 34        expression: Expression to qualify.
 35        schema: Database schema.
 36        expand_alias_refs: Whether or not to expand references to aliases.
 37        expand_stars: Whether or not to expand star queries. This is a necessary step
 38            for most of the optimizer's rules to work; do not set to False unless you
 39            know what you're doing!
 40        infer_schema: Whether or not to infer the schema if missing.
 41
 42    Returns:
 43        The qualified expression.
 44    """
 45    schema = ensure_schema(schema)
 46    infer_schema = schema.empty if infer_schema is None else infer_schema
 47    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
 48
 49    for scope in traverse_scope(expression):
 50        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 51        _pop_table_column_aliases(scope.ctes)
 52        _pop_table_column_aliases(scope.derived_tables)
 53        using_column_tables = _expand_using(scope, resolver)
 54
 55        if schema.empty and expand_alias_refs:
 56            _expand_alias_refs(scope, resolver)
 57
 58        _qualify_columns(scope, resolver)
 59
 60        if not schema.empty and expand_alias_refs:
 61            _expand_alias_refs(scope, resolver)
 62
 63        if not isinstance(scope.expression, exp.UDTF):
 64            if expand_stars:
 65                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 66            qualify_outputs(scope)
 67
 68        _expand_group_by(scope)
 69        _expand_order_by(scope, resolver)
 70
 71    return expression
 72
 73
 74def validate_qualify_columns(expression: E) -> E:
 75    """Raise an `OptimizeError` if any columns aren't qualified"""
 76    unqualified_columns = []
 77    for scope in traverse_scope(expression):
 78        if isinstance(scope.expression, exp.Select):
 79            unqualified_columns.extend(scope.unqualified_columns)
 80            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 81                column = scope.external_columns[0]
 82                raise OptimizeError(
 83                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
 84                )
 85
 86    if unqualified_columns:
 87        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
 88    return expression
 89
 90
 91def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
 92    """
 93    Remove table column aliases.
 94
 95    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 96    """
 97    for derived_table in derived_tables:
 98        table_alias = derived_table.args.get("alias")
 99        if table_alias:
100            table_alias.args.pop("columns", None)
101
102
103def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
104    joins = list(scope.find_all(exp.Join))
105    names = {join.alias_or_name for join in joins}
106    ordered = [key for key in scope.selected_sources if key not in names]
107
108    # Mapping of automatically joined column names to an ordered set of source names (dict).
109    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
110
111    for join in joins:
112        using = join.args.get("using")
113
114        if not using:
115            continue
116
117        join_table = join.alias_or_name
118
119        columns = {}
120
121        for source_name in scope.selected_sources:
122            if source_name in ordered:
123                for column_name in resolver.get_source_columns(source_name):
124                    if column_name not in columns:
125                        columns[column_name] = source_name
126
127        source_table = ordered[-1]
128        ordered.append(join_table)
129        join_columns = resolver.get_source_columns(join_table)
130        conditions = []
131
132        for identifier in using:
133            identifier = identifier.name
134            table = columns.get(identifier)
135
136            if not table or identifier not in join_columns:
137                if (columns and "*" not in columns) and join_columns:
138                    raise OptimizeError(f"Cannot automatically join: {identifier}")
139
140            table = table or source_table
141            conditions.append(
142                exp.condition(
143                    exp.EQ(
144                        this=exp.column(identifier, table=table),
145                        expression=exp.column(identifier, table=join_table),
146                    )
147                )
148            )
149
150            # Set all values in the dict to None, because we only care about the key ordering
151            tables = column_tables.setdefault(identifier, {})
152            if table not in tables:
153                tables[table] = None
154            if join_table not in tables:
155                tables[join_table] = None
156
157        join.args.pop("using")
158        join.set("on", exp.and_(*conditions, copy=False))
159
160    if column_tables:
161        for column in scope.columns:
162            if not column.table and column.name in column_tables:
163                tables = column_tables[column.name]
164                coalesce = [exp.column(column.name, table=table) for table in tables]
165                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
166
167                # Ensure selects keep their output name
168                if isinstance(column.parent, exp.Select):
169                    replacement = alias(replacement, alias=column.name, copy=False)
170
171                scope.replace(column, replacement)
172
173    return column_tables
174
175
176def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
177    expression = scope.expression
178
179    if not isinstance(expression, exp.Select):
180        return
181
182    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
183
184    def replace_columns(
185        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
186    ) -> None:
187        if not node:
188            return
189
190        for column, *_ in walk_in_scope(node):
191            if not isinstance(column, exp.Column):
192                continue
193
194            table = resolver.get_table(column.name) if resolve_table and not column.table else None
195            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
196            double_agg = (
197                (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
198                if alias_expr
199                else False
200            )
201
202            if table and (not alias_expr or double_agg):
203                column.set("table", table)
204            elif not column.table and alias_expr and not double_agg:
205                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
206                    if literal_index:
207                        column.replace(exp.Literal.number(i))
208                else:
209                    column = column.replace(exp.paren(alias_expr))
210                    simplified = simplify_parens(column)
211                    if simplified is not column:
212                        column.replace(simplified)
213
214    for i, projection in enumerate(scope.expression.selects):
215        replace_columns(projection)
216
217        if isinstance(projection, exp.Alias):
218            alias_to_expression[projection.alias] = (projection.this, i + 1)
219
220    replace_columns(expression.args.get("where"))
221    replace_columns(expression.args.get("group"), literal_index=True)
222    replace_columns(expression.args.get("having"), resolve_table=True)
223    replace_columns(expression.args.get("qualify"), resolve_table=True)
224    scope.clear_cache()
225
226
227def _expand_group_by(scope: Scope) -> None:
228    expression = scope.expression
229    group = expression.args.get("group")
230    if not group:
231        return
232
233    group.set("expressions", _expand_positional_references(scope, group.expressions))
234    expression.set("group", group)
235
236
237def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
238    order = scope.expression.args.get("order")
239    if not order:
240        return
241
242    ordereds = order.expressions
243    for ordered, new_expression in zip(
244        ordereds,
245        _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
246    ):
247        for agg in ordered.find_all(exp.AggFunc):
248            for col in agg.find_all(exp.Column):
249                if not col.table:
250                    col.set("table", resolver.get_table(col.name))
251
252        ordered.set("this", new_expression)
253
254    if scope.expression.args.get("group"):
255        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
256
257        for ordered in ordereds:
258            ordered = ordered.this
259
260            ordered.replace(
261                exp.to_identifier(_select_by_pos(scope, ordered).alias)
262                if ordered.is_int
263                else selects.get(ordered, ordered)
264            )
265
266
267def _expand_positional_references(
268    scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
269) -> t.List[exp.Expression]:
270    new_nodes: t.List[exp.Expression] = []
271    for node in expressions:
272        if node.is_int:
273            select = _select_by_pos(scope, t.cast(exp.Literal, node))
274
275            if alias:
276                new_nodes.append(exp.column(select.args["alias"].copy()))
277            else:
278                select = select.this
279
280                if isinstance(select, exp.Literal):
281                    new_nodes.append(node)
282                else:
283                    new_nodes.append(select.copy())
284        else:
285            new_nodes.append(node)
286
287    return new_nodes
288
289
290def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
291    try:
292        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
293    except IndexError:
294        raise OptimizeError(f"Unknown output column: {node.name}")
295
296
297def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
298    """Disambiguate columns, ensuring each column specifies a source"""
299    for column in scope.columns:
300        column_table = column.table
301        column_name = column.name
302
303        if column_table and column_table in scope.sources:
304            source_columns = resolver.get_source_columns(column_table)
305            if source_columns and column_name not in source_columns and "*" not in source_columns:
306                raise OptimizeError(f"Unknown column: {column_name}")
307
308        if not column_table:
309            if scope.pivots and not column.find_ancestor(exp.Pivot):
310                # If the column is under the Pivot expression, we need to qualify it
311                # using the name of the pivoted source instead of the pivot's alias
312                column.set("table", exp.to_identifier(scope.pivots[0].alias))
313                continue
314
315            column_table = resolver.get_table(column_name)
316
317            # column_table can be a '' because bigquery unnest has no table alias
318            if column_table:
319                column.set("table", column_table)
320        elif column_table not in scope.sources and (
321            not scope.parent
322            or column_table not in scope.parent.sources
323            or not scope.is_correlated_subquery
324        ):
325            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
326            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
327
328            root, *parts = column.parts
329
330            if root.name in scope.sources:
331                # struct is already qualified, but we still need to change the AST representation
332                column_table = root
333                root, *parts = parts
334            else:
335                column_table = resolver.get_table(root.name)
336
337            if column_table:
338                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
339
340    for pivot in scope.pivots:
341        for column in pivot.find_all(exp.Column):
342            if not column.table and column.name in resolver.all_columns:
343                column_table = resolver.get_table(column.name)
344                if column_table:
345                    column.set("table", column_table)
346
347
348def _expand_stars(
349    scope: Scope,
350    resolver: Resolver,
351    using_column_tables: t.Dict[str, t.Any],
352    pseudocolumns: t.Set[str],
353) -> None:
354    """Expand stars to lists of column selections"""
355
356    new_selections = []
357    except_columns: t.Dict[int, t.Set[str]] = {}
358    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
359    coalesced_columns = set()
360
361    # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
362    pivot_columns = None
363    pivot_output_columns = None
364    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
365
366    has_pivoted_source = pivot and not pivot.args.get("unpivot")
367    if pivot and has_pivoted_source:
368        pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
369
370        pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
371        if not pivot_output_columns:
372            pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
373
374    for expression in scope.expression.selects:
375        if isinstance(expression, exp.Star):
376            tables = list(scope.selected_sources)
377            _add_except_columns(expression, tables, except_columns)
378            _add_replace_columns(expression, tables, replace_columns)
379        elif expression.is_star:
380            tables = [expression.table]
381            _add_except_columns(expression.this, tables, except_columns)
382            _add_replace_columns(expression.this, tables, replace_columns)
383        else:
384            new_selections.append(expression)
385            continue
386
387        for table in tables:
388            if table not in scope.sources:
389                raise OptimizeError(f"Unknown table: {table}")
390
391            columns = resolver.get_source_columns(table, only_visible=True)
392
393            if pseudocolumns:
394                columns = [name for name in columns if name.upper() not in pseudocolumns]
395
396            if columns and "*" not in columns:
397                table_id = id(table)
398                columns_to_exclude = except_columns.get(table_id) or set()
399
400                if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
401                    implicit_columns = [col for col in columns if col not in pivot_columns]
402                    new_selections.extend(
403                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
404                        for name in implicit_columns + pivot_output_columns
405                        if name not in columns_to_exclude
406                    )
407                    continue
408
409                for name in columns:
410                    if name in using_column_tables and table in using_column_tables[name]:
411                        if name in coalesced_columns:
412                            continue
413
414                        coalesced_columns.add(name)
415                        tables = using_column_tables[name]
416                        coalesce = [exp.column(name, table=table) for table in tables]
417
418                        new_selections.append(
419                            alias(
420                                exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
421                                alias=name,
422                                copy=False,
423                            )
424                        )
425                    elif name not in columns_to_exclude:
426                        alias_ = replace_columns.get(table_id, {}).get(name, name)
427                        column = exp.column(name, table=table)
428                        new_selections.append(
429                            alias(column, alias_, copy=False) if alias_ != name else column
430                        )
431            else:
432                return
433
434    # Ensures we don't overwrite the initial selections with an empty list
435    if new_selections:
436        scope.expression.set("expressions", new_selections)
437
438
439def _add_except_columns(
440    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
441) -> None:
442    except_ = expression.args.get("except")
443
444    if not except_:
445        return
446
447    columns = {e.name for e in except_}
448
449    for table in tables:
450        except_columns[id(table)] = columns
451
452
453def _add_replace_columns(
454    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
455) -> None:
456    replace = expression.args.get("replace")
457
458    if not replace:
459        return
460
461    columns = {e.this.name: e.alias for e in replace}
462
463    for table in tables:
464        replace_columns[id(table)] = columns
465
466
467def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
468    """Ensure all output columns are aliased"""
469    if isinstance(scope_or_expression, exp.Expression):
470        scope = build_scope(scope_or_expression)
471        if not isinstance(scope, Scope):
472            return
473    else:
474        scope = scope_or_expression
475
476    new_selections = []
477    for i, (selection, aliased_column) in enumerate(
478        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
479    ):
480        if isinstance(selection, exp.Subquery):
481            if not selection.output_name:
482                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
483        elif not isinstance(selection, exp.Alias) and not selection.is_star:
484            selection = alias(
485                selection,
486                alias=selection.output_name or f"_col_{i}",
487            )
488        if aliased_column:
489            selection.set("alias", exp.to_identifier(aliased_column))
490
491        new_selections.append(selection)
492
493    scope.expression.set("expressions", new_selections)
494
495
496def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
497    """Makes sure all identifiers that need to be quoted are quoted."""
498    return expression.transform(
499        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
500    )
501
502
503def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
504    """
505    Pushes down the CTE alias columns into the projection,
506
507    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
508
509    Example:
510        >>> import sqlglot
511        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
512        >>> pushdown_cte_alias_columns(expression).sql()
513        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
514
515    Args:
516        expression: Expression to pushdown.
517
518    Returns:
519        The expression with the CTE aliases pushed down into the projection.
520    """
521    for cte in expression.find_all(exp.CTE):
522        if cte.alias_column_names:
523            new_expressions = []
524            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
525                if isinstance(projection, exp.Alias):
526                    projection.set("alias", _alias)
527                else:
528                    projection = alias(projection, alias=_alias)
529                new_expressions.append(projection)
530            cte.this.set("expressions", new_expressions)
531
532    return expression
533
534
535class Resolver:
536    """
537    Helper for resolving columns.
538
539    This is a class so we can lazily load some things and easily share them across functions.
540    """
541
542    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
543        self.scope = scope
544        self.schema = schema
545        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
546        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
547        self._all_columns: t.Optional[t.Set[str]] = None
548        self._infer_schema = infer_schema
549
550    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
551        """
552        Get the table for a column name.
553
554        Args:
555            column_name: The column name to find the table for.
556        Returns:
557            The table name if it can be found/inferred.
558        """
559        if self._unambiguous_columns is None:
560            self._unambiguous_columns = self._get_unambiguous_columns(
561                self._get_all_source_columns()
562            )
563
564        table_name = self._unambiguous_columns.get(column_name)
565
566        if not table_name and self._infer_schema:
567            sources_without_schema = tuple(
568                source
569                for source, columns in self._get_all_source_columns().items()
570                if not columns or "*" in columns
571            )
572            if len(sources_without_schema) == 1:
573                table_name = sources_without_schema[0]
574
575        if table_name not in self.scope.selected_sources:
576            return exp.to_identifier(table_name)
577
578        node, _ = self.scope.selected_sources.get(table_name)
579
580        if isinstance(node, exp.Subqueryable):
581            while node and node.alias != table_name:
582                node = node.parent
583
584        node_alias = node.args.get("alias")
585        if node_alias:
586            return exp.to_identifier(node_alias.this)
587
588        return exp.to_identifier(table_name)
589
590    @property
591    def all_columns(self) -> t.Set[str]:
592        """All available columns of all sources in this scope"""
593        if self._all_columns is None:
594            self._all_columns = {
595                column for columns in self._get_all_source_columns().values() for column in columns
596            }
597        return self._all_columns
598
599    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
600        """Resolve the source columns for a given source `name`."""
601        if name not in self.scope.sources:
602            raise OptimizeError(f"Unknown table: {name}")
603
604        source = self.scope.sources[name]
605
606        if isinstance(source, exp.Table):
607            columns = self.schema.column_names(source, only_visible)
608        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
609            columns = source.expression.alias_column_names
610        else:
611            columns = source.expression.named_selects
612
613        node, _ = self.scope.selected_sources.get(name) or (None, None)
614        if isinstance(node, Scope):
615            column_aliases = node.expression.alias_column_names
616        elif isinstance(node, exp.Expression):
617            column_aliases = node.alias_column_names
618        else:
619            column_aliases = []
620
621        # If the source's columns are aliased, their aliases shadow the corresponding column names
622        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
623
624    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
625        if self._source_columns is None:
626            self._source_columns = {
627                source_name: self.get_source_columns(source_name)
628                for source_name, source in itertools.chain(
629                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
630                )
631            }
632        return self._source_columns
633
634    def _get_unambiguous_columns(
635        self, source_columns: t.Dict[str, t.List[str]]
636    ) -> t.Dict[str, str]:
637        """
638        Find all the unambiguous columns in sources.
639
640        Args:
641            source_columns: Mapping of names to source columns.
642
643        Returns:
644            Mapping of column name to source name.
645        """
646        if not source_columns:
647            return {}
648
649        source_columns_pairs = list(source_columns.items())
650
651        first_table, first_columns = source_columns_pairs[0]
652        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
653        all_columns = set(unambiguous_columns)
654
655        for table, columns in source_columns_pairs[1:]:
656            unique = self._find_unique_columns(columns)
657            ambiguous = set(all_columns).intersection(unique)
658            all_columns.update(columns)
659
660            for column in ambiguous:
661                unambiguous_columns.pop(column, None)
662            for column in unique.difference(ambiguous):
663                unambiguous_columns[column] = table
664
665        return unambiguous_columns
666
667    @staticmethod
668    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
669        """
670        Find the unique columns in a list of columns.
671
672        Example:
673            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
674            ['a', 'c']
675
676        This is necessary because duplicate column names are ambiguous.
677        """
678        counts: t.Dict[str, int] = {}
679        for column in columns:
680            counts[column] = counts.get(column, 0) + 1
681        return {column for column, count in counts.items() if count == 1}
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
17def qualify_columns(
18    expression: exp.Expression,
19    schema: t.Dict | Schema,
20    expand_alias_refs: bool = True,
21    expand_stars: bool = True,
22    infer_schema: t.Optional[bool] = None,
23) -> exp.Expression:
24    """
25    Rewrite sqlglot AST to have fully qualified columns.
26
27    Example:
28        >>> import sqlglot
29        >>> schema = {"tbl": {"col": "INT"}}
30        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
31        >>> qualify_columns(expression, schema).sql()
32        'SELECT tbl.col AS col FROM tbl'
33
34    Args:
35        expression: Expression to qualify.
36        schema: Database schema.
37        expand_alias_refs: Whether or not to expand references to aliases.
38        expand_stars: Whether or not to expand star queries. This is a necessary step
39            for most of the optimizer's rules to work; do not set to False unless you
40            know what you're doing!
41        infer_schema: Whether or not to infer the schema if missing.
42
43    Returns:
44        The qualified expression.
45    """
46    schema = ensure_schema(schema)
47    infer_schema = schema.empty if infer_schema is None else infer_schema
48    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
49
50    for scope in traverse_scope(expression):
51        resolver = Resolver(scope, schema, infer_schema=infer_schema)
52        _pop_table_column_aliases(scope.ctes)
53        _pop_table_column_aliases(scope.derived_tables)
54        using_column_tables = _expand_using(scope, resolver)
55
56        if schema.empty and expand_alias_refs:
57            _expand_alias_refs(scope, resolver)
58
59        _qualify_columns(scope, resolver)
60
61        if not schema.empty and expand_alias_refs:
62            _expand_alias_refs(scope, resolver)
63
64        if not isinstance(scope.expression, exp.UDTF):
65            if expand_stars:
66                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
67            qualify_outputs(scope)
68
69        _expand_group_by(scope)
70        _expand_order_by(scope, resolver)
71
72    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether or not to expand references to aliases.
  • expand_stars: Whether or not to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether or not to infer the schema if missing.
Returns:

The qualified expression.

def validate_qualify_columns(expression: ~E) -> ~E:
75def validate_qualify_columns(expression: E) -> E:
76    """Raise an `OptimizeError` if any columns aren't qualified"""
77    unqualified_columns = []
78    for scope in traverse_scope(expression):
79        if isinstance(scope.expression, exp.Select):
80            unqualified_columns.extend(scope.unqualified_columns)
81            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
82                column = scope.external_columns[0]
83                raise OptimizeError(
84                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
85                )
86
87    if unqualified_columns:
88        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
89    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
468def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
469    """Ensure all output columns are aliased"""
470    if isinstance(scope_or_expression, exp.Expression):
471        scope = build_scope(scope_or_expression)
472        if not isinstance(scope, Scope):
473            return
474    else:
475        scope = scope_or_expression
476
477    new_selections = []
478    for i, (selection, aliased_column) in enumerate(
479        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
480    ):
481        if isinstance(selection, exp.Subquery):
482            if not selection.output_name:
483                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
484        elif not isinstance(selection, exp.Alias) and not selection.is_star:
485            selection = alias(
486                selection,
487                alias=selection.output_name or f"_col_{i}",
488            )
489        if aliased_column:
490            selection.set("alias", exp.to_identifier(aliased_column))
491
492        new_selections.append(selection)
493
494    scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
497def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
498    """Makes sure all identifiers that need to be quoted are quoted."""
499    return expression.transform(
500        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
501    )

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
504def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
505    """
506    Pushes down the CTE alias columns into the projection,
507
508    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
509
510    Example:
511        >>> import sqlglot
512        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
513        >>> pushdown_cte_alias_columns(expression).sql()
514        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
515
516    Args:
517        expression: Expression to pushdown.
518
519    Returns:
520        The expression with the CTE aliases pushed down into the projection.
521    """
522    for cte in expression.find_all(exp.CTE):
523        if cte.alias_column_names:
524            new_expressions = []
525            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
526                if isinstance(projection, exp.Alias):
527                    projection.set("alias", _alias)
528                else:
529                    projection = alias(projection, alias=_alias)
530                new_expressions.append(projection)
531            cte.this.set("expressions", new_expressions)
532
533    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
536class Resolver:
537    """
538    Helper for resolving columns.
539
540    This is a class so we can lazily load some things and easily share them across functions.
541    """
542
543    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
544        self.scope = scope
545        self.schema = schema
546        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
547        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
548        self._all_columns: t.Optional[t.Set[str]] = None
549        self._infer_schema = infer_schema
550
551    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
552        """
553        Get the table for a column name.
554
555        Args:
556            column_name: The column name to find the table for.
557        Returns:
558            The table name if it can be found/inferred.
559        """
560        if self._unambiguous_columns is None:
561            self._unambiguous_columns = self._get_unambiguous_columns(
562                self._get_all_source_columns()
563            )
564
565        table_name = self._unambiguous_columns.get(column_name)
566
567        if not table_name and self._infer_schema:
568            sources_without_schema = tuple(
569                source
570                for source, columns in self._get_all_source_columns().items()
571                if not columns or "*" in columns
572            )
573            if len(sources_without_schema) == 1:
574                table_name = sources_without_schema[0]
575
576        if table_name not in self.scope.selected_sources:
577            return exp.to_identifier(table_name)
578
579        node, _ = self.scope.selected_sources.get(table_name)
580
581        if isinstance(node, exp.Subqueryable):
582            while node and node.alias != table_name:
583                node = node.parent
584
585        node_alias = node.args.get("alias")
586        if node_alias:
587            return exp.to_identifier(node_alias.this)
588
589        return exp.to_identifier(table_name)
590
591    @property
592    def all_columns(self) -> t.Set[str]:
593        """All available columns of all sources in this scope"""
594        if self._all_columns is None:
595            self._all_columns = {
596                column for columns in self._get_all_source_columns().values() for column in columns
597            }
598        return self._all_columns
599
600    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
601        """Resolve the source columns for a given source `name`."""
602        if name not in self.scope.sources:
603            raise OptimizeError(f"Unknown table: {name}")
604
605        source = self.scope.sources[name]
606
607        if isinstance(source, exp.Table):
608            columns = self.schema.column_names(source, only_visible)
609        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
610            columns = source.expression.alias_column_names
611        else:
612            columns = source.expression.named_selects
613
614        node, _ = self.scope.selected_sources.get(name) or (None, None)
615        if isinstance(node, Scope):
616            column_aliases = node.expression.alias_column_names
617        elif isinstance(node, exp.Expression):
618            column_aliases = node.alias_column_names
619        else:
620            column_aliases = []
621
622        # If the source's columns are aliased, their aliases shadow the corresponding column names
623        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
624
625    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
626        if self._source_columns is None:
627            self._source_columns = {
628                source_name: self.get_source_columns(source_name)
629                for source_name, source in itertools.chain(
630                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
631                )
632            }
633        return self._source_columns
634
635    def _get_unambiguous_columns(
636        self, source_columns: t.Dict[str, t.List[str]]
637    ) -> t.Dict[str, str]:
638        """
639        Find all the unambiguous columns in sources.
640
641        Args:
642            source_columns: Mapping of names to source columns.
643
644        Returns:
645            Mapping of column name to source name.
646        """
647        if not source_columns:
648            return {}
649
650        source_columns_pairs = list(source_columns.items())
651
652        first_table, first_columns = source_columns_pairs[0]
653        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
654        all_columns = set(unambiguous_columns)
655
656        for table, columns in source_columns_pairs[1:]:
657            unique = self._find_unique_columns(columns)
658            ambiguous = set(all_columns).intersection(unique)
659            all_columns.update(columns)
660
661            for column in ambiguous:
662                unambiguous_columns.pop(column, None)
663            for column in unique.difference(ambiguous):
664                unambiguous_columns[column] = table
665
666        return unambiguous_columns
667
668    @staticmethod
669    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
670        """
671        Find the unique columns in a list of columns.
672
673        Example:
674            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
675            ['a', 'c']
676
677        This is necessary because duplicate column names are ambiguous.
678        """
679        counts: t.Dict[str, int] = {}
680        for column in columns:
681            counts[column] = counts.get(column, 0) + 1
682        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
543    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
544        self.scope = scope
545        self.schema = schema
546        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
547        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
548        self._all_columns: t.Optional[t.Set[str]] = None
549        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
551    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
552        """
553        Get the table for a column name.
554
555        Args:
556            column_name: The column name to find the table for.
557        Returns:
558            The table name if it can be found/inferred.
559        """
560        if self._unambiguous_columns is None:
561            self._unambiguous_columns = self._get_unambiguous_columns(
562                self._get_all_source_columns()
563            )
564
565        table_name = self._unambiguous_columns.get(column_name)
566
567        if not table_name and self._infer_schema:
568            sources_without_schema = tuple(
569                source
570                for source, columns in self._get_all_source_columns().items()
571                if not columns or "*" in columns
572            )
573            if len(sources_without_schema) == 1:
574                table_name = sources_without_schema[0]
575
576        if table_name not in self.scope.selected_sources:
577            return exp.to_identifier(table_name)
578
579        node, _ = self.scope.selected_sources.get(table_name)
580
581        if isinstance(node, exp.Subqueryable):
582            while node and node.alias != table_name:
583                node = node.parent
584
585        node_alias = node.args.get("alias")
586        if node_alias:
587            return exp.to_identifier(node_alias.this)
588
589        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
591    @property
592    def all_columns(self) -> t.Set[str]:
593        """All available columns of all sources in this scope"""
594        if self._all_columns is None:
595            self._all_columns = {
596                column for columns in self._get_all_source_columns().values() for column in columns
597            }
598        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
600    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
601        """Resolve the source columns for a given source `name`."""
602        if name not in self.scope.sources:
603            raise OptimizeError(f"Unknown table: {name}")
604
605        source = self.scope.sources[name]
606
607        if isinstance(source, exp.Table):
608            columns = self.schema.column_names(source, only_visible)
609        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
610            columns = source.expression.alias_column_names
611        else:
612            columns = source.expression.named_selects
613
614        node, _ = self.scope.selected_sources.get(name) or (None, None)
615        if isinstance(node, Scope):
616            column_aliases = node.expression.alias_column_names
617        elif isinstance(node, exp.Expression):
618            column_aliases = node.alias_column_names
619        else:
620            column_aliases = []
621
622        # If the source's columns are aliased, their aliases shadow the corresponding column names
623        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]

Resolve the source columns for a given source name.