Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot._typing import E
  8from sqlglot.dialects.dialect import Dialect, DialectType
  9from sqlglot.errors import OptimizeError
 10from sqlglot.helper import seq_get
 11from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15
 16def qualify_columns(
 17    expression: exp.Expression,
 18    schema: t.Dict | Schema,
 19    expand_alias_refs: bool = True,
 20    infer_schema: t.Optional[bool] = None,
 21) -> exp.Expression:
 22    """
 23    Rewrite sqlglot AST to have fully qualified columns.
 24
 25    Example:
 26        >>> import sqlglot
 27        >>> schema = {"tbl": {"col": "INT"}}
 28        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 29        >>> qualify_columns(expression, schema).sql()
 30        'SELECT tbl.col AS col FROM tbl'
 31
 32    Args:
 33        expression: Expression to qualify.
 34        schema: Database schema.
 35        expand_alias_refs: Whether or not to expand references to aliases.
 36        infer_schema: Whether or not to infer the schema if missing.
 37
 38    Returns:
 39        The qualified expression.
 40    """
 41    schema = ensure_schema(schema)
 42    infer_schema = schema.empty if infer_schema is None else infer_schema
 43    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
 44
 45    for scope in traverse_scope(expression):
 46        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 47        _pop_table_column_aliases(scope.ctes)
 48        _pop_table_column_aliases(scope.derived_tables)
 49        using_column_tables = _expand_using(scope, resolver)
 50
 51        if schema.empty and expand_alias_refs:
 52            _expand_alias_refs(scope, resolver)
 53
 54        _qualify_columns(scope, resolver)
 55
 56        if not schema.empty and expand_alias_refs:
 57            _expand_alias_refs(scope, resolver)
 58
 59        if not isinstance(scope.expression, exp.UDTF):
 60            _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 61            _qualify_outputs(scope)
 62
 63        _expand_group_by(scope)
 64        _expand_order_by(scope, resolver)
 65
 66    return expression
 67
 68
 69def validate_qualify_columns(expression: E) -> E:
 70    """Raise an `OptimizeError` if any columns aren't qualified"""
 71    unqualified_columns = []
 72    for scope in traverse_scope(expression):
 73        if isinstance(scope.expression, exp.Select):
 74            unqualified_columns.extend(scope.unqualified_columns)
 75            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 76                column = scope.external_columns[0]
 77                raise OptimizeError(
 78                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
 79                )
 80
 81    if unqualified_columns:
 82        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
 83    return expression
 84
 85
 86def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
 87    """
 88    Remove table column aliases.
 89
 90    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 91    """
 92    for derived_table in derived_tables:
 93        table_alias = derived_table.args.get("alias")
 94        if table_alias:
 95            table_alias.args.pop("columns", None)
 96
 97
 98def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
 99    joins = list(scope.find_all(exp.Join))
