Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot._typing import E
  8from sqlglot.dialects.dialect import Dialect, DialectType
  9from sqlglot.errors import OptimizeError
 10from sqlglot.helper import seq_get
 11from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
 12from sqlglot.schema import Schema, ensure_schema
 13
 14
 15def qualify_columns(
 16    expression: exp.Expression,
 17    schema: t.Dict | Schema,
 18    expand_alias_refs: bool = True,
 19    infer_schema: t.Optional[bool] = None,
 20) -> exp.Expression:
 21    """
 22    Rewrite sqlglot AST to have fully qualified columns.
 23
 24    Example:
 25        >>> import sqlglot
 26        >>> schema = {"tbl": {"col": "INT"}}
 27        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 28        >>> qualify_columns(expression, schema).sql()
 29        'SELECT tbl.col AS col FROM tbl'
 30
 31    Args:
 32        expression: expression to qualify
 33        schema: Database schema
 34        expand_alias_refs: whether or not to expand references to aliases
 35        infer_schema: whether or not to infer the schema if missing
 36    Returns:
 37        sqlglot.Expression: qualified expression
 38    """
 39    schema = ensure_schema(schema)
 40    infer_schema = schema.empty if infer_schema is None else infer_schema
 41
 42    for scope in traverse_scope(expression):
 43        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 44        _pop_table_column_aliases(scope.ctes)
 45        _pop_table_column_aliases(scope.derived_tables)
 46        using_column_tables = _expand_using(scope, resolver)
 47
 48        if schema.empty and expand_alias_refs:
 49            _expand_alias_refs(scope, resolver)
 50
 51        _qualify_columns(scope, resolver)
 52
 53        if not schema.empty and expand_alias_refs:
 54            _expand_alias_refs(scope, resolver)
 55
 56        if not isinstance(scope.expression, exp.UDTF):
 57            _expand_stars(scope, resolver, using_column_tables)
 58            _qualify_outputs(scope)
 59        _expand_group_by(scope)
 60        _expand_order_by(scope, resolver)
 61
 62    return expression
 63
 64
 65def validate_qualify_columns(expression: E) -> E:
 66    """Raise an `OptimizeError` if any columns aren't qualified"""
 67    unqualified_columns = []
 68    for scope in traverse_scope(expression):
 69        if isinstance(scope.expression, exp.Select):
 70            unqualified_columns.extend(scope.unqualified_columns)
 71            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 72                column = scope.external_columns[0]
 73                raise OptimizeError(
 74                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
 75                )
 76
 77    if unqualified_columns:
 78        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
 79    return expression
 80
 81
 82def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
 83    """
 84    Remove table column aliases.
 85
 86    (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 87    """
 88    for derived_table in derived_tables:
 89        table_alias = derived_table.args.get("alias")
 90        if table_alias:
 91            table_alias.args.pop("columns", None)
 92
 93
 94def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
 95    joins = list(scope.find_all(exp.Join))
 96    names = {join.alias_or_name for join in joins}
 97    ordered = [key for key in scope.selected_sources if key not in names]
 98
 99    # Mapping of automatically joined column names to an ordered set of source names (dict).
