Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.annotate_types import TypeAnnotator
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import E
 17
 18
 19def qualify_columns(
 20    expression: exp.Expression,
 21    schema: t.Dict | Schema,
 22    expand_alias_refs: bool = True,
 23    expand_stars: bool = True,
 24    infer_schema: t.Optional[bool] = None,
 25) -> exp.Expression:
 26    """
 27    Rewrite sqlglot AST to have fully qualified columns.
 28
 29    Example:
 30        >>> import sqlglot
 31        >>> schema = {"tbl": {"col": "INT"}}
 32        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 33        >>> qualify_columns(expression, schema).sql()
 34        'SELECT tbl.col AS col FROM tbl'
 35
 36    Args:
 37        expression: Expression to qualify.
 38        schema: Database schema.
 39        expand_alias_refs: Whether to expand references to aliases.
 40        expand_stars: Whether to expand star queries. This is a necessary step
 41            for most of the optimizer's rules to work; do not set to False unless you
 42            know what you're doing!
 43        infer_schema: Whether to infer the schema if missing.
 44
 45    Returns:
 46        The qualified expression.
 47
 48    Notes:
 49        - Currently only handles a single PIVOT or UNPIVOT operator
 50    """
 51    schema = ensure_schema(schema)
 52    annotator = TypeAnnotator(schema)
 53    infer_schema = schema.empty if infer_schema is None else infer_schema
 54    dialect = Dialect.get_or_raise(schema.dialect)
 55    pseudocolumns = dialect.PSEUDOCOLUMNS
 56
 57    for scope in traverse_scope(expression):
 58        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 59        _pop_table_column_aliases(scope.ctes)
 60        _pop_table_column_aliases(scope.derived_tables)
 61        using_column_tables = _expand_using(scope, resolver)
 62
 63        if schema.empty and expand_alias_refs:
 64            _expand_alias_refs(scope, resolver)
 65
 66        _convert_columns_to_dots(scope, resolver)
 67        _qualify_columns(scope, resolver)
 68
 69        if not schema.empty and expand_alias_refs:
 70            _expand_alias_refs(scope, resolver)
 71
 72        if not isinstance(scope.expression, exp.UDTF):
 73            if expand_stars:
 74                _expand_stars(
 75                    scope,
 76                    resolver,
 77                    using_column_tables,
 78                    pseudocolumns,
 79                    annotator,
 80                )
 81            qualify_outputs(scope)
 82
 83        _expand_group_by(scope)
 84        _expand_order_by(scope, resolver)
 85
 86        if dialect == "bigquery":
 87            annotator.annotate_scope(scope)
 88
 89    return expression
 90
 91
 92def validate_qualify_columns(expression: E) -> E:
 93    """Raise an `OptimizeError` if any columns aren't qualified"""
 94    all_unqualified_columns = []
 95    for scope in traverse_scope(expression):
 96        if isinstance(scope.expression, exp.Select):
 97            unqualified_columns = scope.unqualified_columns
 98
 99            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
