Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 11from sqlglot.optimizer.simplify import simplify_parens
 12from sqlglot.schema import Schema, ensure_schema
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot._typing import E
 16
 17
 18def qualify_columns(
 19    expression: exp.Expression,
 20    schema: t.Dict | Schema,
 21    expand_alias_refs: bool = True,
 22    expand_stars: bool = True,
 23    infer_schema: t.Optional[bool] = None,
 24) -> exp.Expression:
 25    """
 26    Rewrite sqlglot AST to have fully qualified columns.
 27
 28    Example:
 29        >>> import sqlglot
 30        >>> schema = {"tbl": {"col": "INT"}}
 31        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 32        >>> qualify_columns(expression, schema).sql()
 33        'SELECT tbl.col AS col FROM tbl'
 34
 35    Args:
 36        expression: Expression to qualify.
 37        schema: Database schema.
 38        expand_alias_refs: Whether to expand references to aliases.
 39        expand_stars: Whether to expand star queries. This is a necessary step
 40            for most of the optimizer's rules to work; do not set to False unless you
 41            know what you're doing!
 42        infer_schema: Whether to infer the schema if missing.
 43
 44    Returns:
 45        The qualified expression.
 46
 47    Notes:
 48        - Currently only handles a single PIVOT or UNPIVOT operator
 49    """
 50    schema = ensure_schema(schema)
 51    infer_schema = schema.empty if infer_schema is None else infer_schema
 52    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
 53
 54    for scope in traverse_scope(expression):
 55        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 56        _pop_table_column_aliases(scope.ctes)
 57        _pop_table_column_aliases(scope.derived_tables)
 58        using_column_tables = _expand_using(scope, resolver)
 59
 60        if schema.empty and expand_alias_refs:
 61            _expand_alias_refs(scope, resolver)
 62
 63        _qualify_columns(scope, resolver)
 64
 65        if not schema.empty and expand_alias_refs:
 66            _expand_alias_refs(scope, resolver)
 67
 68        if not isinstance(scope.expression, exp.UDTF):
 69            if expand_stars:
 70                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 71            qualify_outputs(scope)
 72
 73        _expand_group_by(scope)
 74        _expand_order_by(scope, resolver)
 75
 76    return expression
 77
 78
 79def validate_qualify_columns(expression: E) -> E:
 80    """Raise an `OptimizeError` if any columns aren't qualified"""
 81    all_unqualified_columns = []
 82    for scope in traverse_scope(expression):
 83        if isinstance(scope.expression, exp.Select):
 84            unqualified_columns = scope.unqualified_columns
 85
 86            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 87                column = scope.external_columns[0]
 88                for_table = f" for table: '{column.table}'" if column.table else ""
 89                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 90
 91            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 92                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 93                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 94                # this list here to ensure those in the former category will be excluded.
 95                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 96                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 97
 98            all_unqualified_columns.extend(unqualified_columns)
 99
