Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot._typing import E
  8from sqlglot.dialects.dialect import Dialect, DialectType
  9from sqlglot.errors import OptimizeError
 10from sqlglot.helper import seq_get
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15
 16def qualify_columns(
 17    expression: exp.Expression,
 18    schema: t.Dict | Schema,
 19    expand_alias_refs: bool = True,
 20    infer_schema: t.Optional[bool] = None,
 21) -> exp.Expression:
 22    """
 23    Rewrite sqlglot AST to have fully qualified columns.
 24
 25    Example:
 26        >>> import sqlglot
 27        >>> schema = {"tbl": {"col": "INT"}}
 28        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 29        >>> qualify_columns(expression, schema).sql()
 30        'SELECT tbl.col AS col FROM tbl'
 31
 32    Args:
 33        expression: Expression to qualify.
 34        schema: Database schema.
 35        expand_alias_refs: Whether or not to expand references to aliases.
 36        infer_schema: Whether or not to infer the schema if missing.
 37
 38    Returns:
 39        The qualified expression.
 40    """
 41    schema = ensure_schema(schema)
 42    infer_schema = schema.empty if infer_schema is None else infer_schema
 43    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
 44
 45    for scope in traverse_scope(expression):
 46        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 47        _pop_table_column_aliases(scope.ctes)
 48        _pop_table_column_aliases(scope.derived_tables)
 49        using_column_tables = _expand_using(scope, resolver)
 50
 51        if schema.empty and expand_alias_refs:
 52            _expand_alias_refs(scope, resolver)
 53
 54        _qualify_columns(scope, resolver)
 55
 56        if not schema.empty and expand_alias_refs:
 57            _expand_alias_refs(scope, resolver)
 58
 59        if not isinstance(scope.expression, exp.UDTF):
 60            _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 61            qualify_outputs(scope)
 62
 63        _expand_group_by(scope)
 64        _expand_order_by(scope, resolver)
 65
 66    return expression
 67
 68
 69def validate_qualify_columns(expression: E) -> E:
 70    """Raise an `OptimizeError` if any columns aren't qualified"""
 71    unqualified_columns = []
 72    for scope in traverse_scope(expression):
 73        if isinstance(scope.expression, exp.Select):
 74            unqualified_columns.extend(scope.unqualified_columns)
 75            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 76                column = scope.external_columns[0]
 77                raise OptimizeError(
 78                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
 79                )
 80
 81    if unqualified_columns:
 82        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
 83    return expression
 84
 85
 86def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
 87    """
 88    Remove table column aliases.
 89
 90    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 91    """
 92    for derived_table in derived_tables:
 93        table_alias = derived_table.args.get("alias")
 94        if table_alias:
 95            table_alias.args.pop("columns", None)
 96
 97
 98def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
 99    joins = list(scope.find_all(exp.Join))
