Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 11from sqlglot.optimizer.simplify import simplify_parens
 12from sqlglot.schema import Schema, ensure_schema
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot._typing import E
 16
 17
 18def qualify_columns(
 19    expression: exp.Expression,
 20    schema: t.Dict | Schema,
 21    expand_alias_refs: bool = True,
 22    expand_stars: bool = True,
 23    infer_schema: t.Optional[bool] = None,
 24) -> exp.Expression:
 25    """
 26    Rewrite sqlglot AST to have fully qualified columns.
 27
 28    Example:
 29        >>> import sqlglot
 30        >>> schema = {"tbl": {"col": "INT"}}
 31        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 32        >>> qualify_columns(expression, schema).sql()
 33        'SELECT tbl.col AS col FROM tbl'
 34
 35    Args:
 36        expression: Expression to qualify.
 37        schema: Database schema.
 38        expand_alias_refs: Whether to expand references to aliases.
 39        expand_stars: Whether to expand star queries. This is a necessary step
 40            for most of the optimizer's rules to work; do not set to False unless you
 41            know what you're doing!
 42        infer_schema: Whether to infer the schema if missing.
 43
 44    Returns:
 45        The qualified expression.
 46
 47    Notes:
 48        - Currently only handles a single PIVOT or UNPIVOT operator
 49    """
 50    schema = ensure_schema(schema)
 51    infer_schema = schema.empty if infer_schema is None else infer_schema
 52    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
 53
 54    for scope in traverse_scope(expression):
 55        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 56        _pop_table_column_aliases(scope.ctes)
 57        _pop_table_column_aliases(scope.derived_tables)
 58        using_column_tables = _expand_using(scope, resolver)
 59
 60        if schema.empty and expand_alias_refs:
 61            _expand_alias_refs(scope, resolver)
 62
 63        _qualify_columns(scope, resolver)
 64
 65        if not schema.empty and expand_alias_refs:
 66            _expand_alias_refs(scope, resolver)
 67
 68        if not isinstance(scope.expression, exp.UDTF):
 69            if expand_stars:
 70                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 71            qualify_outputs(scope)
 72
 73        _expand_group_by(scope)
 74        _expand_order_by(scope, resolver)
 75
 76    return expression
 77
 78
 79def validate_qualify_columns(expression: E) -> E:
 80    """Raise an `OptimizeError` if any columns aren't qualified"""
 81    all_unqualified_columns = []
 82    for scope in traverse_scope(expression):
 83        if isinstance(scope.expression, exp.Select):
 84            unqualified_columns = scope.unqualified_columns
 85
 86            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 87                column = scope.external_columns[0]
 88                for_table = f" for table: '{column.table}'" if column.table else ""
 89                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 90
 91            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 92                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 93                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 94                # this list here to ensure those in the former category will be excluded.
 95                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 96                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 97
 98            all_unqualified_columns.extend(unqualified_columns)
 99