100                column = scope.external_columns[0]
101                for_table = f" for table: '{column.table}'" if column.table else ""
102                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
103
104            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
105                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
106                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
107                # this list here to ensure those in the former category will be excluded.
108                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
109                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
110
111            all_unqualified_columns.extend(unqualified_columns)
112
113    if all_unqualified_columns:
114        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
115
116    return expression
117
118
119def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
120    name_column = []
121    field = unpivot.args.get("field")
122    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
123        name_column.append(field.this)
124
125    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
126    return itertools.chain(name_column, value_columns)
127
128
129def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
130    """
131    Remove table column aliases.
132
133    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
134    """
135    for derived_table in derived_tables:
136        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
137            continue
138        table_alias = derived_table.args.get("alias")
139        if table_alias:
140            table_alias.args.pop("columns", None)
141
142
143def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
144    joins = list(scope.find_all(exp.Join))
145    names = {join.alias_or_name for join in joins}
146    ordered = [key for key in scope.selected_sources if key not in names]
147
148    # Mapping of automatically joined column names to an ordered set of source names (dict).
149    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
150
151    for join in joins:
152        using = join.args.get("using")
153
154        if not using:
155            continue
156
157        join_table = join.alias_or_name
158
159        columns = {}
160
161        for source_name in scope.selected_sources:
162            if source_name in ordered:
163                for column_name in resolver.get_source_columns(source_name):
164                    if column_name not in columns:
165                        columns[column_name] = source_name
166
167        source_table = ordered[-1]
168        ordered.append(join_table)
169        join_columns = resolver.get_source_columns(join_table)
170        conditions = []
171
172        for identifier in using:
173            identifier = identifier.name
174            table = columns.get(identifier)
175
176            if not table or identifier not in join_columns:
177                if (columns and "*" not in columns) and join_columns:
178                    raise OptimizeError(f"Cannot automatically join: {identifier}")
179
180            table = table or source_table
181            conditions.append(
182                exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table))
183            )
184
185            # Set all values in the dict to None, because we only care about the key ordering
186            tables = column_tables.setdefault(identifier, {})
187            if table not in tables:
188                tables[table] = None
189            if join_table not in tables:
190                tables[join_table] = None
191
192        join.args.pop("using")
193        join.set("on", exp.and_(*conditions, copy=False))
194
195    if column_tables:
196        for column in scope.columns:
197            if not column.table and column.name in column_tables:
198                tables = column_tables[column.name]
199                coalesce = [exp.column(column.name, table=table) for table in tables]
200                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
201
202                # Ensure selects keep their output name
203                if isinstance(column.parent, exp.Select):
204                    replacement = alias(replacement, alias=column.name, copy=False)
205
206                scope.replace(column, replacement)
207
208    return column_tables
209
210
211def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
212    expression = scope.expression
213
214    if not isinstance(expression, exp.Select):
215        return
216
217    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
218
219    def replace_columns(
220        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
221    ) -> None:
222        if not node:
223            return
224
225        for column in walk_in_scope(node, prune=lambda node: node.is_star):
226            if not isinstance(column, exp.Column):
227                continue
228
229            table = resolver.get_table(column.name) if resolve_table and not column.table else None
230            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
231            double_agg = (
232                (
233                    alias_expr.find(exp.AggFunc)
234                    and (
235                        column.find_ancestor(exp.AggFunc)
236                        and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
237                    )
238                )
239                if alias_expr
240                else False
241            )
242
243            if table and (not alias_expr or double_agg):
244                column.set("table", table)
245            elif not column.table and alias_expr and not double_agg:
246                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
247                    if literal_index:
248                        column.replace(exp.Literal.number(i))
249                else:
250                    column = column.replace(exp.paren(alias_expr))
251                    simplified = simplify_parens(column)
252                    if simplified is not column:
253                        column.replace(simplified)
254
255    for i, projection in enumerate(scope.expression.selects):
256        replace_columns(projection)
257
258        if isinstance(projection, exp.Alias):
259            alias_to_expression[projection.alias] = (projection.this, i + 1)
260
261    replace_columns(expression.args.get("where"))
262    replace_columns(expression.args.get("group"), literal_index=True)
263    replace_columns(expression.args.get("having"), resolve_table=True)
264    replace_columns(expression.args.get("qualify"), resolve_table=True)
265
266    scope.clear_cache()
267
268
269def _expand_group_by(scope: Scope) -> None:
270    expression = scope.expression
271    group = expression.args.get("group")
272    if not group:
273        return
274
275    group.set("expressions", _expand_positional_references(scope, group.expressions))
276    expression.set("group", group)
277
278
279def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
280    order = scope.expression.args.get("order")
281    if not order:
282        return
283
284    ordereds = order.expressions
285    for ordered, new_expression in zip(
286        ordereds,
287        _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
288    ):
289        for agg in ordered.find_all(exp.AggFunc):
290            for col in agg.find_all(exp.Column):
291                if not col.table:
292                    col.set("table", resolver.get_table(col.name))
293
294        ordered.set("this", new_expression)
295
296    if scope.expression.args.get("group"):
297        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
298
299        for ordered in ordereds:
300            ordered = ordered.this
301
302            ordered.replace(
303                exp.to_identifier(_select_by_pos(scope, ordered).alias)
304                if ordered.is_int
305                else selects.get(ordered, ordered)
306            )
307
308
309def _expand_positional_references(
310    scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
311) -> t.List[exp.Expression]:
312    new_nodes: t.List[exp.Expression] = []
313    for node in expressions:
314        if node.is_int:
315            select = _select_by_pos(scope, t.cast(exp.Literal, node))
316
317            if alias:
318                new_nodes.append(exp.column(select.args["alias"].copy()))
319            else:
320                select = select.this
321
322                if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest):
323                    new_nodes.append(node)
324                else:
325                    new_nodes.append(select.copy())
326        else:
327            new_nodes.append(node)
328
329    return new_nodes
330
331
332def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
333    try:
334        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
335    except IndexError:
336        raise OptimizeError(f"Unknown output column: {node.name}")
337
338
339def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
340    """
341    Converts `Column` instances that represent struct field lookup into chained `Dots`.
342
343    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
344    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
345    """
346    converted = False
347    for column in itertools.chain(scope.columns, scope.stars):
348        if isinstance(column, exp.Dot):
349            continue
350
351        column_table: t.Optional[str | exp.Identifier] = column.table
352        if (
353            column_table
354            and column_table not in scope.sources
355            and (
356                not scope.parent
357                or column_table not in scope.parent.sources
358                or not scope.is_correlated_subquery
359            )
360        ):
361            root, *parts = column.parts
362
363            if root.name in scope.sources:
364                # The struct is already qualified, but we still need to change the AST
365                column_table = root
366                root, *parts = parts
367            else:
368                column_table = resolver.get_table(root.name)
369
370            if column_table:
371                converted = True
372                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
373
374    if converted:
375        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
376        # a `for column in scope.columns` iteration, even though they shouldn't be
377        scope.clear_cache()
378
379
380def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
381    """Disambiguate columns, ensuring each column specifies a source"""
382    for column in scope.columns:
383        column_table = column.table
384        column_name = column.name
385
386        if column_table and column_table in scope.sources:
387            source_columns = resolver.get_source_columns(column_table)
388            if source_columns and column_name not in source_columns and "*" not in source_columns:
389                raise OptimizeError(f"Unknown column: {column_name}")
390
391        if not column_table:
392            if scope.pivots and not column.find_ancestor(exp.Pivot):
393                # If the column is under the Pivot expression, we need to qualify it
394                # using the name of the pivoted source instead of the pivot's alias
395                column.set("table", exp.to_identifier(scope.pivots[0].alias))
396                continue
397
398            # column_table can be a '' because bigquery unnest has no table alias
399            column_table = resolver.get_table(column_name)
400            if column_table:
401                column.set("table", column_table)
402
403    for pivot in scope.pivots:
404        for column in pivot.find_all(exp.Column):
405            if not column.table and column.name in resolver.all_columns:
406                column_table = resolver.get_table(column.name)
407                if column_table:
408                    column.set("table", column_table)
409
410
411def _expand_struct_stars(
412    expression: exp.Dot,
413) -> t.List[exp.Alias]:
414    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
415
416    dot_column = t.cast(exp.Column, expression.find(exp.Column))
417    if not dot_column.is_type(exp.DataType.Type.STRUCT):
418        return []
419
420    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
421    dot_column = dot_column.copy()
422    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
423
424    # First part is the table name and last part is the star so they can be dropped
425    dot_parts = expression.parts[1:-1]
426
427    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
428    for part in dot_parts[1:]:
429        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
430            # Unable to expand star unless all fields are named
431            if not isinstance(field.this, exp.Identifier):
432                return []
433
434            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
435                starting_struct = field
436                break
437        else:
438            # There is no matching field in the struct
439            return []
440
441    taken_names = set()
442    new_selections = []
443
444    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
445        name = field.name
446
447        # Ambiguous or anonymous fields can't be expanded
448        if name in taken_names or not isinstance(field.this, exp.Identifier):
449            return []
450
451        taken_names.add(name)
452
453        this = field.this.copy()
454        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
455        new_column = exp.column(
456            t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts
457        )
458        new_selections.append(alias(new_column, this, copy=False))
459
460    return new_selections
461
462
463def _expand_stars(
464    scope: Scope,
465    resolver: Resolver,
466    using_column_tables: t.Dict[str, t.Any],
467    pseudocolumns: t.Set[str],
468    annotator: TypeAnnotator,
469) -> None:
470    """Expand stars to lists of column selections"""
471
472    new_selections = []
473    except_columns: t.Dict[int, t.Set[str]] = {}
474    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
475    coalesced_columns = set()
476    dialect = resolver.schema.dialect
477
478    pivot_output_columns = None
479    pivot_exclude_columns = None
480
481    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
482    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
483        if pivot.unpivot:
484            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
485
486            field = pivot.args.get("field")
487            if isinstance(field, exp.In):
488                pivot_exclude_columns = {
489                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
490                }
491        else:
492            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
493
494            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
495            if not pivot_output_columns:
496                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
497
498    is_bigquery = dialect == "bigquery"
499    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
500        # Found struct expansion, annotate scope ahead of time
501        annotator.annotate_scope(scope)
502
503    for expression in scope.expression.selects:
504        tables = []
505        if isinstance(expression, exp.Star):
506            tables.extend(scope.selected_sources)
507            _add_except_columns(expression, tables, except_columns)
508            _add_replace_columns(expression, tables, replace_columns)
509        elif expression.is_star:
510            if not isinstance(expression, exp.Dot):
511                tables.append(expression.table)
512                _add_except_columns(expression.this, tables, except_columns)
513                _add_replace_columns(expression.this, tables, replace_columns)
514            elif is_bigquery:
515                struct_fields = _expand_struct_stars(expression)
516                if struct_fields:
517                    new_selections.extend(struct_fields)
518                    continue
519
520        if not tables:
521            new_selections.append(expression)
522            continue
523
524        for table in tables:
525            if table not in scope.sources:
526                raise OptimizeError(f"Unknown table: {table}")
527
528            columns = resolver.get_source_columns(table, only_visible=True)
529            columns = columns or scope.outer_columns
530
531            if pseudocolumns:
532                columns = [name for name in columns if name.upper() not in pseudocolumns]
533
534            if not columns or "*" in columns:
535                return
536
537            table_id = id(table)
538            columns_to_exclude = except_columns.get(table_id) or set()
539
540            if pivot:
541                if pivot_output_columns and pivot_exclude_columns:
542                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
543                    pivot_columns.extend(pivot_output_columns)
544                else:
545                    pivot_columns = pivot.alias_column_names
546
547                if pivot_columns:
548                    new_selections.extend(
549                        alias(exp.column(name, table=pivot.alias), name, copy=False)
550                        for name in pivot_columns
551                        if name not in columns_to_exclude
552                    )
553                    continue
554
555            for name in columns:
556                if name in columns_to_exclude or name in coalesced_columns:
557                    continue
558                if name in using_column_tables and table in using_column_tables[name]:
559                    coalesced_columns.add(name)
560                    tables = using_column_tables[name]
561                    coalesce = [exp.column(name, table=table) for table in tables]
562
563                    new_selections.append(
564                        alias(
565                            exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
566                            alias=name,
567                            copy=False,
568                        )
569                    )
570                else:
571                    alias_ = replace_columns.get(table_id, {}).get(name, name)
572                    column = exp.column(name, table=table)
573                    new_selections.append(
574                        alias(column, alias_, copy=False) if alias_ != name else column
575                    )
576
577    # Ensures we don't overwrite the initial selections with an empty list
578    if new_selections and isinstance(scope.expression, exp.Select):
579        scope.expression.set("expressions", new_selections)
580
581
582def _add_except_columns(
583    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
584) -> None:
585    except_ = expression.args.get("except")
586
587    if not except_:
588        return
589
590    columns = {e.name for e in except_}
591
592    for table in tables:
593        except_columns[id(table)] = columns
594
595
596def _add_replace_columns(
597    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
598) -> None:
599    replace = expression.args.get("replace")
600
601    if not replace:
602        return
603
604    columns = {e.this.name: e.alias for e in replace}
605
606    for table in tables:
607        replace_columns[id(table)] = columns
608
609
610def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
611    """Ensure all output columns are aliased"""
612    if isinstance(scope_or_expression, exp.Expression):
613        scope = build_scope(scope_or_expression)
614        if not isinstance(scope, Scope):
615            return
616    else:
617        scope = scope_or_expression
618
619    new_selections = []
620    for i, (selection, aliased_column) in enumerate(
621        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
622    ):
623        if selection is None:
624            break
625
626        if isinstance(selection, exp.Subquery):
627            if not selection.output_name:
628                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
629        elif not isinstance(selection, exp.Alias) and not selection.is_star:
630            selection = alias(
631                selection,
632                alias=selection.output_name or f"_col_{i}",
633                copy=False,
634            )
635        if aliased_column:
636            selection.set("alias", exp.to_identifier(aliased_column))
637
638        new_selections.append(selection)
639
640    if isinstance(scope.expression, exp.Select):
641        scope.expression.set("expressions", new_selections)
642
643
644def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
645    """Makes sure all identifiers that need to be quoted are quoted."""
646    return expression.transform(
647        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
648    )  # type: ignore
649
650
651def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
652    """
653    Pushes down the CTE alias columns into the projection,
654
655    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
656
657    Example:
658        >>> import sqlglot
659        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
660        >>> pushdown_cte_alias_columns(expression).sql()
661        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
662
663    Args:
664        expression: Expression to pushdown.
665
666    Returns:
667        The expression with the CTE aliases pushed down into the projection.
668    """
669    for cte in expression.find_all(exp.CTE):
670        if cte.alias_column_names:
671            new_expressions = []
672            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
673                if isinstance(projection, exp.Alias):
674                    projection.set("alias", _alias)
675                else:
676                    projection = alias(projection, alias=_alias)
677                new_expressions.append(projection)
678            cte.this.set("expressions", new_expressions)
679
680    return expression
681
682
683class Resolver:
684    """
685    Helper for resolving columns.
686
687    This is a class so we can lazily load some things and easily share them across functions.
688    """
689
690    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
691        self.scope = scope
692        self.schema = schema
693        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
694        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
695        self._all_columns: t.Optional[t.Set[str]] = None
696        self._infer_schema = infer_schema
697
698    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
699        """
700        Get the table for a column name.
701
702        Args:
703            column_name: The column name to find the table for.
704        Returns:
705            The table name if it can be found/inferred.
706        """
707        if self._unambiguous_columns is None:
708            self._unambiguous_columns = self._get_unambiguous_columns(
709                self._get_all_source_columns()
710            )
711
712        table_name = self._unambiguous_columns.get(column_name)
713
714        if not table_name and self._infer_schema:
715            sources_without_schema = tuple(
716                source
717                for source, columns in self._get_all_source_columns().items()
718                if not columns or "*" in columns
719            )
720            if len(sources_without_schema) == 1:
721                table_name = sources_without_schema[0]
722
723        if table_name not in self.scope.selected_sources:
724            return exp.to_identifier(table_name)
725
726        node, _ = self.scope.selected_sources.get(table_name)
727
728        if isinstance(node, exp.Query):
729            while node and node.alias != table_name:
730                node = node.parent
731
732        node_alias = node.args.get("alias")
733        if node_alias:
734            return exp.to_identifier(node_alias.this)
735
736        return exp.to_identifier(table_name)
737
738    @property
739    def all_columns(self) -> t.Set[str]:
740        """All available columns of all sources in this scope"""
741        if self._all_columns is None:
742            self._all_columns = {
743                column for columns in self._get_all_source_columns().values() for column in columns
744            }
745        return self._all_columns
746
747    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
748        """Resolve the source columns for a given source `name`."""
749        if name not in self.scope.sources:
750            raise OptimizeError(f"Unknown table: {name}")
751
752        source = self.scope.sources[name]
753
754        if isinstance(source, exp.Table):
755            columns = self.schema.column_names(source, only_visible)
756        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
757            columns = source.expression.named_selects
758
759            # in bigquery, unnest structs are automatically scoped as tables, so you can
760            # directly select a struct field in a query.
761            # this handles the case where the unnest is statically defined.
762            if self.schema.dialect == "bigquery":
763                if source.expression.is_type(exp.DataType.Type.STRUCT):
764                    for k in source.expression.type.expressions:  # type: ignore
765                        columns.append(k.name)
766        else:
767            columns = source.expression.named_selects
768
769        node, _ = self.scope.selected_sources.get(name) or (None, None)
770        if isinstance(node, Scope):
771            column_aliases = node.expression.alias_column_names
772        elif isinstance(node, exp.Expression):
773            column_aliases = node.alias_column_names
774        else:
775            column_aliases = []
776
777        if column_aliases:
778            # If the source's columns are aliased, their aliases shadow the corresponding column names.
779            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
780            return [
781                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
782            ]
783        return columns
784
785    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
786        if self._source_columns is None:
787            self._source_columns = {
788                source_name: self.get_source_columns(source_name)
789                for source_name, source in itertools.chain(
790                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
791                )
792            }
793        return self._source_columns
794
795    def _get_unambiguous_columns(
796        self, source_columns: t.Dict[str, t.Sequence[str]]
797    ) -> t.Mapping[str, str]:
798        """
799        Find all the unambiguous columns in sources.
800
801        Args:
802            source_columns: Mapping of names to source columns.
803
804        Returns:
805            Mapping of column name to source name.
806        """
807        if not source_columns:
808            return {}
809
810        source_columns_pairs = list(source_columns.items())
811
812        first_table, first_columns = source_columns_pairs[0]
813
814        if len(source_columns_pairs) == 1:
815            # Performance optimization - avoid copying first_columns if there is only one table.
816            return SingleValuedMapping(first_columns, first_table)
817
818        unambiguous_columns = {col: first_table for col in first_columns}
819        all_columns = set(unambiguous_columns)
820
821        for table, columns in source_columns_pairs[1:]:
822            unique = set(columns)
823            ambiguous = all_columns.intersection(unique)
824            all_columns.update(columns)
825
826            for column in ambiguous:
827                unambiguous_columns.pop(column, None)
828            for column in unique.difference(ambiguous):
829                unambiguous_columns[column] = table
830
831        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
20def qualify_columns(
21    expression: exp.Expression,
22    schema: t.Dict | Schema,
23    expand_alias_refs: bool = True,
24    expand_stars: bool = True,
25    infer_schema: t.Optional[bool] = None,
26) -> exp.Expression:
27    """
28    Rewrite sqlglot AST to have fully qualified columns.
29
30    Example:
31        >>> import sqlglot
32        >>> schema = {"tbl": {"col": "INT"}}
33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
34        >>> qualify_columns(expression, schema).sql()
35        'SELECT tbl.col AS col FROM tbl'
36
37    Args:
38        expression: Expression to qualify.
39        schema: Database schema.
40        expand_alias_refs: Whether to expand references to aliases.
41        expand_stars: Whether to expand star queries. This is a necessary step
42            for most of the optimizer's rules to work; do not set to False unless you
43            know what you're doing!
44        infer_schema: Whether to infer the schema if missing.
45
46    Returns:
47        The qualified expression.
48
49    Notes:
50        - Currently only handles a single PIVOT or UNPIVOT operator
51    """
52    schema = ensure_schema(schema)
53    annotator = TypeAnnotator(schema)
54    infer_schema = schema.empty if infer_schema is None else infer_schema
55    dialect = Dialect.get_or_raise(schema.dialect)
56    pseudocolumns = dialect.PSEUDOCOLUMNS
57
58    for scope in traverse_scope(expression):
59        resolver = Resolver(scope, schema, infer_schema=infer_schema)
60        _pop_table_column_aliases(scope.ctes)
61        _pop_table_column_aliases(scope.derived_tables)
62        using_column_tables = _expand_using(scope, resolver)
63
64        if schema.empty and expand_alias_refs:
65            _expand_alias_refs(scope, resolver)
66
67        _convert_columns_to_dots(scope, resolver)
68        _qualify_columns(scope, resolver)
69
70        if not schema.empty and expand_alias_refs:
71            _expand_alias_refs(scope, resolver)
72
73        if not isinstance(scope.expression, exp.UDTF):
74            if expand_stars:
75                _expand_stars(
76                    scope,
77                    resolver,
78                    using_column_tables,
79                    pseudocolumns,
80                    annotator,
81                )
82            qualify_outputs(scope)
83
84        _expand_group_by(scope)
85        _expand_order_by(scope, resolver)
86
87        if dialect == "bigquery":
88            annotator.annotate_scope(scope)
89
90    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 93def validate_qualify_columns(expression: E) -> E:
 94    """Raise an `OptimizeError` if any columns aren't qualified"""
 95    all_unqualified_columns = []
 96    for scope in traverse_scope(expression):
 97        if isinstance(scope.expression, exp.Select):
 98            unqualified_columns = scope.unqualified_columns
 99