100    if all_unqualified_columns:
101        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
102
103    return expression
104
105
106def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
107    name_column = []
108    field = unpivot.args.get("field")
109    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
110        name_column.append(field.this)
111
112    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
113    return itertools.chain(name_column, value_columns)
114
115
116def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
117    """
118    Remove table column aliases.
119
120    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
121    """
122    for derived_table in derived_tables:
123        table_alias = derived_table.args.get("alias")
124        if table_alias:
125            table_alias.args.pop("columns", None)
126
127
128def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
129    joins = list(scope.find_all(exp.Join))
130    names = {join.alias_or_name for join in joins}
131    ordered = [key for key in scope.selected_sources if key not in names]
132
133    # Mapping of automatically joined column names to an ordered set of source names (dict).
134    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
135
136    for join in joins:
137        using = join.args.get("using")
138
139        if not using:
140            continue
141
142        join_table = join.alias_or_name
143
144        columns = {}
145
146        for source_name in scope.selected_sources:
147            if source_name in ordered:
148                for column_name in resolver.get_source_columns(source_name):
149                    if column_name not in columns:
150                        columns[column_name] = source_name
151
152        source_table = ordered[-1]
153        ordered.append(join_table)
154        join_columns = resolver.get_source_columns(join_table)
155        conditions = []
156
157        for identifier in using:
158            identifier = identifier.name
159            table = columns.get(identifier)
160
161            if not table or identifier not in join_columns:
162                if (columns and "*" not in columns) and join_columns:
163                    raise OptimizeError(f"Cannot automatically join: {identifier}")
164
165            table = table or source_table
166            conditions.append(
167                exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table))
168            )
169
170            # Set all values in the dict to None, because we only care about the key ordering
171            tables = column_tables.setdefault(identifier, {})
172            if table not in tables:
173                tables[table] = None
174            if join_table not in tables:
175                tables[join_table] = None
176
177        join.args.pop("using")
178        join.set("on", exp.and_(*conditions, copy=False))
179
180    if column_tables:
181        for column in scope.columns:
182            if not column.table and column.name in column_tables:
183                tables = column_tables[column.name]
184                coalesce = [exp.column(column.name, table=table) for table in tables]
185                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
186
187                # Ensure selects keep their output name
188                if isinstance(column.parent, exp.Select):
189                    replacement = alias(replacement, alias=column.name, copy=False)
190
191                scope.replace(column, replacement)
192
193    return column_tables
194
195
196def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
197    expression = scope.expression
198
199    if not isinstance(expression, exp.Select):
200        return
201
202    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
203
204    def replace_columns(
205        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
206    ) -> None:
207        if not node:
208            return
209
210        for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
211            if not isinstance(column, exp.Column):
212                continue
213
214            table = resolver.get_table(column.name) if resolve_table and not column.table else None
215            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
216            double_agg = (
217                (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
218                if alias_expr
219                else False
220            )
221
222            if table and (not alias_expr or double_agg):
223                column.set("table", table)
224            elif not column.table and alias_expr and not double_agg:
225                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
226                    if literal_index:
227                        column.replace(exp.Literal.number(i))
228                else:
229                    column = column.replace(exp.paren(alias_expr))
230                    simplified = simplify_parens(column)
231                    if simplified is not column:
232                        column.replace(simplified)
233
234    for i, projection in enumerate(scope.expression.selects):
235        replace_columns(projection)
236
237        if isinstance(projection, exp.Alias):
238            alias_to_expression[projection.alias] = (projection.this, i + 1)
239
240    replace_columns(expression.args.get("where"))
241    replace_columns(expression.args.get("group"), literal_index=True)
242    replace_columns(expression.args.get("having"), resolve_table=True)
243    replace_columns(expression.args.get("qualify"), resolve_table=True)
244
245    scope.clear_cache()
246
247
248def _expand_group_by(scope: Scope) -> None:
249    expression = scope.expression
250    group = expression.args.get("group")
251    if not group:
252        return
253
254    group.set("expressions", _expand_positional_references(scope, group.expressions))
255    expression.set("group", group)
256
257
258def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
259    order = scope.expression.args.get("order")
260    if not order:
261        return
262
263    ordereds = order.expressions
264    for ordered, new_expression in zip(
265        ordereds,
266        _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
267    ):
268        for agg in ordered.find_all(exp.AggFunc):
269            for col in agg.find_all(exp.Column):
270                if not col.table:
271                    col.set("table", resolver.get_table(col.name))
272
273        ordered.set("this", new_expression)
274
275    if scope.expression.args.get("group"):
276        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
277
278        for ordered in ordereds:
279            ordered = ordered.this
280
281            ordered.replace(
282                exp.to_identifier(_select_by_pos(scope, ordered).alias)
283                if ordered.is_int
284                else selects.get(ordered, ordered)
285            )
286
287
288def _expand_positional_references(
289    scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
290) -> t.List[exp.Expression]:
291    new_nodes: t.List[exp.Expression] = []
292    for node in expressions:
293        if node.is_int:
294            select = _select_by_pos(scope, t.cast(exp.Literal, node))
295
296            if alias:
297                new_nodes.append(exp.column(select.args["alias"].copy()))
298            else:
299                select = select.this
300
301                if isinstance(select, exp.Literal):
302                    new_nodes.append(node)
303                else:
304                    new_nodes.append(select.copy())
305        else:
306            new_nodes.append(node)
307
308    return new_nodes
309
310
311def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
312    try:
313        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
314    except IndexError:
315        raise OptimizeError(f"Unknown output column: {node.name}")
316
317
318def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
319    """Disambiguate columns, ensuring each column specifies a source"""
320    for column in scope.columns:
321        column_table = column.table
322        column_name = column.name
323
324        if column_table and column_table in scope.sources:
325            source_columns = resolver.get_source_columns(column_table)
326            if source_columns and column_name not in source_columns and "*" not in source_columns:
327                raise OptimizeError(f"Unknown column: {column_name}")
328
329        if not column_table:
330            if scope.pivots and not column.find_ancestor(exp.Pivot):
331                # If the column is under the Pivot expression, we need to qualify it
332                # using the name of the pivoted source instead of the pivot's alias
333                column.set("table", exp.to_identifier(scope.pivots[0].alias))
334                continue
335
336            column_table = resolver.get_table(column_name)
337
338            # column_table can be a '' because bigquery unnest has no table alias
339            if column_table:
340                column.set("table", column_table)
341        elif column_table not in scope.sources and (
342            not scope.parent
343            or column_table not in scope.parent.sources
344            or not scope.is_correlated_subquery
345        ):
346            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
347            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
348
349            root, *parts = column.parts
350
351            if root.name in scope.sources:
352                # struct is already qualified, but we still need to change the AST representation
353                column_table = root
354                root, *parts = parts
355            else:
356                column_table = resolver.get_table(root.name)
357
358            if column_table:
359                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
360
361    for pivot in scope.pivots:
362        for column in pivot.find_all(exp.Column):
363            if not column.table and column.name in resolver.all_columns:
364                column_table = resolver.get_table(column.name)
365                if column_table:
366                    column.set("table", column_table)
367
368
369def _expand_stars(
370    scope: Scope,
371    resolver: Resolver,
372    using_column_tables: t.Dict[str, t.Any],
373    pseudocolumns: t.Set[str],
374) -> None:
375    """Expand stars to lists of column selections"""
376
377    new_selections = []
378    except_columns: t.Dict[int, t.Set[str]] = {}
379    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
380    coalesced_columns = set()
381
382    pivot_output_columns = None
383    pivot_exclude_columns = None
384
385    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
386    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
387        if pivot.unpivot:
388            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
389
390            field = pivot.args.get("field")
391            if isinstance(field, exp.In):
392                pivot_exclude_columns = {
393                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
394                }
395        else:
396            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
397
398            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
399            if not pivot_output_columns:
400                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
401
402    for expression in scope.expression.selects:
403        if isinstance(expression, exp.Star):
404            tables = list(scope.selected_sources)
405            _add_except_columns(expression, tables, except_columns)
406            _add_replace_columns(expression, tables, replace_columns)
407        elif expression.is_star:
408            tables = [expression.table]
409            _add_except_columns(expression.this, tables, except_columns)
410            _add_replace_columns(expression.this, tables, replace_columns)
411        else:
412            new_selections.append(expression)
413            continue
414
415        for table in tables:
416            if table not in scope.sources:
417                raise OptimizeError(f"Unknown table: {table}")
418
419            columns = resolver.get_source_columns(table, only_visible=True)
420            columns = columns or scope.outer_column_list
421
422            if pseudocolumns:
423                columns = [name for name in columns if name.upper() not in pseudocolumns]
424
425            if not columns or "*" in columns:
426                return
427
428            table_id = id(table)
429            columns_to_exclude = except_columns.get(table_id) or set()
430
431            if pivot:
432                if pivot_output_columns and pivot_exclude_columns:
433                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
434                    pivot_columns.extend(pivot_output_columns)
435                else:
436                    pivot_columns = pivot.alias_column_names
437
438                if pivot_columns:
439                    new_selections.extend(
440                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
441                        for name in pivot_columns
442                        if name not in columns_to_exclude
443                    )
444                    continue
445
446            for name in columns:
447                if name in columns_to_exclude or name in coalesced_columns:
448                    continue
449                if name in using_column_tables and table in using_column_tables[name]:
450                    coalesced_columns.add(name)
451                    tables = using_column_tables[name]
452                    coalesce = [exp.column(name, table=table) for table in tables]
453
454                    new_selections.append(
455                        alias(
456                            exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
457                            alias=name,
458                            copy=False,
459                        )
460                    )
461                else:
462                    alias_ = replace_columns.get(table_id, {}).get(name, name)
463                    column = exp.column(name, table=table)
464                    new_selections.append(
465                        alias(column, alias_, copy=False) if alias_ != name else column
466                    )
467
468    # Ensures we don't overwrite the initial selections with an empty list
469    if new_selections and isinstance(scope.expression, exp.Select):
470        scope.expression.set("expressions", new_selections)
471
472
473def _add_except_columns(
474    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
475) -> None:
476    except_ = expression.args.get("except")
477
478    if not except_:
479        return
480
481    columns = {e.name for e in except_}
482
483    for table in tables:
484        except_columns[id(table)] = columns
485
486
487def _add_replace_columns(
488    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
489) -> None:
490    replace = expression.args.get("replace")
491
492    if not replace:
493        return
494
495    columns = {e.this.name: e.alias for e in replace}
496
497    for table in tables:
498        replace_columns[id(table)] = columns
499
500
501def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
502    """Ensure all output columns are aliased"""
503    if isinstance(scope_or_expression, exp.Expression):
504        scope = build_scope(scope_or_expression)
505        if not isinstance(scope, Scope):
506            return
507    else:
508        scope = scope_or_expression
509
510    new_selections = []
511    for i, (selection, aliased_column) in enumerate(
512        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
513    ):
514        if selection is None:
515            break
516
517        if isinstance(selection, exp.Subquery):
518            if not selection.output_name:
519                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
520        elif not isinstance(selection, exp.Alias) and not selection.is_star:
521            selection = alias(
522                selection,
523                alias=selection.output_name or f"_col_{i}",
524                copy=False,
525            )
526        if aliased_column:
527            selection.set("alias", exp.to_identifier(aliased_column))
528
529        new_selections.append(selection)
530
531    if isinstance(scope.expression, exp.Select):
532        scope.expression.set("expressions", new_selections)
533
534
535def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
536    """Makes sure all identifiers that need to be quoted are quoted."""
537    return expression.transform(
538        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
539    )
540
541
542def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
543    """
544    Pushes down the CTE alias columns into the projection,
545
546    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
547
548    Example:
549        >>> import sqlglot
550        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
551        >>> pushdown_cte_alias_columns(expression).sql()
552        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
553
554    Args:
555        expression: Expression to pushdown.
556
557    Returns:
558        The expression with the CTE aliases pushed down into the projection.
559    """
560    for cte in expression.find_all(exp.CTE):
561        if cte.alias_column_names:
562            new_expressions = []
563            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
564                if isinstance(projection, exp.Alias):
565                    projection.set("alias", _alias)
566                else:
567                    projection = alias(projection, alias=_alias)
568                new_expressions.append(projection)
569            cte.this.set("expressions", new_expressions)
570
571    return expression
572
573
574class Resolver:
575    """
576    Helper for resolving columns.
577
578    This is a class so we can lazily load some things and easily share them across functions.
579    """
580
581    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
582        self.scope = scope
583        self.schema = schema
584        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
585        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
586        self._all_columns: t.Optional[t.Set[str]] = None
587        self._infer_schema = infer_schema
588
589    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
590        """
591        Get the table for a column name.
592
593        Args:
594            column_name: The column name to find the table for.
595        Returns:
596            The table name if it can be found/inferred.
597        """
598        if self._unambiguous_columns is None:
599            self._unambiguous_columns = self._get_unambiguous_columns(
600                self._get_all_source_columns()
601            )
602
603        table_name = self._unambiguous_columns.get(column_name)
604
605        if not table_name and self._infer_schema:
606            sources_without_schema = tuple(
607                source
608                for source, columns in self._get_all_source_columns().items()
609                if not columns or "*" in columns
610            )
611            if len(sources_without_schema) == 1:
612                table_name = sources_without_schema[0]
613
614        if table_name not in self.scope.selected_sources:
615            return exp.to_identifier(table_name)
616
617        node, _ = self.scope.selected_sources.get(table_name)
618
619        if isinstance(node, exp.Query):
620            while node and node.alias != table_name:
621                node = node.parent
622
623        node_alias = node.args.get("alias")
624        if node_alias:
625            return exp.to_identifier(node_alias.this)
626
627        return exp.to_identifier(table_name)
628
629    @property
630    def all_columns(self) -> t.Set[str]:
631        """All available columns of all sources in this scope"""
632        if self._all_columns is None:
633            self._all_columns = {
634                column for columns in self._get_all_source_columns().values() for column in columns
635            }
636        return self._all_columns
637
638    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
639        """Resolve the source columns for a given source `name`."""
640        if name not in self.scope.sources:
641            raise OptimizeError(f"Unknown table: {name}")
642
643        source = self.scope.sources[name]
644
645        if isinstance(source, exp.Table):
646            columns = self.schema.column_names(source, only_visible)
647        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
648            columns = source.expression.alias_column_names
649        else:
650            columns = source.expression.named_selects
651
652        node, _ = self.scope.selected_sources.get(name) or (None, None)
653        if isinstance(node, Scope):
654            column_aliases = node.expression.alias_column_names
655        elif isinstance(node, exp.Expression):
656            column_aliases = node.alias_column_names
657        else:
658            column_aliases = []
659
660        if column_aliases:
661            # If the source's columns are aliased, their aliases shadow the corresponding column names.
662            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
663            return [
664                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
665            ]
666        return columns
667
668    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
669        if self._source_columns is None:
670            self._source_columns = {
671                source_name: self.get_source_columns(source_name)
672                for source_name, source in itertools.chain(
673                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
674                )
675            }
676        return self._source_columns
677
678    def _get_unambiguous_columns(
679        self, source_columns: t.Dict[str, t.Sequence[str]]
680    ) -> t.Mapping[str, str]:
681        """
682        Find all the unambiguous columns in sources.
683
684        Args:
685            source_columns: Mapping of names to source columns.
686
687        Returns:
688            Mapping of column name to source name.
689        """
690        if not source_columns:
691            return {}
692
693        source_columns_pairs = list(source_columns.items())
694
695        first_table, first_columns = source_columns_pairs[0]
696
697        if len(source_columns_pairs) == 1:
698            # Performance optimization - avoid copying first_columns if there is only one table.
699            return SingleValuedMapping(first_columns, first_table)
700
701        unambiguous_columns = {col: first_table for col in first_columns}
702        all_columns = set(unambiguous_columns)
703
704        for table, columns in source_columns_pairs[1:]:
705            unique = set(columns)
706            ambiguous = all_columns.intersection(unique)
707            all_columns.update(columns)
708
709            for column in ambiguous:
710                unambiguous_columns.pop(column, None)
711            for column in unique.difference(ambiguous):
712                unambiguous_columns[column] = table
713
714        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
19def qualify_columns(
20    expression: exp.Expression,
21    schema: t.Dict | Schema,
22    expand_alias_refs: bool = True,
23    expand_stars: bool = True,
24    infer_schema: t.Optional[bool] = None,
25) -> exp.Expression:
26    """
27    Rewrite sqlglot AST to have fully qualified columns.
28
29    Example:
30        >>> import sqlglot
31        >>> schema = {"tbl": {"col": "INT"}}
32        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
33        >>> qualify_columns(expression, schema).sql()
34        'SELECT tbl.col AS col FROM tbl'
35
36    Args:
37        expression: Expression to qualify.
38        schema: Database schema.
39        expand_alias_refs: Whether to expand references to aliases.
40        expand_stars: Whether to expand star queries. This is a necessary step
41            for most of the optimizer's rules to work; do not set to False unless you
42            know what you're doing!
43        infer_schema: Whether to infer the schema if missing.
44
45    Returns:
46        The qualified expression.
47
48    Notes:
49        - Currently only handles a single PIVOT or UNPIVOT operator
50    """
51    schema = ensure_schema(schema)
52    infer_schema = schema.empty if infer_schema is None else infer_schema
53    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
54
55    for scope in traverse_scope(expression):
56        resolver = Resolver(scope, schema, infer_schema=infer_schema)
57        _pop_table_column_aliases(scope.ctes)
58        _pop_table_column_aliases(scope.derived_tables)
59        using_column_tables = _expand_using(scope, resolver)
60
61        if schema.empty and expand_alias_refs:
62            _expand_alias_refs(scope, resolver)
63
64        _qualify_columns(scope, resolver)
65
66        if not schema.empty and expand_alias_refs:
67            _expand_alias_refs(scope, resolver)
68
69        if not isinstance(scope.expression, exp.UDTF):
70            if expand_stars:
71                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
72            qualify_outputs(scope)
73
74        _expand_group_by(scope)
75        _expand_order_by(scope, resolver)
76
77    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 80def validate_qualify_columns(expression: E) -> E:
 81    """Raise an `OptimizeError` if any columns aren't qualified"""
 82    all_unqualified_columns = []
 83    for scope in traverse_scope(expression):
 84        if isinstance(scope.expression, exp.Select):
 85            unqualified_columns = scope.unqualified_columns
 86
 87            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 88                column = scope.external_columns[0]
 89                for_table = f" for table: '{column.table}'" if column.table else ""
 90                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 91
 92            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 93                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 94                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 95                # this list here to ensure those in the former category will be excluded.
 96                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 97                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 98
 99            all_unqualified_columns.extend(unqualified_columns)
100
101    if all_unqualified_columns:
102        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
103
104    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
502def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
503    """Ensure all output columns are aliased"""
504    if isinstance(scope_or_expression, exp.Expression):
505        scope = build_scope(scope_or_expression)
506        if not isinstance(scope, Scope):
507            return
508    else:
509        scope = scope_or_expression
510
511    new_selections = []
512    for i, (selection, aliased_column) in enumerate(
513        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
514    ):
515        if selection is None:
516            break
517
518        if isinstance(selection, exp.Subquery):
519            if not selection.output_name:
520                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
521        elif not isinstance(selection, exp.Alias) and not selection.is_star:
522            selection = alias(
523                selection,
524                alias=selection.output_name or f"_col_{i}",
525                copy=False,
526            )
527        if aliased_column:
528            selection.set("alias", exp.to_identifier(aliased_column))
529
530        new_selections.append(selection)
531
532    if isinstance(scope.expression, exp.Select):
533        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
536def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
537    """Makes sure all identifiers that need to be quoted are quoted."""
538    return expression.transform(
539        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
540    )

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
543def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
544    """
545    Pushes down the CTE alias columns into the projection,
546
547    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
548
549    Example:
550        >>> import sqlglot
551        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
552        >>> pushdown_cte_alias_columns(expression).sql()
553        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
554
555    Args:
556        expression: Expression to pushdown.
557
558    Returns:
559        The expression with the CTE aliases pushed down into the projection.
560    """
561    for cte in expression.find_all(exp.CTE):
562        if cte.alias_column_names:
563            new_expressions = []
564            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
565                if isinstance(projection, exp.Alias):
566                    projection.set("alias", _alias)
567                else:
568                    projection = alias(projection, alias=_alias)
569                new_expressions.append(projection)
570            cte.this.set("expressions", new_expressions)
571
572    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
575class Resolver:
576    """
577    Helper for resolving columns.
578
579    This is a class so we can lazily load some things and easily share them across functions.
580    """
581
582    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
583        self.scope = scope
584        self.schema = schema
585        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
586        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
587        self._all_columns: t.Optional[t.Set[str]] = None
588        self._infer_schema = infer_schema
589
590    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
591        """
592        Get the table for a column name.
593
594        Args:
595            column_name: The column name to find the table for.
596        Returns:
597            The table name if it can be found/inferred.
598        """
599        if self._unambiguous_columns is None:
600            self._unambiguous_columns = self._get_unambiguous_columns(
601                self._get_all_source_columns()
602            )
603
604        table_name = self._unambiguous_columns.get(column_name)
605
606        if not table_name and self._infer_schema:
607            sources_without_schema = tuple(
608                source
609                for source, columns in self._get_all_source_columns().items()
610                if not columns or "*" in columns
611            )
612            if len(sources_without_schema) == 1:
613                table_name = sources_without_schema[0]
614
615        if table_name not in self.scope.selected_sources:
616            return exp.to_identifier(table_name)
617
618        node, _ = self.scope.selected_sources.get(table_name)
619
620        if isinstance(node, exp.Query):
621            while node and node.alias != table_name:
622                node = node.parent
623
624        node_alias = node.args.get("alias")
625        if node_alias:
626            return exp.to_identifier(node_alias.this)
627
628        return exp.to_identifier(table_name)
629
630    @property
631    def all_columns(self) -> t.Set[str]:
632        """All available columns of all sources in this scope"""
633        if self._all_columns is None:
634            self._all_columns = {
635                column for columns in self._get_all_source_columns().values() for column in columns
636            }
637        return self._all_columns
638
639    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
640        """Resolve the source columns for a given source `name`."""
641        if name not in self.scope.sources:
642            raise OptimizeError(f"Unknown table: {name}")
643
644        source = self.scope.sources[name]
645
646        if isinstance(source, exp.Table):
647            columns = self.schema.column_names(source, only_visible)
648        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
649            columns = source.expression.alias_column_names
650        else:
651            columns = source.expression.named_selects
652
653        node, _ = self.scope.selected_sources.get(name) or (None, None)
654        if isinstance(node, Scope):
655            column_aliases = node.expression.alias_column_names
656        elif isinstance(node, exp.Expression):
657            column_aliases = node.alias_column_names
658        else:
659            column_aliases = []
660
661        if column_aliases:
662            # If the source's columns are aliased, their aliases shadow the corresponding column names.
663            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
664            return [
665                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
666            ]
667        return columns
668
669    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
670        if self._source_columns is None:
671            self._source_columns = {
672                source_name: self.get_source_columns(source_name)
673                for source_name, source in itertools.chain(
674                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
675                )
676            }
677        return self._source_columns
678
679    def _get_unambiguous_columns(
680        self, source_columns: t.Dict[str, t.Sequence[str]]
681    ) -> t.Mapping[str, str]:
682        """
683        Find all the unambiguous columns in sources.
684
685        Args:
686            source_columns: Mapping of names to source columns.
687
688        Returns:
689            Mapping of column name to source name.
690        """
691        if not source_columns:
692            return {}
693
694        source_columns_pairs = list(source_columns.items())
695
696        first_table, first_columns = source_columns_pairs[0]
697
698        if len(source_columns_pairs) == 1:
699            # Performance optimization - avoid copying first_columns if there is only one table.
700            return SingleValuedMapping(first_columns, first_table)
701
702        unambiguous_columns = {col: first_table for col in first_columns}
703        all_columns = set(unambiguous_columns)
704
705        for table, columns in source_columns_pairs[1:]:
706            unique = set(columns)
707            ambiguous = all_columns.intersection(unique)
708            all_columns.update(columns)
709
710            for column in ambiguous:
711                unambiguous_columns.pop(column, None)
712            for column in unique.difference(ambiguous):
713                unambiguous_columns[column] = table
714
715        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
582    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
583        self.scope = scope
584        self.schema = schema
585        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
586        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
587        self._all_columns: t.Optional[t.Set[str]] = None
588        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
590    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
591        """
592        Get the table for a column name.
593
594        Args:
595            column_name: The column name to find the table for.
596        Returns:
597            The table name if it can be found/inferred.
598        """
599        if self._unambiguous_columns is None:
600            self._unambiguous_columns = self._get_unambiguous_columns(
601                self._get_all_source_columns()
602            )
603
604        table_name = self._unambiguous_columns.get(column_name)
605
606        if not table_name and self._infer_schema:
607            sources_without_schema = tuple(
608                source
609                for source, columns in self._get_all_source_columns().items()
610                if not columns or "*" in columns
611            )
612            if len(sources_without_schema) == 1:
613                table_name = sources_without_schema[0]
614
615        if table_name not in self.scope.selected_sources:
616            return exp.to_identifier(table_name)
617
618        node, _ = self.scope.selected_sources.get(table_name)
619
620        if isinstance(node, exp.Query):
621            while node and node.alias != table_name:
622                node = node.parent
623
624        node_alias = node.args.get("alias")
625        if node_alias:
626            return exp.to_identifier(node_alias.this)
627
628        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
630    @property
631    def all_columns(self) -> t.Set[str]:
632        """All available columns of all sources in this scope"""
633        if self._all_columns is None:
634            self._all_columns = {
635                column for columns in self._get_all_source_columns().values() for column in columns
636            }
637        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
639    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
640        """Resolve the source columns for a given source `name`."""
641        if name not in self.scope.sources:
642            raise OptimizeError(f"Unknown table: {name}")
643
644        source = self.scope.sources[name]
645
646        if isinstance(source, exp.Table):
647            columns = self.schema.column_names(source, only_visible)
648        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
649            columns = source.expression.alias_column_names
650        else:
651            columns = source.expression.named_selects
652
653        node, _ = self.scope.selected_sources.get(name) or (None, None)
654        if isinstance(node, Scope):
655            column_aliases = node.expression.alias_column_names
656        elif isinstance(node, exp.Expression):
657            column_aliases = node.alias_column_names
658        else:
659            column_aliases = []
660
661        if column_aliases:
662            # If the source's columns are aliased, their aliases shadow the corresponding column names.
663            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
664            return [
665                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
666            ]
667        return columns

Resolve the source columns for a given source name.