100    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
101
102    for join in joins:
103        using = join.args.get("using")
104
105        if not using:
106            continue
107
108        join_table = join.alias_or_name
109
110        columns = {}
111
112        for k in scope.selected_sources:
113            if k in ordered:
114                for column in resolver.get_source_columns(k):
115                    if column not in columns:
116                        columns[column] = k
117
118        source_table = ordered[-1]
119        ordered.append(join_table)
120        join_columns = resolver.get_source_columns(join_table)
121        conditions = []
122
123        for identifier in using:
124            identifier = identifier.name
125            table = columns.get(identifier)
126
127            if not table or identifier not in join_columns:
128                if columns and join_columns:
129                    raise OptimizeError(f"Cannot automatically join: {identifier}")
130
131            table = table or source_table
132            conditions.append(
133                exp.condition(
134                    exp.EQ(
135                        this=exp.column(identifier, table=table),
136                        expression=exp.column(identifier, table=join_table),
137                    )
138                )
139            )
140
141            # Set all values in the dict to None, because we only care about the key ordering
142            tables = column_tables.setdefault(identifier, {})
143            if table not in tables:
144                tables[table] = None
145            if join_table not in tables:
146                tables[join_table] = None
147
148        join.args.pop("using")
149        join.set("on", exp.and_(*conditions, copy=False))
150
151    if column_tables:
152        for column in scope.columns:
153            if not column.table and column.name in column_tables:
154                tables = column_tables[column.name]
155                coalesce = [exp.column(column.name, table=table) for table in tables]
156                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
157
158                # Ensure selects keep their output name
159                if isinstance(column.parent, exp.Select):
160                    replacement = alias(replacement, alias=column.name, copy=False)
161
162                scope.replace(column, replacement)
163
164    return column_tables
165
166
167def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
168    expression = scope.expression
169
170    if not isinstance(expression, exp.Select):
171        return
172
173    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
174
175    def replace_columns(
176        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
177    ) -> None:
178        if not node:
179            return
180
181        for column, *_ in walk_in_scope(node):
182            if not isinstance(column, exp.Column):
183                continue
184            table = resolver.get_table(column.name) if resolve_table and not column.table else None
185            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
186            double_agg = (
187                (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
188                if alias_expr
189                else False
190            )
191
192            if table and (not alias_expr or double_agg):
193                column.set("table", table)
194            elif not column.table and alias_expr and not double_agg:
195                if isinstance(alias_expr, exp.Literal):
196                    if literal_index:
197                        column.replace(exp.Literal.number(i))
198                else:
199                    column.replace(alias_expr.copy())
200
201    for i, projection in enumerate(scope.selects):
202        replace_columns(projection)
203
204        if isinstance(projection, exp.Alias):
205            alias_to_expression[projection.alias] = (projection.this, i + 1)
206
207    replace_columns(expression.args.get("where"))
208    replace_columns(expression.args.get("group"), literal_index=True)
209    replace_columns(expression.args.get("having"), resolve_table=True)
210    replace_columns(expression.args.get("qualify"), resolve_table=True)
211    scope.clear_cache()
212
213
214def _expand_group_by(scope: Scope):
215    expression = scope.expression
216    group = expression.args.get("group")
217    if not group:
218        return
219
220    group.set("expressions", _expand_positional_references(scope, group.expressions))
221    expression.set("group", group)
222
223    # group by expressions cannot be simplified, for example
224    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
225    # the projection must exactly match the group by key
226    groups = set(group.expressions)
227    group.meta["final"] = True
228
229    for e in expression.selects:
230        for node, *_ in e.walk():
231            if node in groups:
232                e.meta["final"] = True
233                break
234
235    having = expression.args.get("having")
236    if having:
237        for node, *_ in having.walk():
238            if node in groups:
239                having.meta["final"] = True
240                break
241
242
243def _expand_order_by(scope: Scope, resolver: Resolver):
244    order = scope.expression.args.get("order")
245    if not order:
246        return
247
248    ordereds = order.expressions
249    for ordered, new_expression in zip(
250        ordereds,
251        _expand_positional_references(scope, (o.this for o in ordereds)),
252    ):
253        for agg in ordered.find_all(exp.AggFunc):
254            for col in agg.find_all(exp.Column):
255                if not col.table:
256                    col.set("table", resolver.get_table(col.name))
257
258        ordered.set("this", new_expression)
259
260    if scope.expression.args.get("group"):
261        selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects}
262
263        for ordered in ordereds:
264            ordered = ordered.this
265
266            ordered.replace(
267                exp.to_identifier(_select_by_pos(scope, ordered).alias)
268                if ordered.is_int
269                else selects.get(ordered, ordered)
270            )
271
272
273def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
274    new_nodes = []
275    for node in expressions:
276        if node.is_int:
277            select = _select_by_pos(scope, t.cast(exp.Literal, node)).this
278
279            if isinstance(select, exp.Literal):
280                new_nodes.append(node)
281            else:
282                new_nodes.append(select.copy())
283                scope.clear_cache()
284        else:
285            new_nodes.append(node)
286
287    return new_nodes
288
289
290def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
291    try:
292        return scope.selects[int(node.this) - 1].assert_is(exp.Alias)
293    except IndexError:
294        raise OptimizeError(f"Unknown output column: {node.name}")
295
296
297def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
298    """Disambiguate columns, ensuring each column specifies a source"""
299    for column in scope.columns:
300        column_table = column.table
301        column_name = column.name
302
303        if column_table and column_table in scope.sources:
304            source_columns = resolver.get_source_columns(column_table)
305            if source_columns and column_name not in source_columns and "*" not in source_columns:
306                raise OptimizeError(f"Unknown column: {column_name}")
307
308        if not column_table:
309            if scope.pivots and not column.find_ancestor(exp.Pivot):
310                # If the column is under the Pivot expression, we need to qualify it
311                # using the name of the pivoted source instead of the pivot's alias
312                column.set("table", exp.to_identifier(scope.pivots[0].alias))
313                continue
314
315            column_table = resolver.get_table(column_name)
316
317            # column_table can be a '' because bigquery unnest has no table alias
318            if column_table:
319                column.set("table", column_table)
320        elif column_table not in scope.sources and (
321            not scope.parent or column_table not in scope.parent.sources
322        ):
323            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
324            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
325
326            root, *parts = column.parts
327
328            if root.name in scope.sources:
329                # struct is already qualified, but we still need to change the AST representation
330                column_table = root
331                root, *parts = parts
332            else:
333                column_table = resolver.get_table(root.name)
334
335            if column_table:
336                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
337
338    for pivot in scope.pivots:
339        for column in pivot.find_all(exp.Column):
340            if not column.table and column.name in resolver.all_columns:
341                column_table = resolver.get_table(column.name)
342                if column_table:
343                    column.set("table", column_table)
344
345
346def _expand_stars(
347    scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any]
348) -> None:
349    """Expand stars to lists of column selections"""
350
351    new_selections = []
352    except_columns: t.Dict[int, t.Set[str]] = {}
353    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
354    coalesced_columns = set()
355
356    # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
357    pivot_columns = None
358    pivot_output_columns = None
359    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
360
361    has_pivoted_source = pivot and not pivot.args.get("unpivot")
362    if pivot and has_pivoted_source:
363        pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
364
365        pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
366        if not pivot_output_columns:
367            pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
368
369    for expression in scope.selects:
370        if isinstance(expression, exp.Star):
371            tables = list(scope.selected_sources)
372            _add_except_columns(expression, tables, except_columns)
373            _add_replace_columns(expression, tables, replace_columns)
374        elif expression.is_star:
375            tables = [expression.table]
376            _add_except_columns(expression.this, tables, except_columns)
377            _add_replace_columns(expression.this, tables, replace_columns)
378        else:
379            new_selections.append(expression)
380            continue
381
382        for table in tables:
383            if table not in scope.sources:
384                raise OptimizeError(f"Unknown table: {table}")
385
386            columns = resolver.get_source_columns(table, only_visible=True)
387
388            # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
389            # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
390            if resolver.schema.dialect == "bigquery":
391                columns = [
392                    name
393                    for name in columns
394                    if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE")
395                ]
396
397            if columns and "*" not in columns:
398                if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
399                    implicit_columns = [col for col in columns if col not in pivot_columns]
400                    new_selections.extend(
401                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
402                        for name in implicit_columns + pivot_output_columns
403                    )
404                    continue
405
406                table_id = id(table)
407                for name in columns:
408                    if name in using_column_tables and table in using_column_tables[name]:
409                        if name in coalesced_columns:
410                            continue
411
412                        coalesced_columns.add(name)
413                        tables = using_column_tables[name]
414                        coalesce = [exp.column(name, table=table) for table in tables]
415
416                        new_selections.append(
417                            alias(
418                                exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
419                                alias=name,
420                                copy=False,
421                            )
422                        )
423                    elif name not in except_columns.get(table_id, set()):
424                        alias_ = replace_columns.get(table_id, {}).get(name, name)
425                        column = exp.column(name, table=table)
426                        new_selections.append(
427                            alias(column, alias_, copy=False) if alias_ != name else column
428                        )
429            else:
430                return
431
432    scope.expression.set("expressions", new_selections)
433
434
435def _add_except_columns(
436    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
437) -> None:
438    except_ = expression.args.get("except")
439
440    if not except_:
441        return
442
443    columns = {e.name for e in except_}
444
445    for table in tables:
446        except_columns[id(table)] = columns
447
448
449def _add_replace_columns(
450    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
451) -> None:
452    replace = expression.args.get("replace")
453
454    if not replace:
455        return
456
457    columns = {e.this.name: e.alias for e in replace}
458
459    for table in tables:
460        replace_columns[id(table)] = columns
461
462
463def _qualify_outputs(scope: Scope):
464    """Ensure all output columns are aliased"""
465    new_selections = []
466
467    for i, (selection, aliased_column) in enumerate(
468        itertools.zip_longest(scope.selects, scope.outer_column_list)
469    ):
470        if isinstance(selection, exp.Subquery):
471            if not selection.output_name:
472                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
473        elif not isinstance(selection, exp.Alias) and not selection.is_star:
474            selection = alias(
475                selection,
476                alias=selection.output_name or f"_col_{i}",
477            )
478        if aliased_column:
479            selection.set("alias", exp.to_identifier(aliased_column))
480
481        new_selections.append(selection)
482
483    scope.expression.set("expressions", new_selections)
484
485
486def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
487    """Makes sure all identifiers that need to be quoted are quoted."""
488    return expression.transform(
489        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
490    )
491
492
493class Resolver:
494    """
495    Helper for resolving columns.
496
497    This is a class so we can lazily load some things and easily share them across functions.
498    """
499
500    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
501        self.scope = scope
502        self.schema = schema
503        self._source_columns = None
504        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
505        self._all_columns = None
506        self._infer_schema = infer_schema
507
508    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
509        """
510        Get the table for a column name.
511
512        Args:
513            column_name: The column name to find the table for.
514        Returns:
515            The table name if it can be found/inferred.
516        """
517        if self._unambiguous_columns is None:
518            self._unambiguous_columns = self._get_unambiguous_columns(
519                self._get_all_source_columns()
520            )
521
522        table_name = self._unambiguous_columns.get(column_name)
523
524        if not table_name and self._infer_schema:
525            sources_without_schema = tuple(
526                source
527                for source, columns in self._get_all_source_columns().items()
528                if not columns or "*" in columns
529            )
530            if len(sources_without_schema) == 1:
531                table_name = sources_without_schema[0]
532
533        if table_name not in self.scope.selected_sources:
534            return exp.to_identifier(table_name)
535
536        node, _ = self.scope.selected_sources.get(table_name)
537
538        if isinstance(node, exp.Subqueryable):
539            while node and node.alias != table_name:
540                node = node.parent
541
542        node_alias = node.args.get("alias")
543        if node_alias:
544            return exp.to_identifier(node_alias.this)
545
546        return exp.to_identifier(table_name)
547
548    @property
549    def all_columns(self):
550        """All available columns of all sources in this scope"""
551        if self._all_columns is None:
552            self._all_columns = {
553                column for columns in self._get_all_source_columns().values() for column in columns
554            }
555        return self._all_columns
556
557    def get_source_columns(self, name, only_visible=False):
558        """Resolve the source columns for a given source `name`"""
559        if name not in self.scope.sources:
560            raise OptimizeError(f"Unknown table: {name}")
561
562        source = self.scope.sources[name]
563
564        # If referencing a table, return the columns from the schema
565        if isinstance(source, exp.Table):
566            return self.schema.column_names(source, only_visible)
567
568        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
569            return source.expression.alias_column_names
570
571        # Otherwise, if referencing another scope, return that scope's named selects
572        return source.expression.named_selects
573
574    def _get_all_source_columns(self):
575        if self._source_columns is None:
576            self._source_columns = {
577                k: self.get_source_columns(k)
578                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
579            }
580        return self._source_columns
581
582    def _get_unambiguous_columns(self, source_columns):
583        """
584        Find all the unambiguous columns in sources.
585
586        Args:
587            source_columns (dict): Mapping of names to source columns
588        Returns:
589            dict: Mapping of column name to source name
590        """
591        if not source_columns:
592            return {}
593
594        source_columns = list(source_columns.items())
595
596        first_table, first_columns = source_columns[0]
597        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
598        all_columns = set(unambiguous_columns)
599
600        for table, columns in source_columns[1:]:
601            unique = self._find_unique_columns(columns)
602            ambiguous = set(all_columns).intersection(unique)
603            all_columns.update(columns)
604            for column in ambiguous:
605                unambiguous_columns.pop(column, None)
606            for column in unique.difference(ambiguous):
607                unambiguous_columns[column] = table
608
609        return unambiguous_columns
610
611    @staticmethod
612    def _find_unique_columns(columns):
613        """
614        Find the unique columns in a list of columns.
615
616        Example:
617            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
618            ['a', 'c']
619
620        This is necessary because duplicate column names are ambiguous.
621        """
622        counts = {}
623        for column in columns:
624            counts[column] = counts.get(column, 0) + 1
625        return {column for column, count in counts.items() if count == 1}
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
16def qualify_columns(
17    expression: exp.Expression,
18    schema: t.Dict | Schema,
19    expand_alias_refs: bool = True,
20    infer_schema: t.Optional[bool] = None,
21) -> exp.Expression:
22    """
23    Rewrite sqlglot AST to have fully qualified columns.
24
25    Example:
26        >>> import sqlglot
27        >>> schema = {"tbl": {"col": "INT"}}
28        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
29        >>> qualify_columns(expression, schema).sql()
30        'SELECT tbl.col AS col FROM tbl'
31
32    Args:
33        expression: expression to qualify
34        schema: Database schema
35        expand_alias_refs: whether or not to expand references to aliases
36        infer_schema: whether or not to infer the schema if missing
37    Returns:
38        sqlglot.Expression: qualified expression
39    """
40    schema = ensure_schema(schema)
41    infer_schema = schema.empty if infer_schema is None else infer_schema
42
43    for scope in traverse_scope(expression):
44        resolver = Resolver(scope, schema, infer_schema=infer_schema)
45        _pop_table_column_aliases(scope.ctes)
46        _pop_table_column_aliases(scope.derived_tables)
47        using_column_tables = _expand_using(scope, resolver)
48
49        if schema.empty and expand_alias_refs:
50            _expand_alias_refs(scope, resolver)
51
52        _qualify_columns(scope, resolver)
53
54        if not schema.empty and expand_alias_refs:
55            _expand_alias_refs(scope, resolver)
56
57        if not isinstance(scope.expression, exp.UDTF):
58            _expand_stars(scope, resolver, using_column_tables)
59            _qualify_outputs(scope)
60        _expand_group_by(scope)
61        _expand_order_by(scope, resolver)
62
63    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: expression to qualify
  • schema: Database schema
  • expand_alias_refs: whether or not to expand references to aliases
  • infer_schema: whether or not to infer the schema if missing