100            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
101                column = scope.external_columns[0]
102                for_table = f" for table: '{column.table}'" if column.table else ""
103                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
104
105            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
106                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
107                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
108                # this list here to ensure those in the former category will be excluded.
109                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
110                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
111
112            all_unqualified_columns.extend(unqualified_columns)
113
114    if all_unqualified_columns:
115        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
116
117    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
611def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
612    """Ensure all output columns are aliased"""
613    if isinstance(scope_or_expression, exp.Expression):
614        scope = build_scope(scope_or_expression)
615        if not isinstance(scope, Scope):
616            return
617    else:
618        scope = scope_or_expression
619
620    new_selections = []
621    for i, (selection, aliased_column) in enumerate(
622        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
623    ):
624        if selection is None:
625            break
626
627        if isinstance(selection, exp.Subquery):
628            if not selection.output_name:
629                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
630        elif not isinstance(selection, exp.Alias) and not selection.is_star:
631            selection = alias(
632                selection,
633                alias=selection.output_name or f"_col_{i}",
634                copy=False,
635            )
636        if aliased_column:
637            selection.set("alias", exp.to_identifier(aliased_column))
638
639        new_selections.append(selection)
640
641    if isinstance(scope.expression, exp.Select):
642        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
645def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
646    """Makes sure all identifiers that need to be quoted are quoted."""
647    return expression.transform(
648        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
649    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
652def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
653    """
654    Pushes down the CTE alias columns into the projection,
655
656    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
657
658    Example:
659        >>> import sqlglot
660        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
661        >>> pushdown_cte_alias_columns(expression).sql()
662        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
663
664    Args:
665        expression: Expression to pushdown.
666
667    Returns:
668        The expression with the CTE aliases pushed down into the projection.
669    """
670    for cte in expression.find_all(exp.CTE):
671        if cte.alias_column_names:
672            new_expressions = []
673            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
674                if isinstance(projection, exp.Alias):
675                    projection.set("alias", _alias)
676                else:
677                    projection = alias(projection, alias=_alias)
678                new_expressions.append(projection)
679            cte.this.set("expressions", new_expressions)
680
681    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
684class Resolver:
685    """
686    Helper for resolving columns.
687
688    This is a class so we can lazily load some things and easily share them across functions.
689    """
690
691    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
692        self.scope = scope
693        self.schema = schema
694        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
695        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
696        self._all_columns: t.Optional[t.Set[str]] = None
697        self._infer_schema = infer_schema
698
699    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
700        """
701        Get the table for a column name.
702
703        Args:
704            column_name: The column name to find the table for.
705        Returns:
706            The table name if it can be found/inferred.
707        """
708        if self._unambiguous_columns is None:
709            self._unambiguous_columns = self._get_unambiguous_columns(
710                self._get_all_source_columns()
711            )
712
713        table_name = self._unambiguous_columns.get(column_name)
714
715        if not table_name and self._infer_schema:
716            sources_without_schema = tuple(
717                source
718                for source, columns in self._get_all_source_columns().items()
719                if not columns or "*" in columns
720            )
721            if len(sources_without_schema) == 1:
722                table_name = sources_without_schema[0]
723
724        if table_name not in self.scope.selected_sources:
725            return exp.to_identifier(table_name)
726
727        node, _ = self.scope.selected_sources.get(table_name)
728
729        if isinstance(node, exp.Query):
730            while node and node.alias != table_name:
731                node = node.parent
732
733        node_alias = node.args.get("alias")
734        if node_alias:
735            return exp.to_identifier(node_alias.this)
736
737        return exp.to_identifier(table_name)
738
739    @property
740    def all_columns(self) -> t.Set[str]:
741        """All available columns of all sources in this scope"""
742        if self._all_columns is None:
743            self._all_columns = {
744                column for columns in self._get_all_source_columns().values() for column in columns
745            }
746        return self._all_columns
747
748    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
749        """Resolve the source columns for a given source `name`."""
750        if name not in self.scope.sources:
751            raise OptimizeError(f"Unknown table: {name}")
752
753        source = self.scope.sources[name]
754
755        if isinstance(source, exp.Table):
756            columns = self.schema.column_names(source, only_visible)
757        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
758            columns = source.expression.named_selects
759
760            # in bigquery, unnest structs are automatically scoped as tables, so you can
761            # directly select a struct field in a query.
762            # this handles the case where the unnest is statically defined.
763            if self.schema.dialect == "bigquery":
764                if source.expression.is_type(exp.DataType.Type.STRUCT):
765                    for k in source.expression.type.expressions:  # type: ignore
766                        columns.append(k.name)
767        else:
768            columns = source.expression.named_selects
769
770        node, _ = self.scope.selected_sources.get(name) or (None, None)
771        if isinstance(node, Scope):
772            column_aliases = node.expression.alias_column_names
773        elif isinstance(node, exp.Expression):
774            column_aliases = node.alias_column_names
775        else:
776            column_aliases = []
777
778        if column_aliases:
779            # If the source's columns are aliased, their aliases shadow the corresponding column names.
780            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
781            return [
782                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
783            ]
784        return columns
785
786    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
787        if self._source_columns is None:
788            self._source_columns = {
789                source_name: self.get_source_columns(source_name)
790                for source_name, source in itertools.chain(
791                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
792                )
793            }
794        return self._source_columns
795
796    def _get_unambiguous_columns(
797        self, source_columns: t.Dict[str, t.Sequence[str]]
798    ) -> t.Mapping[str, str]:
799        """
800        Find all the unambiguous columns in sources.
801
802        Args:
803            source_columns: Mapping of names to source columns.
804
805        Returns:
806            Mapping of column name to source name.
807        """
808        if not source_columns:
809            return {}
810
811        source_columns_pairs = list(source_columns.items())
812
813        first_table, first_columns = source_columns_pairs[0]
814
815        if len(source_columns_pairs) == 1:
816            # Performance optimization - avoid copying first_columns if there is only one table.
817            return SingleValuedMapping(first_columns, first_table)
818
819        unambiguous_columns = {col: first_table for col in first_columns}
820        all_columns = set(unambiguous_columns)
821
822        for table, columns in source_columns_pairs[1:]:
823            unique = set(columns)
824            ambiguous = all_columns.intersection(unique)
825            all_columns.update(columns)
826
827            for column in ambiguous:
828                unambiguous_columns.pop(column, None)
829            for column in unique.difference(ambiguous):
830                unambiguous_columns[column] = table
831
832        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
691    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
692        self.scope = scope
693        self.schema = schema
694        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
695        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
696        self._all_columns: t.Optional[t.Set[str]] = None
697        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
699    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
700        """
701        Get the table for a column name.
702
703        Args:
704            column_name: The column name to find the table for.
705        Returns:
706            The table name if it can be found/inferred.
707        """
708        if self._unambiguous_columns is None:
709            self._unambiguous_columns = self._get_unambiguous_columns(
710                self._get_all_source_columns()
711            )
712
713        table_name = self._unambiguous_columns.get(column_name)
714
715        if not table_name and self._infer_schema:
716            sources_without_schema = tuple(
717                source
718                for source, columns in self._get_all_source_columns().items()
719                if not columns or "*" in columns
720            )
721            if len(sources_without_schema) == 1:
722                table_name = sources_without_schema[0]
723
724        if table_name not in self.scope.selected_sources:
725            return exp.to_identifier(table_name)
726
727        node, _ = self.scope.selected_sources.get(table_name)
728
729        if isinstance(node, exp.Query):
730            while node and node.alias != table_name:
731                node = node.parent
732
733        node_alias = node.args.get("alias")
734        if node_alias:
735            return exp.to_identifier(node_alias.this)
736
737        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
739    @property
740    def all_columns(self) -> t.Set[str]:
741        """All available columns of all sources in this scope"""
742        if self._all_columns is None:
743            self._all_columns = {
744                column for columns in self._get_all_source_columns().values() for column in columns
745            }
746        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
748    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
749        """Resolve the source columns for a given source `name`."""
750        if name not in self.scope.sources:
751            raise OptimizeError(f"Unknown table: {name}")
752
753        source = self.scope.sources[name]
754
755        if isinstance(source, exp.Table):
756            columns = self.schema.column_names(source, only_visible)
757        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
758            columns = source.expression.named_selects
759
760            # in bigquery, unnest structs are automatically scoped as tables, so you can
761            # directly select a struct field in a query.
762            # this handles the case where the unnest is statically defined.
763            if self.schema.dialect == "bigquery":
764                if source.expression.is_type(exp.DataType.Type.STRUCT):
765                    for k in source.expression.type.expressions:  # type: ignore
766                        columns.append(k.name)
767        else:
768            columns = source.expression.named_selects
769
770        node, _ = self.scope.selected_sources.get(name) or (None, None)
771        if isinstance(node, Scope):
772            column_aliases = node.expression.alias_column_names
773        elif isinstance(node, exp.Expression):
774            column_aliases = node.alias_column_names
775        else:
776            column_aliases = []
777
778        if column_aliases:
779            # If the source's columns are aliased, their aliases shadow the corresponding column names.
780            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
781            return [
782                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
783            ]
784        return columns

Resolve the source columns for a given source name.