100    names = {join.alias_or_name for join in joins}
101    ordered = [key for key in scope.selected_sources if key not in names]
102
103    # Mapping of automatically joined column names to an ordered set of source names (dict).
104    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
105
106    for join in joins:
107        using = join.args.get("using")
108
109        if not using:
110            continue
111
112        join_table = join.alias_or_name
113
114        columns = {}
115
116        for source_name in scope.selected_sources:
117            if source_name in ordered:
118                for column_name in resolver.get_source_columns(source_name):
119                    if column_name not in columns:
120                        columns[column_name] = source_name
121
122        source_table = ordered[-1]
123        ordered.append(join_table)
124        join_columns = resolver.get_source_columns(join_table)
125        conditions = []
126
127        for identifier in using:
128            identifier = identifier.name
129            table = columns.get(identifier)
130
131            if not table or identifier not in join_columns:
132                if (columns and "*" not in columns) and join_columns:
133                    raise OptimizeError(f"Cannot automatically join: {identifier}")
134
135            table = table or source_table
136            conditions.append(
137                exp.condition(
138                    exp.EQ(
139                        this=exp.column(identifier, table=table),
140                        expression=exp.column(identifier, table=join_table),
141                    )
142                )
143            )
144
145            # Set all values in the dict to None, because we only care about the key ordering
146            tables = column_tables.setdefault(identifier, {})
147            if table not in tables:
148                tables[table] = None
149            if join_table not in tables:
150                tables[join_table] = None
151
152        join.args.pop("using")
153        join.set("on", exp.and_(*conditions, copy=False))
154
155    if column_tables:
156        for column in scope.columns:
157            if not column.table and column.name in column_tables:
158                tables = column_tables[column.name]
159                coalesce = [exp.column(column.name, table=table) for table in tables]
160                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
161
162                # Ensure selects keep their output name
163                if isinstance(column.parent, exp.Select):
164                    replacement = alias(replacement, alias=column.name, copy=False)
165
166                scope.replace(column, replacement)
167
168    return column_tables
169
170
171def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
172    expression = scope.expression
173
174    if not isinstance(expression, exp.Select):
175        return
176
177    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
178
179    def replace_columns(
180        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
181    ) -> None:
182        if not node:
183            return
184
185        for column, *_ in walk_in_scope(node):
186            if not isinstance(column, exp.Column):
187                continue
188
189            table = resolver.get_table(column.name) if resolve_table and not column.table else None
190            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
191            double_agg = (
192                (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
193                if alias_expr
194                else False
195            )
196
197            if table and (not alias_expr or double_agg):
198                column.set("table", table)
199            elif not column.table and alias_expr and not double_agg:
200                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
201                    if literal_index:
202                        column.replace(exp.Literal.number(i))
203                else:
204                    column = column.replace(exp.paren(alias_expr))
205                    simplified = simplify_parens(column)
206                    if simplified is not column:
207                        column.replace(simplified)
208
209    for i, projection in enumerate(scope.expression.selects):
210        replace_columns(projection)
211
212        if isinstance(projection, exp.Alias):
213            alias_to_expression[projection.alias] = (projection.this, i + 1)
214
215    replace_columns(expression.args.get("where"))
216    replace_columns(expression.args.get("group"), literal_index=True)
217    replace_columns(expression.args.get("having"), resolve_table=True)
218    replace_columns(expression.args.get("qualify"), resolve_table=True)
219    scope.clear_cache()
220
221
222def _expand_group_by(scope: Scope) -> None:
223    expression = scope.expression
224    group = expression.args.get("group")
225    if not group:
226        return
227
228    group.set("expressions", _expand_positional_references(scope, group.expressions))
229    expression.set("group", group)
230
231
232def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
233    order = scope.expression.args.get("order")
234    if not order:
235        return
236
237    ordereds = order.expressions
238    for ordered, new_expression in zip(
239        ordereds,
240        _expand_positional_references(scope, (o.this for o in ordereds)),
241    ):
242        for agg in ordered.find_all(exp.AggFunc):
243            for col in agg.find_all(exp.Column):
244                if not col.table:
245                    col.set("table", resolver.get_table(col.name))
246
247        ordered.set("this", new_expression)
248
249    if scope.expression.args.get("group"):
250        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
251
252        for ordered in ordereds:
253            ordered = ordered.this
254
255            ordered.replace(
256                exp.to_identifier(_select_by_pos(scope, ordered).alias)
257                if ordered.is_int
258                else selects.get(ordered, ordered)
259            )
260
261
262def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
263    new_nodes = []
264    for node in expressions:
265        if node.is_int:
266            select = _select_by_pos(scope, t.cast(exp.Literal, node)).this
267
268            if isinstance(select, exp.Literal):
269                new_nodes.append(node)
270            else:
271                new_nodes.append(select.copy())
272                scope.clear_cache()
273        else:
274            new_nodes.append(node)
275
276    return new_nodes
277
278
279def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
280    try:
281        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
282    except IndexError:
283        raise OptimizeError(f"Unknown output column: {node.name}")
284
285
286def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
287    """Disambiguate columns, ensuring each column specifies a source"""
288    for column in scope.columns:
289        column_table = column.table
290        column_name = column.name
291
292        if column_table and column_table in scope.sources:
293            source_columns = resolver.get_source_columns(column_table)
294            if source_columns and column_name not in source_columns and "*" not in source_columns:
295                raise OptimizeError(f"Unknown column: {column_name}")
296
297        if not column_table:
298            if scope.pivots and not column.find_ancestor(exp.Pivot):
299                # If the column is under the Pivot expression, we need to qualify it
300                # using the name of the pivoted source instead of the pivot's alias
301                column.set("table", exp.to_identifier(scope.pivots[0].alias))
302                continue
303
304            column_table = resolver.get_table(column_name)
305
306            # column_table can be a '' because bigquery unnest has no table alias
307            if column_table:
308                column.set("table", column_table)
309        elif column_table not in scope.sources and (
310            not scope.parent
311            or column_table not in scope.parent.sources
312            or not scope.is_correlated_subquery
313        ):
314            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
315            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
316
317            root, *parts = column.parts
318
319            if root.name in scope.sources:
320                # struct is already qualified, but we still need to change the AST representation
321                column_table = root
322                root, *parts = parts
323            else:
324                column_table = resolver.get_table(root.name)
325
326            if column_table:
327                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
328
329    for pivot in scope.pivots:
330        for column in pivot.find_all(exp.Column):
331            if not column.table and column.name in resolver.all_columns:
332                column_table = resolver.get_table(column.name)
333                if column_table:
334                    column.set("table", column_table)
335
336
337def _expand_stars(
338    scope: Scope,
339    resolver: Resolver,
340    using_column_tables: t.Dict[str, t.Any],
341    pseudocolumns: t.Set[str],
342) -> None:
343    """Expand stars to lists of column selections"""
344
345    new_selections = []
346    except_columns: t.Dict[int, t.Set[str]] = {}
347    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
348    coalesced_columns = set()
349
350    # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
351    pivot_columns = None
352    pivot_output_columns = None
353    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
354
355    has_pivoted_source = pivot and not pivot.args.get("unpivot")
356    if pivot and has_pivoted_source:
357        pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
358
359        pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
360        if not pivot_output_columns:
361            pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
362
363    for expression in scope.expression.selects:
364        if isinstance(expression, exp.Star):
365            tables = list(scope.selected_sources)
366            _add_except_columns(expression, tables, except_columns)
367            _add_replace_columns(expression, tables, replace_columns)
368        elif expression.is_star:
369            tables = [expression.table]
370            _add_except_columns(expression.this, tables, except_columns)
371            _add_replace_columns(expression.this, tables, replace_columns)
372        else:
373            new_selections.append(expression)
374            continue
375
376        for table in tables:
377            if table not in scope.sources:
378                raise OptimizeError(f"Unknown table: {table}")
379
380            columns = resolver.get_source_columns(table, only_visible=True)
381
382            if pseudocolumns:
383                columns = [name for name in columns if name.upper() not in pseudocolumns]
384
385            if columns and "*" not in columns:
386                if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
387                    implicit_columns = [col for col in columns if col not in pivot_columns]
388                    new_selections.extend(
389                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
390                        for name in implicit_columns + pivot_output_columns
391                    )
392                    continue
393
394                table_id = id(table)
395                for name in columns:
396                    if name in using_column_tables and table in using_column_tables[name]:
397                        if name in coalesced_columns:
398                            continue
399
400                        coalesced_columns.add(name)
401                        tables = using_column_tables[name]
402                        coalesce = [exp.column(name, table=table) for table in tables]
403
404                        new_selections.append(
405                            alias(
406                                exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
407                                alias=name,
408                                copy=False,
409                            )
410                        )
411                    elif name not in except_columns.get(table_id, set()):
412                        alias_ = replace_columns.get(table_id, {}).get(name, name)
413                        column = exp.column(name, table=table)
414                        new_selections.append(
415                            alias(column, alias_, copy=False) if alias_ != name else column
416                        )
417            else:
418                return
419
420    # Ensures we don't overwrite the initial selections with an empty list
421    if new_selections:
422        scope.expression.set("expressions", new_selections)
423
424
425def _add_except_columns(
426    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
427) -> None:
428    except_ = expression.args.get("except")
429
430    if not except_:
431        return
432
433    columns = {e.name for e in except_}
434
435    for table in tables:
436        except_columns[id(table)] = columns
437
438
439def _add_replace_columns(
440    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
441) -> None:
442    replace = expression.args.get("replace")
443
444    if not replace:
445        return
446
447    columns = {e.this.name: e.alias for e in replace}
448
449    for table in tables:
450        replace_columns[id(table)] = columns
451
452
453def _qualify_outputs(scope: Scope) -> None:
454    """Ensure all output columns are aliased"""
455    new_selections = []
456
457    for i, (selection, aliased_column) in enumerate(
458        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
459    ):
460        if isinstance(selection, exp.Subquery):
461            if not selection.output_name:
462                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
463        elif not isinstance(selection, exp.Alias) and not selection.is_star:
464            selection = alias(
465                selection,
466                alias=selection.output_name or f"_col_{i}",
467            )
468        if aliased_column:
469            selection.set("alias", exp.to_identifier(aliased_column))
470
471        new_selections.append(selection)
472
473    scope.expression.set("expressions", new_selections)
474
475
476def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
477    """Makes sure all identifiers that need to be quoted are quoted."""
478    return expression.transform(
479        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
480    )
481
482
483class Resolver:
484    """
485    Helper for resolving columns.
486
487    This is a class so we can lazily load some things and easily share them across functions.
488    """
489
490    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
491        self.scope = scope
492        self.schema = schema
493        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
494        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
495        self._all_columns: t.Optional[t.Set[str]] = None
496        self._infer_schema = infer_schema
497
498    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
499        """
500        Get the table for a column name.
501
502        Args:
503            column_name: The column name to find the table for.
504        Returns:
505            The table name if it can be found/inferred.
506        """
507        if self._unambiguous_columns is None:
508            self._unambiguous_columns = self._get_unambiguous_columns(
509                self._get_all_source_columns()
510            )
511
512        table_name = self._unambiguous_columns.get(column_name)
513
514        if not table_name and self._infer_schema:
515            sources_without_schema = tuple(
516                source
517                for source, columns in self._get_all_source_columns().items()
518                if not columns or "*" in columns
519            )
520            if len(sources_without_schema) == 1:
521                table_name = sources_without_schema[0]
522
523        if table_name not in self.scope.selected_sources:
524            return exp.to_identifier(table_name)
525
526        node, _ = self.scope.selected_sources.get(table_name)
527
528        if isinstance(node, exp.Subqueryable):
529            while node and node.alias != table_name:
530                node = node.parent
531
532        node_alias = node.args.get("alias")
533        if node_alias:
534            return exp.to_identifier(node_alias.this)
535
536        return exp.to_identifier(table_name)
537
538    @property
539    def all_columns(self) -> t.Set[str]:
540        """All available columns of all sources in this scope"""
541        if self._all_columns is None:
542            self._all_columns = {
543                column for columns in self._get_all_source_columns().values() for column in columns
544            }
545        return self._all_columns
546
547    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
548        """Resolve the source columns for a given source `name`."""
549        if name not in self.scope.sources:
550            raise OptimizeError(f"Unknown table: {name}")
551
552        source = self.scope.sources[name]
553
554        if isinstance(source, exp.Table):
555            columns = self.schema.column_names(source, only_visible)
556        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
557            columns = source.expression.alias_column_names
558        else:
559            columns = source.expression.named_selects
560
561        node, _ = self.scope.selected_sources.get(name) or (None, None)
562        if isinstance(node, Scope):
563            column_aliases = node.expression.alias_column_names
564        elif isinstance(node, exp.Expression):
565            column_aliases = node.alias_column_names
566        else:
567            column_aliases = []
568
569        # If the source's columns are aliased, their aliases shadow the corresponding column names
570        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
571
572    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
573        if self._source_columns is None:
574            self._source_columns = {
575                source_name: self.get_source_columns(source_name)
576                for source_name, source in itertools.chain(
577                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
578                )
579            }
580        return self._source_columns
581
582    def _get_unambiguous_columns(
583        self, source_columns: t.Dict[str, t.List[str]]
584    ) -> t.Dict[str, str]:
585        """
586        Find all the unambiguous columns in sources.
587
588        Args:
589            source_columns: Mapping of names to source columns.
590
591        Returns:
592            Mapping of column name to source name.
593        """
594        if not source_columns:
595            return {}
596
597        source_columns_pairs = list(source_columns.items())
598
599        first_table, first_columns = source_columns_pairs[0]
600        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
601        all_columns = set(unambiguous_columns)
602
603        for table, columns in source_columns_pairs[1:]:
604            unique = self._find_unique_columns(columns)
605            ambiguous = set(all_columns).intersection(unique)
606            all_columns.update(columns)
607
608            for column in ambiguous:
609                unambiguous_columns.pop(column, None)
610            for column in unique.difference(ambiguous):
611                unambiguous_columns[column] = table
612
613        return unambiguous_columns
614
615    @staticmethod
616    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
617        """
618        Find the unique columns in a list of columns.
619
620        Example:
621            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
622            ['a', 'c']
623
624        This is necessary because duplicate column names are ambiguous.
625        """
626        counts: t.Dict[str, int] = {}
627        for column in columns:
628            counts[column] = counts.get(column, 0) + 1
629        return {column for column, count in counts.items() if count == 1}
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
17def qualify_columns(
18    expression: exp.Expression,
19    schema: t.Dict | Schema,
20    expand_alias_refs: bool = True,
21    infer_schema: t.Optional[bool] = None,
22) -> exp.Expression:
23    """
24    Rewrite sqlglot AST to have fully qualified columns.
25
26    Example:
27        >>> import sqlglot
28        >>> schema = {"tbl": {"col": "INT"}}
29        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
30        >>> qualify_columns(expression, schema).sql()
31        'SELECT tbl.col AS col FROM tbl'
32
33    Args:
34        expression: Expression to qualify.
35        schema: Database schema.
36        expand_alias_refs: Whether or not to expand references to aliases.
37        infer_schema: Whether or not to infer the schema if missing.
38
39    Returns:
40        The qualified expression.
41    """
42    schema = ensure_schema(schema)
43    infer_schema = schema.empty if infer_schema is None else infer_schema
44    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
45
46    for scope in traverse_scope(expression):
47        resolver = Resolver(scope, schema, infer_schema=infer_schema)
48        _pop_table_column_aliases(scope.ctes)
49        _pop_table_column_aliases(scope.derived_tables)
50        using_column_tables = _expand_using(scope, resolver)
51
52        if schema.empty and expand_alias_refs:
53            _expand_alias_refs(scope, resolver)
54
55        _qualify_columns(scope, resolver)
56
57        if not schema.empty and expand_alias_refs:
58            _expand_alias_refs(scope, resolver)
59
60        if not isinstance(scope.expression, exp.UDTF):
61            _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
62            _qualify_outputs(scope)
63
64        _expand_group_by(scope)
65        _expand_order_by(scope, resolver)
66
67    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether or not to expand references to aliases.
  • infer_schema: Whether or not to infer the schema if missing.
Returns:

The qualified expression.

def validate_qualify_columns(expression: ~E) -> ~E:
70def validate_qualify_columns(expression: E) -> E:
71    """Raise an `OptimizeError` if any columns aren't qualified"""
72    unqualified_columns = []
73    for scope in traverse_scope(expression):
74        if isinstance(scope.expression, exp.Select):
75            unqualified_columns.extend(scope.unqualified_columns)
76            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
77                column = scope.external_columns[0]
78                raise OptimizeError(
79                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
80                )
81
82    if unqualified_columns:
83        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
84    return expression

Raise an OptimizeError if any columns aren't qualified

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
477def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
478    """Makes sure all identifiers that need to be quoted are quoted."""
479    return expression.transform(
480        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
481    )

Makes sure all identifiers that need to be quoted are quoted.

class Resolver:
484class Resolver:
485    """
486    Helper for resolving columns.
487
488    This is a class so we can lazily load some things and easily share them across functions.
489    """
490
491    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
492        self.scope = scope
493        self.schema = schema
494        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
495        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
496        self._all_columns: t.Optional[t.Set[str]] = None
497        self._infer_schema = infer_schema
498
499    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
500        """
501        Get the table for a column name.
502
503        Args:
504            column_name: The column name to find the table for.
505        Returns:
506            The table name if it can be found/inferred.
507        """
508        if self._unambiguous_columns is None:
509            self._unambiguous_columns = self._get_unambiguous_columns(
510                self._get_all_source_columns()
511            )
512
513        table_name = self._unambiguous_columns.get(column_name)
514
515        if not table_name and self._infer_schema:
516            sources_without_schema = tuple(
517                source
518                for source, columns in self._get_all_source_columns().items()
519                if not columns or "*" in columns
520            )
521            if len(sources_without_schema) == 1:
522                table_name = sources_without_schema[0]
523
524        if table_name not in self.scope.selected_sources:
525            return exp.to_identifier(table_name)
526
527        node, _ = self.scope.selected_sources.get(table_name)
528
529        if isinstance(node, exp.Subqueryable):
530            while node and node.alias != table_name:
531                node = node.parent
532
533        node_alias = node.args.get("alias")
534        if node_alias:
535            return exp.to_identifier(node_alias.this)
536
537        return exp.to_identifier(table_name)
538
539    @property
540    def all_columns(self) -> t.Set[str]:
541        """All available columns of all sources in this scope"""
542        if self._all_columns is None:
543            self._all_columns = {
544                column for columns in self._get_all_source_columns().values() for column in columns
545            }
546        return self._all_columns
547
548    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
549        """Resolve the source columns for a given source `name`."""
550        if name not in self.scope.sources:
551            raise OptimizeError(f"Unknown table: {name}")
552
553        source = self.scope.sources[name]
554
555        if isinstance(source, exp.Table):
556            columns = self.schema.column_names(source, only_visible)
557        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
558            columns = source.expression.alias_column_names
559        else:
560            columns = source.expression.named_selects
561
562        node, _ = self.scope.selected_sources.get(name) or (None, None)
563        if isinstance(node, Scope):
564            column_aliases = node.expression.alias_column_names
565        elif isinstance(node, exp.Expression):
566            column_aliases = node.alias_column_names
567        else:
568            column_aliases = []
569
570        # If the source's columns are aliased, their aliases shadow the corresponding column names
571        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
572
573    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
574        if self._source_columns is None:
575            self._source_columns = {
576                source_name: self.get_source_columns(source_name)
577                for source_name, source in itertools.chain(
578                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
579                )
580            }
581        return self._source_columns
582
583    def _get_unambiguous_columns(
584        self, source_columns: t.Dict[str, t.List[str]]
585    ) -> t.Dict[str, str]:
586        """
587        Find all the unambiguous columns in sources.
588
589        Args:
590            source_columns: Mapping of names to source columns.
591
592        Returns:
593            Mapping of column name to source name.
594        """
595        if not source_columns:
596            return {}
597
598        source_columns_pairs = list(source_columns.items())
599
600        first_table, first_columns = source_columns_pairs[0]
601        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
602        all_columns = set(unambiguous_columns)
603
604        for table, columns in source_columns_pairs[1:]:
605            unique = self._find_unique_columns(columns)
606            ambiguous = set(all_columns).intersection(unique)
607            all_columns.update(columns)
608
609            for column in ambiguous:
610                unambiguous_columns.pop(column, None)
611            for column in unique.difference(ambiguous):
612                unambiguous_columns[column] = table
613
614        return unambiguous_columns
615
616    @staticmethod
617    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
618        """
619        Find the unique columns in a list of columns.
620
621        Example:
622            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
623            ['a', 'c']
624
625        This is necessary because duplicate column names are ambiguous.
626        """
627        counts: t.Dict[str, int] = {}
628        for column in columns:
629            counts[column] = counts.get(column, 0) + 1
630        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
491    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
492        self.scope = scope
493        self.schema = schema
494        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
495        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
496        self._all_columns: t.Optional[t.Set[str]] = None
497        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
499    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
500        """
501        Get the table for a column name.
502
503        Args:
504            column_name: The column name to find the table for.
505        Returns:
506            The table name if it can be found/inferred.
507        """
508        if self._unambiguous_columns is None:
509            self._unambiguous_columns = self._get_unambiguous_columns(
510                self._get_all_source_columns()
511            )
512
513        table_name = self._unambiguous_columns.get(column_name)
514
515        if not table_name and self._infer_schema:
516            sources_without_schema = tuple(
517                source
518                for source, columns in self._get_all_source_columns().items()
519                if not columns or "*" in columns
520            )
521            if len(sources_without_schema) == 1:
522                table_name = sources_without_schema[0]
523
524        if table_name not in self.scope.selected_sources:
525            return exp.to_identifier(table_name)
526
527        node, _ = self.scope.selected_sources.get(table_name)
528
529        if isinstance(node, exp.Subqueryable):
530            while node and node.alias != table_name:
531                node = node.parent
532
533        node_alias = node.args.get("alias")
534        if node_alias:
535            return exp.to_identifier(node_alias.this)
536
537        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
548    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
549        """Resolve the source columns for a given source `name`."""
550        if name not in self.scope.sources:
551            raise OptimizeError(f"Unknown table: {name}")
552
553        source = self.scope.sources[name]
554
555        if isinstance(source, exp.Table):
556            columns = self.schema.column_names(source, only_visible)
557        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
558            columns = source.expression.alias_column_names
559        else:
560            columns = source.expression.named_selects
561
562        node, _ = self.scope.selected_sources.get(name) or (None, None)
563        if isinstance(node, Scope):
564            column_aliases = node.expression.alias_column_names
565        elif isinstance(node, exp.Expression):
566            column_aliases = node.alias_column_names
567        else:
568            column_aliases = []
569
570        # If the source's columns are aliased, their aliases shadow the corresponding column names
571        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]

Resolve the source columns for a given source name.