100    if all_unqualified_columns:
101        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
102
103    return expression
104
105
106def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
107    name_column = []
108    field = unpivot.args.get("field")
109    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
110        name_column.append(field.this)
111
112    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
113    return itertools.chain(name_column, value_columns)
114
115
116def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
117    """
118    Remove table column aliases.
119
120    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
121    """
122    for derived_table in derived_tables:
123        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
124            continue
125        table_alias = derived_table.args.get("alias")
126        if table_alias:
127            table_alias.args.pop("columns", None)
128
129
130def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
131    joins = list(scope.find_all(exp.Join))
132    names = {join.alias_or_name for join in joins}
133    ordered = [key for key in scope.selected_sources if key not in names]
134
135    # Mapping of automatically joined column names to an ordered set of source names (dict).
136    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
137
138    for join in joins:
139        using = join.args.get("using")
140
141        if not using:
142            continue
143
144        join_table = join.alias_or_name
145
146        columns = {}
147
148        for source_name in scope.selected_sources:
149            if source_name in ordered:
150                for column_name in resolver.get_source_columns(source_name):
151                    if column_name not in columns:
152                        columns[column_name] = source_name
153
154        source_table = ordered[-1]
155        ordered.append(join_table)
156        join_columns = resolver.get_source_columns(join_table)
157        conditions = []
158
159        for identifier in using:
160            identifier = identifier.name
161            table = columns.get(identifier)
162
163            if not table or identifier not in join_columns:
164                if (columns and "*" not in columns) and join_columns:
165                    raise OptimizeError(f"Cannot automatically join: {identifier}")
166
167            table = table or source_table
168            conditions.append(
169                exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table))
170            )
171
172            # Set all values in the dict to None, because we only care about the key ordering
173            tables = column_tables.setdefault(identifier, {})
174            if table not in tables:
175                tables[table] = None
176            if join_table not in tables:
177                tables[join_table] = None
178
179        join.args.pop("using")
180        join.set("on", exp.and_(*conditions, copy=False))
181
182    if column_tables:
183        for column in scope.columns:
184            if not column.table and column.name in column_tables:
185                tables = column_tables[column.name]
186                coalesce = [exp.column(column.name, table=table) for table in tables]
187                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
188
189                # Ensure selects keep their output name
190                if isinstance(column.parent, exp.Select):
191                    replacement = alias(replacement, alias=column.name, copy=False)
192
193                scope.replace(column, replacement)
194
195    return column_tables
196
197
198def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
199    expression = scope.expression
200
201    if not isinstance(expression, exp.Select):
202        return
203
204    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
205
206    def replace_columns(
207        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
208    ) -> None:
209        if not node:
210            return
211
212        for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
213            if not isinstance(column, exp.Column):
214                continue
215
216            table = resolver.get_table(column.name) if resolve_table and not column.table else None
217            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
218            double_agg = (
219                (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
220                if alias_expr
221                else False
222            )
223
224            if table and (not alias_expr or double_agg):
225                column.set("table", table)
226            elif not column.table and alias_expr and not double_agg:
227                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
228                    if literal_index:
229                        column.replace(exp.Literal.number(i))
230                else:
231                    column = column.replace(exp.paren(alias_expr))
232                    simplified = simplify_parens(column)
233                    if simplified is not column:
234                        column.replace(simplified)
235
236    for i, projection in enumerate(scope.expression.selects):
237        replace_columns(projection)
238
239        if isinstance(projection, exp.Alias):
240            alias_to_expression[projection.alias] = (projection.this, i + 1)
241
242    replace_columns(expression.args.get("where"))
243    replace_columns(expression.args.get("group"), literal_index=True)
244    replace_columns(expression.args.get("having"), resolve_table=True)
245    replace_columns(expression.args.get("qualify"), resolve_table=True)
246
247    scope.clear_cache()
248
249
250def _expand_group_by(scope: Scope) -> None:
251    expression = scope.expression
252    group = expression.args.get("group")
253    if not group:
254        return
255
256    group.set("expressions", _expand_positional_references(scope, group.expressions))
257    expression.set("group", group)
258
259
260def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
261    order = scope.expression.args.get("order")
262    if not order:
263        return
264
265    ordereds = order.expressions
266    for ordered, new_expression in zip(
267        ordereds,
268        _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
269    ):
270        for agg in ordered.find_all(exp.AggFunc):
271            for col in agg.find_all(exp.Column):
272                if not col.table:
273                    col.set("table", resolver.get_table(col.name))
274
275        ordered.set("this", new_expression)
276
277    if scope.expression.args.get("group"):
278        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
279
280        for ordered in ordereds:
281            ordered = ordered.this
282
283            ordered.replace(
284                exp.to_identifier(_select_by_pos(scope, ordered).alias)
285                if ordered.is_int
286                else selects.get(ordered, ordered)
287            )
288
289
290def _expand_positional_references(
291    scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
292) -> t.List[exp.Expression]:
293    new_nodes: t.List[exp.Expression] = []
294    for node in expressions:
295        if node.is_int:
296            select = _select_by_pos(scope, t.cast(exp.Literal, node))
297
298            if alias:
299                new_nodes.append(exp.column(select.args["alias"].copy()))
300            else:
301                select = select.this
302
303                if isinstance(select, exp.Literal):
304                    new_nodes.append(node)
305                else:
306                    new_nodes.append(select.copy())
307        else:
308            new_nodes.append(node)
309
310    return new_nodes
311
312
313def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
314    try:
315        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
316    except IndexError:
317        raise OptimizeError(f"Unknown output column: {node.name}")
318
319
320def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
321    """Disambiguate columns, ensuring each column specifies a source"""
322    for column in scope.columns:
323        column_table = column.table
324        column_name = column.name
325
326        if column_table and column_table in scope.sources:
327            source_columns = resolver.get_source_columns(column_table)
328            if source_columns and column_name not in source_columns and "*" not in source_columns:
329                raise OptimizeError(f"Unknown column: {column_name}")
330
331        if not column_table:
332            if scope.pivots and not column.find_ancestor(exp.Pivot):
333                # If the column is under the Pivot expression, we need to qualify it
334                # using the name of the pivoted source instead of the pivot's alias
335                column.set("table", exp.to_identifier(scope.pivots[0].alias))
336                continue
337
338            column_table = resolver.get_table(column_name)
339
340            # column_table can be a '' because bigquery unnest has no table alias
341            if column_table:
342                column.set("table", column_table)
343        elif column_table not in scope.sources and (
344            not scope.parent
345            or column_table not in scope.parent.sources
346            or not scope.is_correlated_subquery
347        ):
348            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
349            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
350
351            root, *parts = column.parts
352
353            if root.name in scope.sources:
354                # struct is already qualified, but we still need to change the AST representation
355                column_table = root
356                root, *parts = parts
357            else:
358                column_table = resolver.get_table(root.name)
359
360            if column_table:
361                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
362
363    for pivot in scope.pivots:
364        for column in pivot.find_all(exp.Column):
365            if not column.table and column.name in resolver.all_columns:
366                column_table = resolver.get_table(column.name)
367                if column_table:
368                    column.set("table", column_table)
369
370
371def _expand_stars(
372    scope: Scope,
373    resolver: Resolver,
374    using_column_tables: t.Dict[str, t.Any],
375    pseudocolumns: t.Set[str],
376) -> None:
377    """Expand stars to lists of column selections"""
378
379    new_selections = []
380    except_columns: t.Dict[int, t.Set[str]] = {}
381    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
382    coalesced_columns = set()
383
384    pivot_output_columns = None
385    pivot_exclude_columns = None
386
387    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
388    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
389        if pivot.unpivot:
390            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
391
392            field = pivot.args.get("field")
393            if isinstance(field, exp.In):
394                pivot_exclude_columns = {
395                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
396                }
397        else:
398            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
399
400            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
401            if not pivot_output_columns:
402                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
403
404    for expression in scope.expression.selects:
405        if isinstance(expression, exp.Star):
406            tables = list(scope.selected_sources)
407            _add_except_columns(expression, tables, except_columns)
408            _add_replace_columns(expression, tables, replace_columns)
409        elif expression.is_star and not isinstance(expression, exp.Dot):
410            tables = [expression.table]
411            _add_except_columns(expression.this, tables, except_columns)
412            _add_replace_columns(expression.this, tables, replace_columns)
413        else:
414            new_selections.append(expression)
415            continue
416
417        for table in tables:
418            if table not in scope.sources:
419                raise OptimizeError(f"Unknown table: {table}")
420
421            columns = resolver.get_source_columns(table, only_visible=True)
422            columns = columns or scope.outer_column_list
423
424            if pseudocolumns:
425                columns = [name for name in columns if name.upper() not in pseudocolumns]
426
427            if not columns or "*" in columns:
428                return
429
430            table_id = id(table)
431            columns_to_exclude = except_columns.get(table_id) or set()
432
433            if pivot:
434                if pivot_output_columns and pivot_exclude_columns:
435                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
436                    pivot_columns.extend(pivot_output_columns)
437                else:
438                    pivot_columns = pivot.alias_column_names
439
440                if pivot_columns:
441                    new_selections.extend(
442                        alias(exp.column(name, table=pivot.alias), name, copy=False)
443                        for name in pivot_columns
444                        if name not in columns_to_exclude
445                    )
446                    continue
447
448            for name in columns:
449                if name in columns_to_exclude or name in coalesced_columns:
450                    continue
451                if name in using_column_tables and table in using_column_tables[name]:
452                    coalesced_columns.add(name)
453                    tables = using_column_tables[name]
454                    coalesce = [exp.column(name, table=table) for table in tables]
455
456                    new_selections.append(
457                        alias(
458                            exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
459                            alias=name,
460                            copy=False,
461                        )
462                    )
463                else:
464                    alias_ = replace_columns.get(table_id, {}).get(name, name)
465                    column = exp.column(name, table=table)
466                    new_selections.append(
467                        alias(column, alias_, copy=False) if alias_ != name else column
468                    )
469
470    # Ensures we don't overwrite the initial selections with an empty list
471    if new_selections and isinstance(scope.expression, exp.Select):
472        scope.expression.set("expressions", new_selections)
473
474
475def _add_except_columns(
476    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
477) -> None:
478    except_ = expression.args.get("except")
479
480    if not except_:
481        return
482
483    columns = {e.name for e in except_}
484
485    for table in tables:
486        except_columns[id(table)] = columns
487
488
489def _add_replace_columns(
490    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
491) -> None:
492    replace = expression.args.get("replace")
493
494    if not replace:
495        return
496
497    columns = {e.this.name: e.alias for e in replace}
498
499    for table in tables:
500        replace_columns[id(table)] = columns
501
502
503def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
504    """Ensure all output columns are aliased"""
505    if isinstance(scope_or_expression, exp.Expression):
506        scope = build_scope(scope_or_expression)
507        if not isinstance(scope, Scope):
508            return
509    else:
510        scope = scope_or_expression
511
512    new_selections = []
513    for i, (selection, aliased_column) in enumerate(
514        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
515    ):
516        if selection is None:
517            break
518
519        if isinstance(selection, exp.Subquery):
520            if not selection.output_name:
521                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
522        elif not isinstance(selection, exp.Alias) and not selection.is_star:
523            selection = alias(
524                selection,
525                alias=selection.output_name or f"_col_{i}",
526                copy=False,
527            )
528        if aliased_column:
529            selection.set("alias", exp.to_identifier(aliased_column))
530
531        new_selections.append(selection)
532
533    if isinstance(scope.expression, exp.Select):
534        scope.expression.set("expressions", new_selections)
535
536
537def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
538    """Makes sure all identifiers that need to be quoted are quoted."""
539    return expression.transform(
540        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
541    )
542
543
544def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
545    """
546    Pushes down the CTE alias columns into the projection,
547
548    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
549
550    Example:
551        >>> import sqlglot
552        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
553        >>> pushdown_cte_alias_columns(expression).sql()
554        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
555
556    Args:
557        expression: Expression to pushdown.
558
559    Returns:
560        The expression with the CTE aliases pushed down into the projection.
561    """
562    for cte in expression.find_all(exp.CTE):
563        if cte.alias_column_names:
564            new_expressions = []
565            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
566                if isinstance(projection, exp.Alias):
567                    projection.set("alias", _alias)
568                else:
569                    projection = alias(projection, alias=_alias)
570                new_expressions.append(projection)
571            cte.this.set("expressions", new_expressions)
572
573    return expression
574
575
576class Resolver:
577    """
578    Helper for resolving columns.
579
580    This is a class so we can lazily load some things and easily share them across functions.
581    """
582
583    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
584        self.scope = scope
585        self.schema = schema
586        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
587        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
588        self._all_columns: t.Optional[t.Set[str]] = None
589        self._infer_schema = infer_schema
590
591    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
592        """
593        Get the table for a column name.
594
595        Args:
596            column_name: The column name to find the table for.
597        Returns:
598            The table name if it can be found/inferred.
599        """
600        if self._unambiguous_columns is None:
601            self._unambiguous_columns = self._get_unambiguous_columns(
602                self._get_all_source_columns()
603            )
604
605        table_name = self._unambiguous_columns.get(column_name)
606
607        if not table_name and self._infer_schema:
608            sources_without_schema = tuple(
609                source
610                for source, columns in self._get_all_source_columns().items()
611                if not columns or "*" in columns
612            )
613            if len(sources_without_schema) == 1:
614                table_name = sources_without_schema[0]
615
616        if table_name not in self.scope.selected_sources:
617            return exp.to_identifier(table_name)
618
619        node, _ = self.scope.selected_sources.get(table_name)
620
621        if isinstance(node, exp.Query):
622            while node and node.alias != table_name:
623                node = node.parent
624
625        node_alias = node.args.get("alias")
626        if node_alias:
627            return exp.to_identifier(node_alias.this)
628
629        return exp.to_identifier(table_name)
630
631    @property
632    def all_columns(self) -> t.Set[str]:
633        """All available columns of all sources in this scope"""
634        if self._all_columns is None:
635            self._all_columns = {
636                column for columns in self._get_all_source_columns().values() for column in columns
637            }
638        return self._all_columns
639
640    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
641        """Resolve the source columns for a given source `name`."""
642        if name not in self.scope.sources:
643            raise OptimizeError(f"Unknown table: {name}")
644
645        source = self.scope.sources[name]
646
647        if isinstance(source, exp.Table):
648            columns = self.schema.column_names(source, only_visible)
649        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
650            columns = source.expression.alias_column_names
651        else:
652            columns = source.expression.named_selects
653
654        node, _ = self.scope.selected_sources.get(name) or (None, None)
655        if isinstance(node, Scope):
656            column_aliases = node.expression.alias_column_names
657        elif isinstance(node, exp.Expression):
658            column_aliases = node.alias_column_names
659        else:
660            column_aliases = []
661
662        if column_aliases:
663            # If the source's columns are aliased, their aliases shadow the corresponding column names.
664            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
665            return [
666                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
667            ]
668        return columns
669
670    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
671        if self._source_columns is None:
672            self._source_columns = {
673                source_name: self.get_source_columns(source_name)
674                for source_name, source in itertools.chain(
675                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
676                )
677            }
678        return self._source_columns
679
680    def _get_unambiguous_columns(
681        self, source_columns: t.Dict[str, t.Sequence[str]]
682    ) -> t.Mapping[str, str]:
683        """
684        Find all the unambiguous columns in sources.
685
686        Args:
687            source_columns: Mapping of names to source columns.
688
689        Returns:
690            Mapping of column name to source name.
691        """
692        if not source_columns:
693            return {}
694
695        source_columns_pairs = list(source_columns.items())
696
697        first_table, first_columns = source_columns_pairs[0]
698
699        if len(source_columns_pairs) == 1:
700            # Performance optimization - avoid copying first_columns if there is only one table.
701            return SingleValuedMapping(first_columns, first_table)
702
703        unambiguous_columns = {col: first_table for col in first_columns}
704        all_columns = set(unambiguous_columns)
705
706        for table, columns in source_columns_pairs[1:]:
707            unique = set(columns)
708            ambiguous = all_columns.intersection(unique)
709            all_columns.update(columns)
710
711            for column in ambiguous:
712                unambiguous_columns.pop(column, None)
713            for column in unique.difference(ambiguous):
714                unambiguous_columns[column] = table
715
716        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
19def qualify_columns(
20    expression: exp.Expression,
21    schema: t.Dict | Schema,
22    expand_alias_refs: bool = True,
23    expand_stars: bool = True,
24    infer_schema: t.Optional[bool] = None,
25) -> exp.Expression:
26    """
27    Rewrite sqlglot AST to have fully qualified columns.
28
29    Example:
30        >>> import sqlglot
31        >>> schema = {"tbl": {"col": "INT"}}
32        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
33        >>> qualify_columns(expression, schema).sql()
34        'SELECT tbl.col AS col FROM tbl'
35
36    Args:
37        expression: Expression to qualify.
38        schema: Database schema.
39        expand_alias_refs: Whether to expand references to aliases.
40        expand_stars: Whether to expand star queries. This is a necessary step
41            for most of the optimizer's rules to work; do not set to False unless you
42            know what you're doing!
43        infer_schema: Whether to infer the schema if missing.
44
45    Returns:
46        The qualified expression.
47
48    Notes:
49        - Currently only handles a single PIVOT or UNPIVOT operator
50    """
51    schema = ensure_schema(schema)
52    infer_schema = schema.empty if infer_schema is None else infer_schema
53    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
54
55    for scope in traverse_scope(expression):
56        resolver = Resolver(scope, schema, infer_schema=infer_schema)
57        _pop_table_column_aliases(scope.ctes)
58        _pop_table_column_aliases(scope.derived_tables)
59        using_column_tables = _expand_using(scope, resolver)
60
61        if schema.empty and expand_alias_refs:
62            _expand_alias_refs(scope, resolver)
63
64        _qualify_columns(scope, resolver)
65
66        if not schema.empty and expand_alias_refs:
67            _expand_alias_refs(scope, resolver)
68
69        if not isinstance(scope.expression, exp.UDTF):
70            if expand_stars:
71                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
72            qualify_outputs(scope)
73
74        _expand_group_by(scope)
75        _expand_order_by(scope, resolver)
76
77    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 80def validate_qualify_columns(expression: E) -> E:
 81    """Raise an `OptimizeError` if any columns aren't qualified"""
 82    all_unqualified_columns = []
 83    for scope in traverse_scope(expression):
 84        if isinstance(scope.expression, exp.Select):
 85            unqualified_columns = scope.unqualified_columns
 86
 87            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 88                column = scope.external_columns[0]
 89                for_table = f" for table: '{column.table}'" if column.table else ""
 90                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 91
 92            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 93                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 94                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 95                # this list here to ensure those in the former category will be excluded.
 96                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 97                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 98
 99            all_unqualified_columns.extend(unqualified_columns)
100
101    if all_unqualified_columns:
102        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
103
104    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
504def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
505    """Ensure all output columns are aliased"""
506    if isinstance(scope_or_expression, exp.Expression):
507        scope = build_scope(scope_or_expression)
508        if not isinstance(scope, Scope):
509            return
510    else:
511        scope = scope_or_expression
512
513    new_selections = []
514    for i, (selection, aliased_column) in enumerate(
515        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
516    ):
517        if selection is None:
518            break
519
520        if isinstance(selection, exp.Subquery):
521            if not selection.output_name:
522                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
523        elif not isinstance(selection, exp.Alias) and not selection.is_star:
524            selection = alias(
525                selection,
526                alias=selection.output_name or f"_col_{i}",
527                copy=False,
528            )
529        if aliased_column:
530            selection.set("alias", exp.to_identifier(aliased_column))
531
532        new_selections.append(selection)
533
534    if isinstance(scope.expression, exp.Select):
535        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
538def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
539    """Makes sure all identifiers that need to be quoted are quoted."""
540    return expression.transform(
541        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
542    )

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
545def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
546    """
547    Pushes down the CTE alias columns into the projection,
548
549    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
550
551    Example:
552        >>> import sqlglot
553        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
554        >>> pushdown_cte_alias_columns(expression).sql()
555        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
556
557    Args:
558        expression: Expression to pushdown.
559
560    Returns:
561        The expression with the CTE aliases pushed down into the projection.
562    """
563    for cte in expression.find_all(exp.CTE):
564        if cte.alias_column_names:
565            new_expressions = []
566            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
567                if isinstance(projection, exp.Alias):
568                    projection.set("alias", _alias)
569                else:
570                    projection = alias(projection, alias=_alias)
571                new_expressions.append(projection)
572            cte.this.set("expressions", new_expressions)
573
574    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
577class Resolver:
578    """
579    Helper for resolving columns.
580
581    This is a class so we can lazily load some things and easily share them across functions.
582    """
583
584    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
585        self.scope = scope
586        self.schema = schema
587        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
588        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
589        self._all_columns: t.Optional[t.Set[str]] = None
590        self._infer_schema = infer_schema
591
592    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
593        """
594        Get the table for a column name.
595
596        Args:
597            column_name: The column name to find the table for.
598        Returns:
599            The table name if it can be found/inferred.
600        """
601        if self._unambiguous_columns is None:
602            self._unambiguous_columns = self._get_unambiguous_columns(
603                self._get_all_source_columns()
604            )
605
606        table_name = self._unambiguous_columns.get(column_name)
607
608        if not table_name and self._infer_schema:
609            sources_without_schema = tuple(
610                source
611                for source, columns in self._get_all_source_columns().items()
612                if not columns or "*" in columns
613            )
614            if len(sources_without_schema) == 1:
615                table_name = sources_without_schema[0]
616
617        if table_name not in self.scope.selected_sources:
618            return exp.to_identifier(table_name)
619
620        node, _ = self.scope.selected_sources.get(table_name)
621
622        if isinstance(node, exp.Query):
623            while node and node.alias != table_name:
624                node = node.parent
625
626        node_alias = node.args.get("alias")
627        if node_alias:
628            return exp.to_identifier(node_alias.this)
629
630        return exp.to_identifier(table_name)
631
632    @property
633    def all_columns(self) -> t.Set[str]:
634        """All available columns of all sources in this scope"""
635        if self._all_columns is None:
636            self._all_columns = {
637                column for columns in self._get_all_source_columns().values() for column in columns
638            }
639        return self._all_columns
640
641    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
642        """Resolve the source columns for a given source `name`."""
643        if name not in self.scope.sources:
644            raise OptimizeError(f"Unknown table: {name}")
645
646        source = self.scope.sources[name]
647
648        if isinstance(source, exp.Table):
649            columns = self.schema.column_names(source, only_visible)
650        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
651            columns = source.expression.alias_column_names
652        else:
653            columns = source.expression.named_selects
654
655        node, _ = self.scope.selected_sources.get(name) or (None, None)
656        if isinstance(node, Scope):
657            column_aliases = node.expression.alias_column_names
658        elif isinstance(node, exp.Expression):
659            column_aliases = node.alias_column_names
660        else:
661            column_aliases = []
662
663        if column_aliases:
664            # If the source's columns are aliased, their aliases shadow the corresponding column names.
665            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
666            return [
667                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
668            ]
669        return columns
670
671    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
672        if self._source_columns is None:
673            self._source_columns = {
674                source_name: self.get_source_columns(source_name)
675                for source_name, source in itertools.chain(
676                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
677                )
678            }
679        return self._source_columns
680
681    def _get_unambiguous_columns(
682        self, source_columns: t.Dict[str, t.Sequence[str]]
683    ) -> t.Mapping[str, str]:
684        """
685        Find all the unambiguous columns in sources.
686
687        Args:
688            source_columns: Mapping of names to source columns.
689
690        Returns:
691            Mapping of column name to source name.
692        """
693        if not source_columns:
694            return {}
695
696        source_columns_pairs = list(source_columns.items())
697
698        first_table, first_columns = source_columns_pairs[0]
699
700        if len(source_columns_pairs) == 1:
701            # Performance optimization - avoid copying first_columns if there is only one table.
702            return SingleValuedMapping(first_columns, first_table)
703
704        unambiguous_columns = {col: first_table for col in first_columns}
705        all_columns = set(unambiguous_columns)
706
707        for table, columns in source_columns_pairs[1:]:
708            unique = set(columns)
709            ambiguous = all_columns.intersection(unique)
710            all_columns.update(columns)
711
712            for column in ambiguous:
713                unambiguous_columns.pop(column, None)
714            for column in unique.difference(ambiguous):
715                unambiguous_columns[column] = table
716
717        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
584    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
585        self.scope = scope
586        self.schema = schema
587        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
588        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
589        self._all_columns: t.Optional[t.Set[str]] = None
590        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
592    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
593        """
594        Get the table for a column name.
595
596        Args:
597            column_name: The column name to find the table for.
598        Returns:
599            The table name if it can be found/inferred.
600        """
601        if self._unambiguous_columns is None:
602            self._unambiguous_columns = self._get_unambiguous_columns(
603                self._get_all_source_columns()
604            )
605
606        table_name = self._unambiguous_columns.get(column_name)
607
608        if not table_name and self._infer_schema:
609            sources_without_schema = tuple(
610                source
611                for source, columns in self._get_all_source_columns().items()
612                if not columns or "*" in columns
613            )
614            if len(sources_without_schema) == 1:
615                table_name = sources_without_schema[0]
616
617        if table_name not in self.scope.selected_sources:
618            return exp.to_identifier(table_name)
619
620        node, _ = self.scope.selected_sources.get(table_name)
621
622        if isinstance(node, exp.Query):
623            while node and node.alias != table_name:
624                node = node.parent
625
626        node_alias = node.args.get("alias")
627        if node_alias:
628            return exp.to_identifier(node_alias.this)
629
630        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
632    @property
633    def all_columns(self) -> t.Set[str]:
634        """All available columns of all sources in this scope"""
635        if self._all_columns is None:
636            self._all_columns = {
637                column for columns in self._get_all_source_columns().values() for column in columns
638            }
639        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
641    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
642        """Resolve the source columns for a given source `name`."""
643        if name not in self.scope.sources:
644            raise OptimizeError(f"Unknown table: {name}")
645
646        source = self.scope.sources[name]
647
648        if isinstance(source, exp.Table):
649            columns = self.schema.column_names(source, only_visible)
650        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
651            columns = source.expression.alias_column_names
652        else:
653            columns = source.expression.named_selects
654
655        node, _ = self.scope.selected_sources.get(name) or (None, None)
656        if isinstance(node, Scope):
657            column_aliases = node.expression.alias_column_names
658        elif isinstance(node, exp.Expression):
659            column_aliases = node.alias_column_names
660        else:
661            column_aliases = []
662
663        if column_aliases:
664            # If the source's columns are aliased, their aliases shadow the corresponding column names.
665            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
666            return [
667                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
668            ]
669        return columns

Resolve the source columns for a given source name.