Returns:

sqlglot.Expression: qualified expression

def validate_qualify_columns(expression: ~E) -> ~E:
66def validate_qualify_columns(expression: E) -> E:
67    """Raise an `OptimizeError` if any columns aren't qualified"""
68    unqualified_columns = []
69    for scope in traverse_scope(expression):
70        if isinstance(scope.expression, exp.Select):
71            unqualified_columns.extend(scope.unqualified_columns)
72            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
73                column = scope.external_columns[0]
74                raise OptimizeError(
75                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
76                )
77
78    if unqualified_columns:
79        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
80    return expression

Raise an OptimizeError if any columns aren't qualified

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
487def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
488    """Makes sure all identifiers that need to be quoted are quoted."""
489    return expression.transform(
490        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
491    )

Makes sure all identifiers that need to be quoted are quoted.

class Resolver:
494class Resolver:
495    """
496    Helper for resolving columns.
497
498    This is a class so we can lazily load some things and easily share them across functions.
499    """
500
501    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
502        self.scope = scope
503        self.schema = schema
504        self._source_columns = None
505        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
506        self._all_columns = None
507        self._infer_schema = infer_schema
508
509    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
510        """
511        Get the table for a column name.
512
513        Args:
514            column_name: The column name to find the table for.
515        Returns:
516            The table name if it can be found/inferred.
517        """
518        if self._unambiguous_columns is None:
519            self._unambiguous_columns = self._get_unambiguous_columns(
520                self._get_all_source_columns()
521            )
522
523        table_name = self._unambiguous_columns.get(column_name)
524
525        if not table_name and self._infer_schema:
526            sources_without_schema = tuple(
527                source
528                for source, columns in self._get_all_source_columns().items()
529                if not columns or "*" in columns
530            )
531            if len(sources_without_schema) == 1:
532                table_name = sources_without_schema[0]
533
534        if table_name not in self.scope.selected_sources:
535            return exp.to_identifier(table_name)
536
537        node, _ = self.scope.selected_sources.get(table_name)
538
539        if isinstance(node, exp.Subqueryable):
540            while node and node.alias != table_name:
541                node = node.parent
542
543        node_alias = node.args.get("alias")
544        if node_alias:
545            return exp.to_identifier(node_alias.this)
546
547        return exp.to_identifier(table_name)
548
549    @property
550    def all_columns(self):
551        """All available columns of all sources in this scope"""
552        if self._all_columns is None:
553            self._all_columns = {
554                column for columns in self._get_all_source_columns().values() for column in columns
555            }
556        return self._all_columns
557
558    def get_source_columns(self, name, only_visible=False):
559        """Resolve the source columns for a given source `name`"""
560        if name not in self.scope.sources:
561            raise OptimizeError(f"Unknown table: {name}")
562
563        source = self.scope.sources[name]
564
565        # If referencing a table, return the columns from the schema
566        if isinstance(source, exp.Table):
567            return self.schema.column_names(source, only_visible)
568
569        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
570            return source.expression.alias_column_names
571
572        # Otherwise, if referencing another scope, return that scope's named selects
573        return source.expression.named_selects
574
575    def _get_all_source_columns(self):
576        if self._source_columns is None:
577            self._source_columns = {
578                k: self.get_source_columns(k)
579                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
580            }
581        return self._source_columns
582
583    def _get_unambiguous_columns(self, source_columns):
584        """
585        Find all the unambiguous columns in sources.
586
587        Args:
588            source_columns (dict): Mapping of names to source columns
589        Returns:
590            dict: Mapping of column name to source name
591        """
592        if not source_columns:
593            return {}
594
595        source_columns = list(source_columns.items())
596
597        first_table, first_columns = source_columns[0]
598        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
599        all_columns = set(unambiguous_columns)
600
601        for table, columns in source_columns[1:]:
602            unique = self._find_unique_columns(columns)
603            ambiguous = set(all_columns).intersection(unique)
604            all_columns.update(columns)
605            for column in ambiguous:
606                unambiguous_columns.pop(column, None)
607            for column in unique.difference(ambiguous):
608                unambiguous_columns[column] = table
609
610        return unambiguous_columns
611
612    @staticmethod
613    def _find_unique_columns(columns):
614        """
615        Find the unique columns in a list of columns.
616
617        Example:
618            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
619            ['a', 'c']
620
621        This is necessary because duplicate column names are ambiguous.
622        """
623        counts = {}
624        for column in columns:
625            counts[column] = counts.get(column, 0) + 1
626        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
501    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
502        self.scope = scope
503        self.schema = schema
504        self._source_columns = None
505        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
506        self._all_columns = None
507        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
509    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
510        """
511        Get the table for a column name.
512
513        Args:
514            column_name: The column name to find the table for.
515        Returns:
516            The table name if it can be found/inferred.
517        """
518        if self._unambiguous_columns is None:
519            self._unambiguous_columns = self._get_unambiguous_columns(
520                self._get_all_source_columns()
521            )
522
523        table_name = self._unambiguous_columns.get(column_name)
524
525        if not table_name and self._infer_schema:
526            sources_without_schema = tuple(
527                source
528                for source, columns in self._get_all_source_columns().items()
529                if not columns or "*" in columns
530            )
531            if len(sources_without_schema) == 1:
532                table_name = sources_without_schema[0]
533
534        if table_name not in self.scope.selected_sources:
535            return exp.to_identifier(table_name)
536
537        node, _ = self.scope.selected_sources.get(table_name)
538
539        if isinstance(node, exp.Subqueryable):
540            while node and node.alias != table_name:
541                node = node.parent
542
543        node_alias = node.args.get("alias")
544        if node_alias:
545            return exp.to_identifier(node_alias.this)
546
547        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns

All available columns of all sources in this scope

def get_source_columns(self, name, only_visible=False):
558    def get_source_columns(self, name, only_visible=False):
559        """Resolve the source columns for a given source `name`"""
560        if name not in self.scope.sources:
561            raise OptimizeError(f"Unknown table: {name}")
562
563        source = self.scope.sources[name]
564
565        # If referencing a table, return the columns from the schema
566        if isinstance(source, exp.Table):
567            return self.schema.column_names(source, only_visible)
568
569        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
570            return source.expression.alias_column_names
571
572        # Otherwise, if referencing another scope, return that scope's named selects
573        return source.expression.named_selects

Resolve the source columns for a given source name