100    names = {join.alias_or_name for join in joins}
101    ordered = [key for key in scope.selected_sources if key not in names]
102
103    # Mapping of automatically joined column names to an ordered set of source names (dict).
104    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
105
106    for join in joins:
107        using = join.args.get("using")
108
109        if not using:
110            continue
111
112        join_table = join.alias_or_name
113
114        columns = {}
115
116        for source_name in scope.selected_sources:
117            if source_name in ordered:
118                for column_name in resolver.get_source_columns(source_name):
119                    if column_name not in columns:
120                        columns[column_name] = source_name
121
122        source_table = ordered[-1]
123        ordered.append(join_table)
124        join_columns = resolver.get_source_columns(join_table)
125        conditions = []
126
127        for identifier in using:
128            identifier = identifier.name
129            table = columns.get(identifier)
130
131            if not table or identifier not in join_columns:
132                if (columns and "*" not in columns) and join_columns:
133                    raise OptimizeError(f"Cannot automatically join: {identifier}")
134
135            table = table or source_table
136            conditions.append(
137                exp.condition(
138                    exp.EQ(
139                        this=exp.column(identifier, table=table),
140                        expression=exp.column(identifier, table=join_table),
141                    )
142                )
143            )
144
145            # Set all values in the dict to None, because we only care about the key ordering
146            tables = column_tables.setdefault(identifier, {})
147            if table not in tables:
148                tables[table] = None
149            if join_table not in tables:
150                tables[join_table] = None
151
152        join.args.pop("using")
153        join.set("on", exp.and_(*conditions, copy=False))
154
155    if column_tables:
156        for column in scope.columns:
157            if not column.table and column.name in column_tables:
158                tables = column_tables[column.name]
159                coalesce = [exp.column(column.name, table=table) for table in tables]
160                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
161
162                # Ensure selects keep their output name
163                if isinstance(column.parent, exp.Select):
164                    replacement = alias(replacement, alias=column.name, copy=False)
165
166                scope.replace(column, replacement)
167
168    return column_tables
169
170
171def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
172    expression = scope.expression
173
174    if not isinstance(expression, exp.Select):
175        return
176
177    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
178
179    def replace_columns(
180        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
181    ) -> None:
182        if not node:
183            return
184
185        for column, *_ in walk_in_scope(node):
186            if not isinstance(column, exp.Column):
187                continue
188
189            table = resolver.get_table(column.name) if resolve_table and not column.table else None
190            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
191            double_agg = (
192                (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
193                if alias_expr
194                else False
195            )
196
197            if table and (not alias_expr or double_agg):
198                column.set("table", table)
199            elif not column.table and alias_expr and not double_agg:
200                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
201                    if literal_index:
202                        column.replace(exp.Literal.number(i))
203                else:
204                    column = column.replace(exp.paren(alias_expr))
205                    simplified = simplify_parens(column)
206                    if simplified is not column:
207                        column.replace(simplified)
208
209    for i, projection in enumerate(scope.expression.selects):
210        replace_columns(projection)
211
212        if isinstance(projection, exp.Alias):
213            alias_to_expression[projection.alias] = (projection.this, i + 1)
214
215    replace_columns(expression.args.get("where"))
216    replace_columns(expression.args.get("group"), literal_index=True)
217    replace_columns(expression.args.get("having"), resolve_table=True)
218    replace_columns(expression.args.get("qualify"), resolve_table=True)
219    scope.clear_cache()
220
221
222def _expand_group_by(scope: Scope) -> None:
223    expression = scope.expression
224    group = expression.args.get("group")
225    if not group:
226        return
227
228    group.set("expressions", _expand_positional_references(scope, group.expressions))
229    expression.set("group", group)
230
231
232def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
233    order = scope.expression.args.get("order")
234    if not order:
235        return
236
237    ordereds = order.expressions
238    for ordered, new_expression in zip(
239        ordereds,
240        _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
241    ):
242        for agg in ordered.find_all(exp.AggFunc):
243            for col in agg.find_all(exp.Column):
244                if not col.table:
245                    col.set("table", resolver.get_table(col.name))
246
247        ordered.set("this", new_expression)
248
249    if scope.expression.args.get("group"):
250        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
251
252        for ordered in ordereds:
253            ordered = ordered.this
254
255            ordered.replace(
256                exp.to_identifier(_select_by_pos(scope, ordered).alias)
257                if ordered.is_int
258                else selects.get(ordered, ordered)
259            )
260
261
262def _expand_positional_references(
263    scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
264) -> t.List[exp.Expression]:
265    new_nodes: t.List[exp.Expression] = []
266    for node in expressions:
267        if node.is_int:
268            select = _select_by_pos(scope, t.cast(exp.Literal, node))
269
270            if alias:
271                new_nodes.append(exp.column(select.args["alias"].copy()))
272            else:
273                select = select.this
274
275                if isinstance(select, exp.Literal):
276                    new_nodes.append(node)
277                else:
278                    new_nodes.append(select.copy())
279        else:
280            new_nodes.append(node)
281
282    return new_nodes
283
284
285def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
286    try:
287        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
288    except IndexError:
289        raise OptimizeError(f"Unknown output column: {node.name}")
290
291
292def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
293    """Disambiguate columns, ensuring each column specifies a source"""
294    for column in scope.columns:
295        column_table = column.table
296        column_name = column.name
297
298        if column_table and column_table in scope.sources:
299            source_columns = resolver.get_source_columns(column_table)
300            if source_columns and column_name not in source_columns and "*" not in source_columns:
301                raise OptimizeError(f"Unknown column: {column_name}")
302
303        if not column_table:
304            if scope.pivots and not column.find_ancestor(exp.Pivot):
305                # If the column is under the Pivot expression, we need to qualify it
306                # using the name of the pivoted source instead of the pivot's alias
307                column.set("table", exp.to_identifier(scope.pivots[0].alias))
308                continue
309
310            column_table = resolver.get_table(column_name)
311
312            # column_table can be a '' because bigquery unnest has no table alias
313            if column_table:
314                column.set("table", column_table)
315        elif column_table not in scope.sources and (
316            not scope.parent
317            or column_table not in scope.parent.sources
318            or not scope.is_correlated_subquery
319        ):
320            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
321            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
322
323            root, *parts = column.parts
324
325            if root.name in scope.sources:
326                # struct is already qualified, but we still need to change the AST representation
327                column_table = root
328                root, *parts = parts
329            else:
330                column_table = resolver.get_table(root.name)
331
332            if column_table:
333                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
334
335    for pivot in scope.pivots:
336        for column in pivot.find_all(exp.Column):
337            if not column.table and column.name in resolver.all_columns:
338                column_table = resolver.get_table(column.name)
339                if column_table:
340                    column.set("table", column_table)
341
342
343def _expand_stars(
344    scope: Scope,
345    resolver: Resolver,
346    using_column_tables: t.Dict[str, t.Any],
347    pseudocolumns: t.Set[str],
348) -> None:
349    """Expand stars to lists of column selections"""
350
351    new_selections = []
352    except_columns: t.Dict[int, t.Set[str]] = {}
353    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
354    coalesced_columns = set()
355
356    # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
357    pivot_columns = None
358    pivot_output_columns = None
359    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
360
361    has_pivoted_source = pivot and not pivot.args.get("unpivot")
362    if pivot and has_pivoted_source:
363        pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
364
365        pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
366        if not pivot_output_columns:
367            pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
368
369    for expression in scope.expression.selects:
370        if isinstance(expression, exp.Star):
371            tables = list(scope.selected_sources)
372            _add_except_columns(expression, tables, except_columns)
373            _add_replace_columns(expression, tables, replace_columns)
374        elif expression.is_star:
375            tables = [expression.table]
376            _add_except_columns(expression.this, tables, except_columns)
377            _add_replace_columns(expression.this, tables, replace_columns)
378        else:
379            new_selections.append(expression)
380            continue
381
382        for table in tables:
383            if table not in scope.sources:
384                raise OptimizeError(f"Unknown table: {table}")
385
386            columns = resolver.get_source_columns(table, only_visible=True)
387
388            if pseudocolumns:
389                columns = [name for name in columns if name.upper() not in pseudocolumns]
390
391            if columns and "*" not in columns:
392                table_id = id(table)
393                columns_to_exclude = except_columns.get(table_id) or set()
394
395                if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
396                    implicit_columns = [col for col in columns if col not in pivot_columns]
397                    new_selections.extend(
398                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
399                        for name in implicit_columns + pivot_output_columns
400                        if name not in columns_to_exclude
401                    )
402                    continue
403
404                for name in columns:
405                    if name in using_column_tables and table in using_column_tables[name]:
406                        if name in coalesced_columns:
407                            continue
408
409                        coalesced_columns.add(name)
410                        tables = using_column_tables[name]
411                        coalesce = [exp.column(name, table=table) for table in tables]
412
413                        new_selections.append(
414                            alias(
415                                exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
416                                alias=name,
417                                copy=False,
418                            )
419                        )
420                    elif name not in columns_to_exclude:
421                        alias_ = replace_columns.get(table_id, {}).get(name, name)
422                        column = exp.column(name, table=table)
423                        new_selections.append(
424                            alias(column, alias_, copy=False) if alias_ != name else column
425                        )
426            else:
427                return
428
429    # Ensures we don't overwrite the initial selections with an empty list
430    if new_selections:
431        scope.expression.set("expressions", new_selections)
432
433
434def _add_except_columns(
435    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
436) -> None:
437    except_ = expression.args.get("except")
438
439    if not except_:
440        return
441
442    columns = {e.name for e in except_}
443
444    for table in tables:
445        except_columns[id(table)] = columns
446
447
448def _add_replace_columns(
449    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
450) -> None:
451    replace = expression.args.get("replace")
452
453    if not replace:
454        return
455
456    columns = {e.this.name: e.alias for e in replace}
457
458    for table in tables:
459        replace_columns[id(table)] = columns
460
461
462def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
463    """Ensure all output columns are aliased"""
464    if isinstance(scope_or_expression, exp.Expression):
465        scope = build_scope(scope_or_expression)
466        if not isinstance(scope, Scope):
467            return
468    else:
469        scope = scope_or_expression
470
471    new_selections = []
472    for i, (selection, aliased_column) in enumerate(
473        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
474    ):
475        if isinstance(selection, exp.Subquery):
476            if not selection.output_name:
477                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
478        elif not isinstance(selection, exp.Alias) and not selection.is_star:
479            selection = alias(
480                selection,
481                alias=selection.output_name or f"_col_{i}",
482            )
483        if aliased_column:
484            selection.set("alias", exp.to_identifier(aliased_column))
485
486        new_selections.append(selection)
487
488    scope.expression.set("expressions", new_selections)
489
490
491def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
492    """Makes sure all identifiers that need to be quoted are quoted."""
493    return expression.transform(
494        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
495    )
496
497
498def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
499    """
500    Pushes down the CTE alias columns into the projection,
501
502    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
503
504    Example:
505        >>> import sqlglot
506        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
507        >>> pushdown_cte_alias_columns(expression).sql()
508        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
509
510    Args:
511        expression: Expression to pushdown.
512
513    Returns:
514        The expression with the CTE aliases pushed down into the projection.
515    """
516    for cte in expression.find_all(exp.CTE):
517        if cte.alias_column_names:
518            new_expressions = []
519            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
520                if isinstance(projection, exp.Alias):
521                    projection.set("alias", _alias)
522                else:
523                    projection = alias(projection, alias=_alias)
524                new_expressions.append(projection)
525            cte.this.set("expressions", new_expressions)
526
527    return expression
528
529
530class Resolver:
531    """
532    Helper for resolving columns.
533
534    This is a class so we can lazily load some things and easily share them across functions.
535    """
536
537    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
538        self.scope = scope
539        self.schema = schema
540        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
541        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
542        self._all_columns: t.Optional[t.Set[str]] = None
543        self._infer_schema = infer_schema
544
545    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
546        """
547        Get the table for a column name.
548
549        Args:
550            column_name: The column name to find the table for.
551        Returns:
552            The table name if it can be found/inferred.
553        """
554        if self._unambiguous_columns is None:
555            self._unambiguous_columns = self._get_unambiguous_columns(
556                self._get_all_source_columns()
557            )
558
559        table_name = self._unambiguous_columns.get(column_name)
560
561        if not table_name and self._infer_schema:
562            sources_without_schema = tuple(
563                source
564                for source, columns in self._get_all_source_columns().items()
565                if not columns or "*" in columns
566            )
567            if len(sources_without_schema) == 1:
568                table_name = sources_without_schema[0]
569
570        if table_name not in self.scope.selected_sources:
571            return exp.to_identifier(table_name)
572
573        node, _ = self.scope.selected_sources.get(table_name)
574
575        if isinstance(node, exp.Subqueryable):
576            while node and node.alias != table_name:
577                node = node.parent
578
579        node_alias = node.args.get("alias")
580        if node_alias:
581            return exp.to_identifier(node_alias.this)
582
583        return exp.to_identifier(table_name)
584
585    @property
586    def all_columns(self) -> t.Set[str]:
587        """All available columns of all sources in this scope"""
588        if self._all_columns is None:
589            self._all_columns = {
590                column for columns in self._get_all_source_columns().values() for column in columns
591            }
592        return self._all_columns
593
594    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
595        """Resolve the source columns for a given source `name`."""
596        if name not in self.scope.sources:
597            raise OptimizeError(f"Unknown table: {name}")
598
599        source = self.scope.sources[name]
600
601        if isinstance(source, exp.Table):
602            columns = self.schema.column_names(source, only_visible)
603        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
604            columns = source.expression.alias_column_names
605        else:
606            columns = source.expression.named_selects
607
608        node, _ = self.scope.selected_sources.get(name) or (None, None)
609        if isinstance(node, Scope):
610            column_aliases = node.expression.alias_column_names
611        elif isinstance(node, exp.Expression):
612            column_aliases = node.alias_column_names
613        else:
614            column_aliases = []
615
616        # If the source's columns are aliased, their aliases shadow the corresponding column names
617        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
618
619    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
620        if self._source_columns is None:
621            self._source_columns = {
622                source_name: self.get_source_columns(source_name)
623                for source_name, source in itertools.chain(
624                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
625                )
626            }
627        return self._source_columns
628
629    def _get_unambiguous_columns(
630        self, source_columns: t.Dict[str, t.List[str]]
631    ) -> t.Dict[str, str]:
632        """
633        Find all the unambiguous columns in sources.
634
635        Args:
636            source_columns: Mapping of names to source columns.
637
638        Returns:
639            Mapping of column name to source name.
640        """
641        if not source_columns:
642            return {}
643
644        source_columns_pairs = list(source_columns.items())
645
646        first_table, first_columns = source_columns_pairs[0]
647        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
648        all_columns = set(unambiguous_columns)
649
650        for table, columns in source_columns_pairs[1:]:
651            unique = self._find_unique_columns(columns)
652            ambiguous = set(all_columns).intersection(unique)
653            all_columns.update(columns)
654
655            for column in ambiguous:
656                unambiguous_columns.pop(column, None)
657            for column in unique.difference(ambiguous):
658                unambiguous_columns[column] = table
659
660        return unambiguous_columns
661
662    @staticmethod
663    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
664        """
665        Find the unique columns in a list of columns.
666
667        Example:
668            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
669            ['a', 'c']
670
671        This is necessary because duplicate column names are ambiguous.
672        """
673        counts: t.Dict[str, int] = {}
674        for column in columns:
675            counts[column] = counts.get(column, 0) + 1
676        return {column for column, count in counts.items() if count == 1}
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
17def qualify_columns(
18    expression: exp.Expression,
19    schema: t.Dict | Schema,
20    expand_alias_refs: bool = True,
21    infer_schema: t.Optional[bool] = None,
22) -> exp.Expression:
23    """
24    Rewrite sqlglot AST to have fully qualified columns.
25
26    Example:
27        >>> import sqlglot
28        >>> schema = {"tbl": {"col": "INT"}}
29        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
30        >>> qualify_columns(expression, schema).sql()
31        'SELECT tbl.col AS col FROM tbl'
32
33    Args:
34        expression: Expression to qualify.
35        schema: Database schema.
36        expand_alias_refs: Whether or not to expand references to aliases.
37        infer_schema: Whether or not to infer the schema if missing.
38
39    Returns:
40        The qualified expression.
41    """
42    schema = ensure_schema(schema)
43    infer_schema = schema.empty if infer_schema is None else infer_schema
44    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
45
46    for scope in traverse_scope(expression):
47        resolver = Resolver(scope, schema, infer_schema=infer_schema)
48        _pop_table_column_aliases(scope.ctes)
49        _pop_table_column_aliases(scope.derived_tables)
50        using_column_tables = _expand_using(scope, resolver)
51
52        if schema.empty and expand_alias_refs:
53            _expand_alias_refs(scope, resolver)
54
55        _qualify_columns(scope, resolver)
56
57        if not schema.empty and expand_alias_refs:
58            _expand_alias_refs(scope, resolver)
59
60        if not isinstance(scope.expression, exp.UDTF):
61            _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
62            qualify_outputs(scope)
63
64        _expand_group_by(scope)
65        _expand_order_by(scope, resolver)
66
67    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether or not to expand references to aliases.
  • infer_schema: Whether or not to infer the schema if missing.
Returns:

The qualified expression.

def validate_qualify_columns(expression: ~E) -> ~E:
70def validate_qualify_columns(expression: E) -> E:
71    """Raise an `OptimizeError` if any columns aren't qualified"""
72    unqualified_columns = []
73    for scope in traverse_scope(expression):
74        if isinstance(scope.expression, exp.Select):
75            unqualified_columns.extend(scope.unqualified_columns)
76            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
77                column = scope.external_columns[0]
78                raise OptimizeError(
79                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
80                )
81
82    if unqualified_columns:
83        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
84    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
463def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
464    """Ensure all output columns are aliased"""
465    if isinstance(scope_or_expression, exp.Expression):
466        scope = build_scope(scope_or_expression)
467        if not isinstance(scope, Scope):
468            return
469    else:
470        scope = scope_or_expression
471
472    new_selections = []
473    for i, (selection, aliased_column) in enumerate(
474        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
475    ):
476        if isinstance(selection, exp.Subquery):
477            if not selection.output_name:
478                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
479        elif not isinstance(selection, exp.Alias) and not selection.is_star:
480            selection = alias(
481                selection,
482                alias=selection.output_name or f"_col_{i}",
483            )
484        if aliased_column:
485            selection.set("alias", exp.to_identifier(aliased_column))
486
487        new_selections.append(selection)
488
489    scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
492def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
493    """Makes sure all identifiers that need to be quoted are quoted."""
494    return expression.transform(
495        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
496    )

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
499def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
500    """
501    Pushes down the CTE alias columns into the projection,
502
503    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
504
505    Example:
506        >>> import sqlglot
507        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
508        >>> pushdown_cte_alias_columns(expression).sql()
509        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
510
511    Args:
512        expression: Expression to pushdown.
513
514    Returns:
515        The expression with the CTE aliases pushed down into the projection.
516    """
517    for cte in expression.find_all(exp.CTE):
518        if cte.alias_column_names:
519            new_expressions = []
520            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
521                if isinstance(projection, exp.Alias):
522                    projection.set("alias", _alias)
523                else:
524                    projection = alias(projection, alias=_alias)
525                new_expressions.append(projection)
526            cte.this.set("expressions", new_expressions)
527
528    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
531class Resolver:
532    """
533    Helper for resolving columns.
534
535    This is a class so we can lazily load some things and easily share them across functions.
536    """
537
538    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
539        self.scope = scope
540        self.schema = schema
541        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
542        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
543        self._all_columns: t.Optional[t.Set[str]] = None
544        self._infer_schema = infer_schema
545
546    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
547        """
548        Get the table for a column name.
549
550        Args:
551            column_name: The column name to find the table for.
552        Returns:
553            The table name if it can be found/inferred.
554        """
555        if self._unambiguous_columns is None:
556            self._unambiguous_columns = self._get_unambiguous_columns(
557                self._get_all_source_columns()
558            )
559
560        table_name = self._unambiguous_columns.get(column_name)
561
562        if not table_name and self._infer_schema:
563            sources_without_schema = tuple(
564                source
565                for source, columns in self._get_all_source_columns().items()
566                if not columns or "*" in columns
567            )
568            if len(sources_without_schema) == 1:
569                table_name = sources_without_schema[0]
570
571        if table_name not in self.scope.selected_sources:
572            return exp.to_identifier(table_name)
573
574        node, _ = self.scope.selected_sources.get(table_name)
575
576        if isinstance(node, exp.Subqueryable):
577            while node and node.alias != table_name:
578                node = node.parent
579
580        node_alias = node.args.get("alias")
581        if node_alias:
582            return exp.to_identifier(node_alias.this)
583
584        return exp.to_identifier(table_name)
585
586    @property
587    def all_columns(self) -> t.Set[str]:
588        """All available columns of all sources in this scope"""
589        if self._all_columns is None:
590            self._all_columns = {
591                column for columns in self._get_all_source_columns().values() for column in columns
592            }
593        return self._all_columns
594
595    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
596        """Resolve the source columns for a given source `name`."""
597        if name not in self.scope.sources:
598            raise OptimizeError(f"Unknown table: {name}")
599
600        source = self.scope.sources[name]
601
602        if isinstance(source, exp.Table):
603            columns = self.schema.column_names(source, only_visible)
604        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
605            columns = source.expression.alias_column_names
606        else:
607            columns = source.expression.named_selects
608
609        node, _ = self.scope.selected_sources.get(name) or (None, None)
610        if isinstance(node, Scope):
611            column_aliases = node.expression.alias_column_names
612        elif isinstance(node, exp.Expression):
613            column_aliases = node.alias_column_names
614        else:
615            column_aliases = []
616
617        # If the source's columns are aliased, their aliases shadow the corresponding column names
618        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
619
620    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
621        if self._source_columns is None:
622            self._source_columns = {
623                source_name: self.get_source_columns(source_name)
624                for source_name, source in itertools.chain(
625                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
626                )
627            }
628        return self._source_columns
629
630    def _get_unambiguous_columns(
631        self, source_columns: t.Dict[str, t.List[str]]
632    ) -> t.Dict[str, str]:
633        """
634        Find all the unambiguous columns in sources.
635
636        Args:
637            source_columns: Mapping of names to source columns.
638
639        Returns:
640            Mapping of column name to source name.
641        """
642        if not source_columns:
643            return {}
644
645        source_columns_pairs = list(source_columns.items())
646
647        first_table, first_columns = source_columns_pairs[0]
648        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
649        all_columns = set(unambiguous_columns)
650
651        for table, columns in source_columns_pairs[1:]:
652            unique = self._find_unique_columns(columns)
653            ambiguous = set(all_columns).intersection(unique)
654            all_columns.update(columns)
655
656            for column in ambiguous:
657                unambiguous_columns.pop(column, None)
658            for column in unique.difference(ambiguous):
659                unambiguous_columns[column] = table
660
661        return unambiguous_columns
662
663    @staticmethod
664    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
665        """
666        Find the unique columns in a list of columns.
667
668        Example:
669            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
670            ['a', 'c']
671
672        This is necessary because duplicate column names are ambiguous.
673        """
674        counts: t.Dict[str, int] = {}
675        for column in columns:
676            counts[column] = counts.get(column, 0) + 1
677        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
538    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
539        self.scope = scope
540        self.schema = schema
541        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
542        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
543        self._all_columns: t.Optional[t.Set[str]] = None
544        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
546    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
547        """
548        Get the table for a column name.
549
550        Args:
551            column_name: The column name to find the table for.
552        Returns:
553            The table name if it can be found/inferred.
554        """
555        if self._unambiguous_columns is None:
556            self._unambiguous_columns = self._get_unambiguous_columns(
557                self._get_all_source_columns()
558            )
559
560        table_name = self._unambiguous_columns.get(column_name)
561
562        if not table_name and self._infer_schema:
563            sources_without_schema = tuple(
564                source
565                for source, columns in self._get_all_source_columns().items()
566                if not columns or "*" in columns
567            )
568            if len(sources_without_schema) == 1:
569                table_name = sources_without_schema[0]
570
571        if table_name not in self.scope.selected_sources:
572            return exp.to_identifier(table_name)
573
574        node, _ = self.scope.selected_sources.get(table_name)
575
576        if isinstance(node, exp.Subqueryable):
577            while node and node.alias != table_name:
578                node = node.parent
579
580        node_alias = node.args.get("alias")
581        if node_alias:
582            return exp.to_identifier(node_alias.this)
583
584        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
586    @property
587    def all_columns(self) -> t.Set[str]:
588        """All available columns of all sources in this scope"""
589        if self._all_columns is None:
590            self._all_columns = {
591                column for columns in self._get_all_source_columns().values() for column in columns
592            }
593        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
595    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
596        """Resolve the source columns for a given source `name`."""
597        if name not in self.scope.sources:
598            raise OptimizeError(f"Unknown table: {name}")
599
600        source = self.scope.sources[name]
601
602        if isinstance(source, exp.Table):
603            columns = self.schema.column_names(source, only_visible)
604        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
605            columns = source.expression.alias_column_names
606        else:
607            columns = source.expression.named_selects
608
609        node, _ = self.scope.selected_sources.get(name) or (None, None)
610        if isinstance(node, Scope):
611            column_aliases = node.expression.alias_column_names
612        elif isinstance(node, exp.Expression):
613            column_aliases = node.alias_column_names
614        else:
615            column_aliases = []
616
617        # If the source's columns are aliased, their aliases shadow the corresponding column names
618        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]

Resolve the source columns for a given source name.