Edit on GitHub

sqlglot.optimizer.scope

  1import itertools
  2from collections import defaultdict
  3from enum import Enum, auto
  4
  5from sqlglot import exp
  6from sqlglot.errors import OptimizeError
  7from sqlglot.helper import find_new_name
  8
  9
 10class ScopeType(Enum):
 11    ROOT = auto()
 12    SUBQUERY = auto()
 13    DERIVED_TABLE = auto()
 14    CTE = auto()
 15    UNION = auto()
 16    UDTF = auto()
 17
 18
 19class Scope:
 20    """
 21    Selection scope.
 22
 23    Attributes:
 24        expression (exp.Select|exp.Union): Root expression of this scope
 25        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 26            a Table expression or another Scope instance. For example:
 27                SELECT * FROM x                     {"x": Table(this="x")}
 28                SELECT * FROM x AS y                {"y": Table(this="x")}
 29                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 30        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 31            For example:
 32                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 33            The LATERAL VIEW EXPLODE gets x as a source.
 34        outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
 35            defines a column list of it's alias of this scope, this is that list of columns.
 36            For example:
 37                SELECT * FROM (SELECT ...) AS y(col1, col2)
 38            The inner query would have `["col1", "col2"]` for its `outer_column_list`
 39        parent (Scope): Parent scope
 40        scope_type (ScopeType): Type of this scope, relative to it's parent
 41        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 42        cte_scopes (list[Scope]): List of all child scopes for CTEs
 43        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 44        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 45        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 46        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 47            a list of the left and right child scopes.
 48    """
 49
 50    def __init__(
 51        self,
 52        expression,
 53        sources=None,
 54        outer_column_list=None,
 55        parent=None,
 56        scope_type=ScopeType.ROOT,
 57        lateral_sources=None,
 58    ):
 59        self.expression = expression
 60        self.sources = sources or {}
 61        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
 62        self.sources.update(self.lateral_sources)
 63        self.outer_column_list = outer_column_list or []
 64        self.parent = parent
 65        self.scope_type = scope_type
 66        self.subquery_scopes = []
 67        self.derived_table_scopes = []
 68        self.table_scopes = []
 69        self.cte_scopes = []
 70        self.union_scopes = []
 71        self.udtf_scopes = []
 72        self.clear_cache()
 73
 74    def clear_cache(self):
 75        self._collected = False
 76        self._raw_columns = None
 77        self._derived_tables = None
 78        self._udtfs = None
 79        self._tables = None
 80        self._ctes = None
 81        self._subqueries = None
 82        self._selected_sources = None
 83        self._columns = None
 84        self._external_columns = None
 85        self._join_hints = None
 86        self._pivots = None
 87
 88    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 89        """Branch from the current scope to a new, inner scope"""
 90        return Scope(
 91            expression=expression.unnest(),
 92            sources={**self.cte_sources, **(chain_sources or {})},
 93            parent=self,
 94            scope_type=scope_type,
 95            **kwargs,
 96        )
 97
 98    def _collect(self):
 99        self._tables = []
100        self._ctes = []
101        self._subqueries = []
102        self._derived_tables = []
103        self._udtfs = []
104        self._raw_columns = []
105        self._join_hints = []
106
107        for node, parent, _ in self.walk(bfs=False):
108            if node is self.expression:
109                continue
110            elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
111                self._raw_columns.append(node)
112            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
113                self._tables.append(node)
114            elif isinstance(node, exp.JoinHint):
115                self._join_hints.append(node)
116            elif isinstance(node, exp.UDTF):
117                self._udtfs.append(node)
118            elif isinstance(node, exp.CTE):
119                self._ctes.append(node)
120            elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
121                self._derived_tables.append(node)
122            elif isinstance(node, exp.Subqueryable):
123                self._subqueries.append(node)
124
125        self._collected = True
126
127    def _ensure_collected(self):
128        if not self._collected:
129            self._collect()
130
131    def walk(self, bfs=True):
132        return walk_in_scope(self.expression, bfs=bfs)
133
134    def find(self, *expression_types, bfs=True):
135        """
136        Returns the first node in this scope which matches at least one of the specified types.
137
138        This does NOT traverse into subscopes.
139
140        Args:
141            expression_types (type): the expression type(s) to match.
142            bfs (bool): True to use breadth-first search, False to use depth-first.
143
144        Returns:
145            exp.Expression: the node which matches the criteria or None if no node matching
146            the criteria was found.
147        """
148        return next(self.find_all(*expression_types, bfs=bfs), None)
149
150    def find_all(self, *expression_types, bfs=True):
151        """
152        Returns a generator object which visits all nodes in this scope and only yields those that
153        match at least one of the specified expression types.
154
155        This does NOT traverse into subscopes.
156
157        Args:
158            expression_types (type): the expression type(s) to match.
159            bfs (bool): True to use breadth-first search, False to use depth-first.
160
161        Yields:
162            exp.Expression: nodes
163        """
164        for expression, *_ in self.walk(bfs=bfs):
165            if isinstance(expression, expression_types):
166                yield expression
167
168    def replace(self, old, new):
169        """
170        Replace `old` with `new`.
171
172        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
173
174        Args:
175            old (exp.Expression): old node
176            new (exp.Expression): new node
177        """
178        old.replace(new)
179        self.clear_cache()
180
181    @property
182    def tables(self):
183        """
184        List of tables in this scope.
185
186        Returns:
187            list[exp.Table]: tables
188        """
189        self._ensure_collected()
190        return self._tables
191
192    @property
193    def ctes(self):
194        """
195        List of CTEs in this scope.
196
197        Returns:
198            list[exp.CTE]: ctes
199        """
200        self._ensure_collected()
201        return self._ctes
202
203    @property
204    def derived_tables(self):
205        """
206        List of derived tables in this scope.
207
208        For example:
209            SELECT * FROM (SELECT ...) <- that's a derived table
210
211        Returns:
212            list[exp.Subquery]: derived tables
213        """
214        self._ensure_collected()
215        return self._derived_tables
216
217    @property
218    def udtfs(self):
219        """
220        List of "User Defined Tabular Functions" in this scope.
221
222        Returns:
223            list[exp.UDTF]: UDTFs
224        """
225        self._ensure_collected()
226        return self._udtfs
227
228    @property
229    def subqueries(self):
230        """
231        List of subqueries in this scope.
232
233        For example:
234            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
235
236        Returns:
237            list[exp.Subqueryable]: subqueries
238        """
239        self._ensure_collected()
240        return self._subqueries
241
242    @property
243    def columns(self):
244        """
245        List of columns in this scope.
246
247        Returns:
248            list[exp.Column]: Column instances in this scope, plus any
249                Columns that reference this scope from correlated subqueries.
250        """
251        if self._columns is None:
252            self._ensure_collected()
253            columns = self._raw_columns
254
255            external_columns = [
256                column
257                for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
258                for column in scope.external_columns
259            ]
260
261            named_selects = set(self.expression.named_selects)
262
263            self._columns = []
264            for column in columns + external_columns:
265                ancestor = column.find_ancestor(
266                    exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint
267                )
268                if (
269                    not ancestor
270                    or column.table
271                    or isinstance(ancestor, exp.Select)
272                    or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window))
273                    or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
274                ):
275                    self._columns.append(column)
276
277        return self._columns
278
279    @property
280    def selected_sources(self):
281        """
282        Mapping of nodes and sources that are actually selected from in this scope.
283
284        That is, all tables in a schema are selectable at any point. But a
285        table only becomes a selected source if it's included in a FROM or JOIN clause.
286
287        Returns:
288            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
289        """
290        if self._selected_sources is None:
291            referenced_names = []
292
293            for table in self.tables:
294                referenced_names.append((table.alias_or_name, table))
295            for expression in itertools.chain(self.derived_tables, self.udtfs):
296                referenced_names.append((expression.alias, expression.unnest()))
297            result = {}
298
299            for name, node in referenced_names:
300                if name in result:
301                    raise OptimizeError(f"Alias already used: {name}")
302                if name in self.sources:
303                    result[name] = (node, self.sources[name])
304
305            self._selected_sources = result
306        return self._selected_sources
307
308    @property
309    def cte_sources(self):
310        """
311        Sources that are CTEs.
312
313        Returns:
314            dict[str, Scope]: Mapping of source alias to Scope
315        """
316        return {
317            alias: scope
318            for alias, scope in self.sources.items()
319            if isinstance(scope, Scope) and scope.is_cte
320        }
321
322    @property
323    def selects(self):
324        """
325        Select expressions of this scope.
326
327        For example, for the following expression:
328            SELECT 1 as a, 2 as b FROM x
329
330        The outputs are the "1 as a" and "2 as b" expressions.
331
332        Returns:
333            list[exp.Expression]: expressions
334        """
335        if isinstance(self.expression, exp.Union):
336            return self.expression.unnest().selects
337        return self.expression.selects
338
339    @property
340    def external_columns(self):
341        """
342        Columns that appear to reference sources in outer scopes.
343
344        Returns:
345            list[exp.Column]: Column instances that don't reference
346                sources in the current scope.
347        """
348        if self._external_columns is None:
349            self._external_columns = [
350                c for c in self.columns if c.table not in self.selected_sources
351            ]
352        return self._external_columns
353
354    @property
355    def unqualified_columns(self):
356        """
357        Unqualified columns in the current scope.
358
359        Returns:
360             list[exp.Column]: Unqualified columns
361        """
362        return [c for c in self.columns if not c.table]
363
364    @property
365    def join_hints(self):
366        """
367        Hints that exist in the scope that reference tables
368
369        Returns:
370            list[exp.JoinHint]: Join hints that are referenced within the scope
371        """
372        if self._join_hints is None:
373            return []
374        return self._join_hints
375
376    @property
377    def pivots(self):
378        if not self._pivots:
379            self._pivots = [
380                pivot
381                for node in self.tables + self.derived_tables
382                for pivot in node.args.get("pivots") or []
383            ]
384
385        return self._pivots
386
387    def source_columns(self, source_name):
388        """
389        Get all columns in the current scope for a particular source.
390
391        Args:
392            source_name (str): Name of the source
393        Returns:
394            list[exp.Column]: Column instances that reference `source_name`
395        """
396        return [column for column in self.columns if column.table == source_name]
397
398    @property
399    def is_subquery(self):
400        """Determine if this scope is a subquery"""
401        return self.scope_type == ScopeType.SUBQUERY
402
403    @property
404    def is_derived_table(self):
405        """Determine if this scope is a derived table"""
406        return self.scope_type == ScopeType.DERIVED_TABLE
407
408    @property
409    def is_union(self):
410        """Determine if this scope is a union"""
411        return self.scope_type == ScopeType.UNION
412
413    @property
414    def is_cte(self):
415        """Determine if this scope is a common table expression"""
416        return self.scope_type == ScopeType.CTE
417
418    @property
419    def is_root(self):
420        """Determine if this is the root scope"""
421        return self.scope_type == ScopeType.ROOT
422
423    @property
424    def is_udtf(self):
425        """Determine if this scope is a UDTF (User Defined Table Function)"""
426        return self.scope_type == ScopeType.UDTF
427
428    @property
429    def is_correlated_subquery(self):
430        """Determine if this scope is a correlated subquery"""
431        return bool(self.is_subquery and self.external_columns)
432
433    def rename_source(self, old_name, new_name):
434        """Rename a source in this scope"""
435        columns = self.sources.pop(old_name or "", [])
436        self.sources[new_name] = columns
437
438    def add_source(self, name, source):
439        """Add a source to this scope"""
440        self.sources[name] = source
441        self.clear_cache()
442
443    def remove_source(self, name):
444        """Remove a source from this scope"""
445        self.sources.pop(name, None)
446        self.clear_cache()
447
448    def __repr__(self):
449        return f"Scope<{self.expression.sql()}>"
450
451    def traverse(self):
452        """
453        Traverse the scope tree from this node.
454
455        Yields:
456            Scope: scope instances in depth-first-search post-order
457        """
458        for child_scope in itertools.chain(
459            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
460        ):
461            yield from child_scope.traverse()
462        yield self
463
464    def ref_count(self):
465        """
466        Count the number of times each scope in this tree is referenced.
467
468        Returns:
469            dict[int, int]: Mapping of Scope instance ID to reference count
470        """
471        scope_ref_count = defaultdict(lambda: 0)
472
473        for scope in self.traverse():
474            for _, source in scope.selected_sources.values():
475                scope_ref_count[id(source)] += 1
476
477        return scope_ref_count
478
479
480def traverse_scope(expression):
481    """
482    Traverse an expression by it's "scopes".
483
484    "Scope" represents the current context of a Select statement.
485
486    This is helpful for optimizing queries, where we need more information than
487    the expression tree itself. For example, we might care about the source
488    names within a subquery. Returns a list because a generator could result in
489    incomplete properties which is confusing.
490
491    Examples:
492        >>> import sqlglot
493        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
494        >>> scopes = traverse_scope(expression)
495        >>> scopes[0].expression.sql(), list(scopes[0].sources)
496        ('SELECT a FROM x', ['x'])
497        >>> scopes[1].expression.sql(), list(scopes[1].sources)
498        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
499
500    Args:
501        expression (exp.Expression): expression to traverse
502    Returns:
503        list[Scope]: scope instances
504    """
505    return list(_traverse_scope(Scope(expression)))
506
507
508def build_scope(expression):
509    """
510    Build a scope tree.
511
512    Args:
513        expression (exp.Expression): expression to build the scope tree for
514    Returns:
515        Scope: root scope
516    """
517    return traverse_scope(expression)[-1]
518
519
520def _traverse_scope(scope):
521    if isinstance(scope.expression, exp.Select):
522        yield from _traverse_select(scope)
523    elif isinstance(scope.expression, exp.Union):
524        yield from _traverse_union(scope)
525    elif isinstance(scope.expression, exp.Subquery):
526        yield from _traverse_subqueries(scope)
527    elif isinstance(scope.expression, exp.Table):
528        # This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..)
529        yield from _traverse_tables(scope)
530    elif isinstance(scope.expression, exp.UDTF):
531        pass
532    else:
533        raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
534    yield scope
535
536
537def _traverse_select(scope):
538    yield from _traverse_ctes(scope)
539    yield from _traverse_tables(scope)
540    yield from _traverse_subqueries(scope)
541
542
543def _traverse_union(scope):
544    yield from _traverse_ctes(scope)
545
546    # The last scope to be yield should be the top most scope
547    left = None
548    for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
549        yield left
550
551    right = None
552    for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
553        yield right
554
555    scope.union_scopes = [left, right]
556
557
558def _traverse_ctes(scope):
559    sources = {}
560
561    for cte in scope.ctes:
562        recursive_scope = None
563
564        # if the scope is a recursive cte, it must be in the form of
565        # base_case UNION recursive. thus the recursive scope is the first
566        # section of the union.
567        if scope.expression.args["with"].recursive:
568            union = cte.this
569
570            if isinstance(union, exp.Union):
571                recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
572
573        for child_scope in _traverse_scope(
574            scope.branch(
575                cte.this,
576                chain_sources=sources,
577                outer_column_list=cte.alias_column_names,
578                scope_type=ScopeType.CTE,
579            )
580        ):
581            yield child_scope
582
583            alias = cte.alias
584            sources[alias] = child_scope
585
586            if recursive_scope:
587                child_scope.add_source(alias, recursive_scope)
588
589        # append the final child_scope yielded
590        scope.cte_scopes.append(child_scope)
591
592    scope.sources.update(sources)
593
594
595def _traverse_tables(scope):
596    sources = {}
597
598    # Traverse FROMs, JOINs, and LATERALs in the order they are defined
599    expressions = []
600    from_ = scope.expression.args.get("from")
601    if from_:
602        expressions.append(from_.this)
603
604    for join in scope.expression.args.get("joins") or []:
605        expressions.append(join.this)
606
607    if isinstance(scope.expression, exp.Table):
608        expressions.append(scope.expression)
609
610    expressions.extend(scope.expression.args.get("laterals") or [])
611
612    for expression in expressions:
613        if isinstance(expression, exp.Table):
614            table_name = expression.name
615            source_name = expression.alias_or_name
616
617            if table_name in scope.sources:
618                # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
619                # it is pivoted, because then we get back a new table and hence a new source.
620                pivots = expression.args.get("pivots")
621                if pivots:
622                    sources[pivots[0].alias] = expression
623                else:
624                    sources[source_name] = scope.sources[table_name]
625            elif source_name in sources:
626                sources[find_new_name(sources, table_name)] = expression
627            else:
628                sources[source_name] = expression
629            continue
630
631        if isinstance(expression, exp.UDTF):
632            lateral_sources = sources
633            scope_type = ScopeType.UDTF
634            scopes = scope.udtf_scopes
635        else:
636            lateral_sources = None
637            scope_type = ScopeType.DERIVED_TABLE
638            scopes = scope.derived_table_scopes
639
640        for child_scope in _traverse_scope(
641            scope.branch(
642                expression,
643                lateral_sources=lateral_sources,
644                outer_column_list=expression.alias_column_names,
645                scope_type=scope_type,
646            )
647        ):
648            yield child_scope
649
650            # Tables without aliases will be set as ""
651            # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
652            # Until then, this means that only a single, unaliased derived table is allowed (rather,
653            # the latest one wins.
654            alias = expression.alias
655            sources[alias] = child_scope
656
657        # append the final child_scope yielded
658        scopes.append(child_scope)
659        scope.table_scopes.append(child_scope)
660
661    scope.sources.update(sources)
662
663
664def _traverse_subqueries(scope):
665    for subquery in scope.subqueries:
666        top = None
667        for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
668            yield child_scope
669            top = child_scope
670        scope.subquery_scopes.append(top)
671
672
673def walk_in_scope(expression, bfs=True):
674    """
675    Returns a generator object which visits all nodes in the syntrax tree, stopping at
676    nodes that start child scopes.
677
678    Args:
679        expression (exp.Expression):
680        bfs (bool): if set to True the BFS traversal order will be applied,
681            otherwise the DFS traversal will be used instead.
682
683    Yields:
684        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
685    """
686    # We'll use this variable to pass state into the dfs generator.
687    # Whenever we set it to True, we exclude a subtree from traversal.
688    prune = False
689
690    for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
691        prune = False
692
693        yield node, parent, key
694
695        if node is expression:
696            continue
697        if (
698            isinstance(node, exp.CTE)
699            or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)))
700            or isinstance(node, exp.UDTF)
701            or isinstance(node, exp.Subqueryable)
702        ):
703            prune = True
class ScopeType(enum.Enum):
11class ScopeType(Enum):
12    ROOT = auto()
13    SUBQUERY = auto()
14    DERIVED_TABLE = auto()
15    CTE = auto()
16    UNION = auto()
17    UDTF = auto()

An enumeration.

ROOT = <ScopeType.ROOT: 1>
SUBQUERY = <ScopeType.SUBQUERY: 2>
DERIVED_TABLE = <ScopeType.DERIVED_TABLE: 3>
CTE = <ScopeType.CTE: 4>
UNION = <ScopeType.UNION: 5>
UDTF = <ScopeType.UDTF: 6>
Inherited Members
enum.Enum
name
value
class Scope:
 20class Scope:
 21    """
 22    Selection scope.
 23
 24    Attributes:
 25        expression (exp.Select|exp.Union): Root expression of this scope
 26        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 27            a Table expression or another Scope instance. For example:
 28                SELECT * FROM x                     {"x": Table(this="x")}
 29                SELECT * FROM x AS y                {"y": Table(this="x")}
 30                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 31        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 32            For example:
 33                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 34            The LATERAL VIEW EXPLODE gets x as a source.
 35        outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
 36            defines a column list of it's alias of this scope, this is that list of columns.
 37            For example:
 38                SELECT * FROM (SELECT ...) AS y(col1, col2)
 39            The inner query would have `["col1", "col2"]` for its `outer_column_list`
 40        parent (Scope): Parent scope
 41        scope_type (ScopeType): Type of this scope, relative to it's parent
 42        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 43        cte_scopes (list[Scope]): List of all child scopes for CTEs
 44        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 45        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 46        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 47        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 48            a list of the left and right child scopes.
 49    """
 50
 51    def __init__(
 52        self,
 53        expression,
 54        sources=None,
 55        outer_column_list=None,
 56        parent=None,
 57        scope_type=ScopeType.ROOT,
 58        lateral_sources=None,
 59    ):
 60        self.expression = expression
 61        self.sources = sources or {}
 62        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
 63        self.sources.update(self.lateral_sources)
 64        self.outer_column_list = outer_column_list or []
 65        self.parent = parent
 66        self.scope_type = scope_type
 67        self.subquery_scopes = []
 68        self.derived_table_scopes = []
 69        self.table_scopes = []
 70        self.cte_scopes = []
 71        self.union_scopes = []
 72        self.udtf_scopes = []
 73        self.clear_cache()
 74
 75    def clear_cache(self):
 76        self._collected = False
 77        self._raw_columns = None
 78        self._derived_tables = None
 79        self._udtfs = None
 80        self._tables = None
 81        self._ctes = None
 82        self._subqueries = None
 83        self._selected_sources = None
 84        self._columns = None
 85        self._external_columns = None
 86        self._join_hints = None
 87        self._pivots = None
 88
 89    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 90        """Branch from the current scope to a new, inner scope"""
 91        return Scope(
 92            expression=expression.unnest(),
 93            sources={**self.cte_sources, **(chain_sources or {})},
 94            parent=self,
 95            scope_type=scope_type,
 96            **kwargs,
 97        )
 98
 99    def _collect(self):
100        self._tables = []
101        self._ctes = []
102        self._subqueries = []
103        self._derived_tables = []
104        self._udtfs = []
105        self._raw_columns = []
106        self._join_hints = []
107
108        for node, parent, _ in self.walk(bfs=False):
109            if node is self.expression:
110                continue
111            elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
112                self._raw_columns.append(node)
113            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
114                self._tables.append(node)
115            elif isinstance(node, exp.JoinHint):
116                self._join_hints.append(node)
117            elif isinstance(node, exp.UDTF):
118                self._udtfs.append(node)
119            elif isinstance(node, exp.CTE):
120                self._ctes.append(node)
121            elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
122                self._derived_tables.append(node)
123            elif isinstance(node, exp.Subqueryable):
124                self._subqueries.append(node)
125
126        self._collected = True
127
128    def _ensure_collected(self):
129        if not self._collected:
130            self._collect()
131
132    def walk(self, bfs=True):
133        return walk_in_scope(self.expression, bfs=bfs)
134
135    def find(self, *expression_types, bfs=True):
136        """
137        Returns the first node in this scope which matches at least one of the specified types.
138
139        This does NOT traverse into subscopes.
140
141        Args:
142            expression_types (type): the expression type(s) to match.
143            bfs (bool): True to use breadth-first search, False to use depth-first.
144
145        Returns:
146            exp.Expression: the node which matches the criteria or None if no node matching
147            the criteria was found.
148        """
149        return next(self.find_all(*expression_types, bfs=bfs), None)
150
151    def find_all(self, *expression_types, bfs=True):
152        """
153        Returns a generator object which visits all nodes in this scope and only yields those that
154        match at least one of the specified expression types.
155
156        This does NOT traverse into subscopes.
157
158        Args:
159            expression_types (type): the expression type(s) to match.
160            bfs (bool): True to use breadth-first search, False to use depth-first.
161
162        Yields:
163            exp.Expression: nodes
164        """
165        for expression, *_ in self.walk(bfs=bfs):
166            if isinstance(expression, expression_types):
167                yield expression
168
169    def replace(self, old, new):
170        """
171        Replace `old` with `new`.
172
173        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
174
175        Args:
176            old (exp.Expression): old node
177            new (exp.Expression): new node
178        """
179        old.replace(new)
180        self.clear_cache()
181
182    @property
183    def tables(self):
184        """
185        List of tables in this scope.
186
187        Returns:
188            list[exp.Table]: tables
189        """
190        self._ensure_collected()
191        return self._tables
192
193    @property
194    def ctes(self):
195        """
196        List of CTEs in this scope.
197
198        Returns:
199            list[exp.CTE]: ctes
200        """
201        self._ensure_collected()
202        return self._ctes
203
204    @property
205    def derived_tables(self):
206        """
207        List of derived tables in this scope.
208
209        For example:
210            SELECT * FROM (SELECT ...) <- that's a derived table
211
212        Returns:
213            list[exp.Subquery]: derived tables
214        """
215        self._ensure_collected()
216        return self._derived_tables
217
218    @property
219    def udtfs(self):
220        """
221        List of "User Defined Tabular Functions" in this scope.
222
223        Returns:
224            list[exp.UDTF]: UDTFs
225        """
226        self._ensure_collected()
227        return self._udtfs
228
229    @property
230    def subqueries(self):
231        """
232        List of subqueries in this scope.
233
234        For example:
235            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
236
237        Returns:
238            list[exp.Subqueryable]: subqueries
239        """
240        self._ensure_collected()
241        return self._subqueries
242
243    @property
244    def columns(self):
245        """
246        List of columns in this scope.
247
248        Returns:
249            list[exp.Column]: Column instances in this scope, plus any
250                Columns that reference this scope from correlated subqueries.
251        """
252        if self._columns is None:
253            self._ensure_collected()
254            columns = self._raw_columns
255
256            external_columns = [
257                column
258                for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
259                for column in scope.external_columns
260            ]
261
262            named_selects = set(self.expression.named_selects)
263
264            self._columns = []
265            for column in columns + external_columns:
266                ancestor = column.find_ancestor(
267                    exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint
268                )
269                if (
270                    not ancestor
271                    or column.table
272                    or isinstance(ancestor, exp.Select)
273                    or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window))
274                    or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
275                ):
276                    self._columns.append(column)
277
278        return self._columns
279
280    @property
281    def selected_sources(self):
282        """
283        Mapping of nodes and sources that are actually selected from in this scope.
284
285        That is, all tables in a schema are selectable at any point. But a
286        table only becomes a selected source if it's included in a FROM or JOIN clause.
287
288        Returns:
289            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
290        """
291        if self._selected_sources is None:
292            referenced_names = []
293
294            for table in self.tables:
295                referenced_names.append((table.alias_or_name, table))
296            for expression in itertools.chain(self.derived_tables, self.udtfs):
297                referenced_names.append((expression.alias, expression.unnest()))
298            result = {}
299
300            for name, node in referenced_names:
301                if name in result:
302                    raise OptimizeError(f"Alias already used: {name}")
303                if name in self.sources:
304                    result[name] = (node, self.sources[name])
305
306            self._selected_sources = result
307        return self._selected_sources
308
309    @property
310    def cte_sources(self):
311        """
312        Sources that are CTEs.
313
314        Returns:
315            dict[str, Scope]: Mapping of source alias to Scope
316        """
317        return {
318            alias: scope
319            for alias, scope in self.sources.items()
320            if isinstance(scope, Scope) and scope.is_cte
321        }
322
323    @property
324    def selects(self):
325        """
326        Select expressions of this scope.
327
328        For example, for the following expression:
329            SELECT 1 as a, 2 as b FROM x
330
331        The outputs are the "1 as a" and "2 as b" expressions.
332
333        Returns:
334            list[exp.Expression]: expressions
335        """
336        if isinstance(self.expression, exp.Union):
337            return self.expression.unnest().selects
338        return self.expression.selects
339
340    @property
341    def external_columns(self):
342        """
343        Columns that appear to reference sources in outer scopes.
344
345        Returns:
346            list[exp.Column]: Column instances that don't reference
347                sources in the current scope.
348        """
349        if self._external_columns is None:
350            self._external_columns = [
351                c for c in self.columns if c.table not in self.selected_sources
352            ]
353        return self._external_columns
354
355    @property
356    def unqualified_columns(self):
357        """
358        Unqualified columns in the current scope.
359
360        Returns:
361             list[exp.Column]: Unqualified columns
362        """
363        return [c for c in self.columns if not c.table]
364
365    @property
366    def join_hints(self):
367        """
368        Hints that exist in the scope that reference tables
369
370        Returns:
371            list[exp.JoinHint]: Join hints that are referenced within the scope
372        """
373        if self._join_hints is None:
374            return []
375        return self._join_hints
376
377    @property
378    def pivots(self):
379        if not self._pivots:
380            self._pivots = [
381                pivot
382                for node in self.tables + self.derived_tables
383                for pivot in node.args.get("pivots") or []
384            ]
385
386        return self._pivots
387
388    def source_columns(self, source_name):
389        """
390        Get all columns in the current scope for a particular source.
391
392        Args:
393            source_name (str): Name of the source
394        Returns:
395            list[exp.Column]: Column instances that reference `source_name`
396        """
397        return [column for column in self.columns if column.table == source_name]
398
399    @property
400    def is_subquery(self):
401        """Determine if this scope is a subquery"""
402        return self.scope_type == ScopeType.SUBQUERY
403
404    @property
405    def is_derived_table(self):
406        """Determine if this scope is a derived table"""
407        return self.scope_type == ScopeType.DERIVED_TABLE
408
409    @property
410    def is_union(self):
411        """Determine if this scope is a union"""
412        return self.scope_type == ScopeType.UNION
413
414    @property
415    def is_cte(self):
416        """Determine if this scope is a common table expression"""
417        return self.scope_type == ScopeType.CTE
418
419    @property
420    def is_root(self):
421        """Determine if this is the root scope"""
422        return self.scope_type == ScopeType.ROOT
423
424    @property
425    def is_udtf(self):
426        """Determine if this scope is a UDTF (User Defined Table Function)"""
427        return self.scope_type == ScopeType.UDTF
428
429    @property
430    def is_correlated_subquery(self):
431        """Determine if this scope is a correlated subquery"""
432        return bool(self.is_subquery and self.external_columns)
433
434    def rename_source(self, old_name, new_name):
435        """Rename a source in this scope"""
436        columns = self.sources.pop(old_name or "", [])
437        self.sources[new_name] = columns
438
439    def add_source(self, name, source):
440        """Add a source to this scope"""
441        self.sources[name] = source
442        self.clear_cache()
443
444    def remove_source(self, name):
445        """Remove a source from this scope"""
446        self.sources.pop(name, None)
447        self.clear_cache()
448
449    def __repr__(self):
450        return f"Scope<{self.expression.sql()}>"
451
452    def traverse(self):
453        """
454        Traverse the scope tree from this node.
455
456        Yields:
457            Scope: scope instances in depth-first-search post-order
458        """
459        for child_scope in itertools.chain(
460            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
461        ):
462            yield from child_scope.traverse()
463        yield self
464
465    def ref_count(self):
466        """
467        Count the number of times each scope in this tree is referenced.
468
469        Returns:
470            dict[int, int]: Mapping of Scope instance ID to reference count
471        """
472        scope_ref_count = defaultdict(lambda: 0)
473
474        for scope in self.traverse():
475            for _, source in scope.selected_sources.values():
476                scope_ref_count[id(source)] += 1
477
478        return scope_ref_count

Selection scope.

Attributes:
  • expression (exp.Select|exp.Union): Root expression of this scope
  • sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
  • lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
  • outer_column_list (list[str]): If this is a derived table or CTE, and the outer query defines a column list of it's alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) The inner query would have ["col1", "col2"] for its outer_column_list
  • parent (Scope): Parent scope
  • scope_type (ScopeType): Type of this scope, relative to it's parent
  • subquery_scopes (list[Scope]): List of all child scopes for subqueries
  • cte_scopes (list[Scope]): List of all child scopes for CTEs
  • derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
  • udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
  • table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
  • union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
Scope( expression, sources=None, outer_column_list=None, parent=None, scope_type=<ScopeType.ROOT: 1>, lateral_sources=None)
51    def __init__(
52        self,
53        expression,
54        sources=None,
55        outer_column_list=None,
56        parent=None,
57        scope_type=ScopeType.ROOT,
58        lateral_sources=None,
59    ):
60        self.expression = expression
61        self.sources = sources or {}
62        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
63        self.sources.update(self.lateral_sources)
64        self.outer_column_list = outer_column_list or []
65        self.parent = parent
66        self.scope_type = scope_type
67        self.subquery_scopes = []
68        self.derived_table_scopes = []
69        self.table_scopes = []
70        self.cte_scopes = []
71        self.union_scopes = []
72        self.udtf_scopes = []
73        self.clear_cache()
def clear_cache(self):
75    def clear_cache(self):
76        self._collected = False
77        self._raw_columns = None
78        self._derived_tables = None
79        self._udtfs = None
80        self._tables = None
81        self._ctes = None
82        self._subqueries = None
83        self._selected_sources = None
84        self._columns = None
85        self._external_columns = None
86        self._join_hints = None
87        self._pivots = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
89    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
90        """Branch from the current scope to a new, inner scope"""
91        return Scope(
92            expression=expression.unnest(),
93            sources={**self.cte_sources, **(chain_sources or {})},
94            parent=self,
95            scope_type=scope_type,
96            **kwargs,
97        )

Branch from the current scope to a new, inner scope

def walk(self, bfs=True):
132    def walk(self, bfs=True):
133        return walk_in_scope(self.expression, bfs=bfs)
def find(self, *expression_types, bfs=True):
135    def find(self, *expression_types, bfs=True):
136        """
137        Returns the first node in this scope which matches at least one of the specified types.
138
139        This does NOT traverse into subscopes.
140
141        Args:
142            expression_types (type): the expression type(s) to match.
143            bfs (bool): True to use breadth-first search, False to use depth-first.
144
145        Returns:
146            exp.Expression: the node which matches the criteria or None if no node matching
147            the criteria was found.
148        """
149        return next(self.find_all(*expression_types, bfs=bfs), None)

Returns the first node in this scope which matches at least one of the specified types.

This does NOT traverse into subscopes.

Arguments:
  • expression_types (type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:

exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.

def find_all(self, *expression_types, bfs=True):
151    def find_all(self, *expression_types, bfs=True):
152        """
153        Returns a generator object which visits all nodes in this scope and only yields those that
154        match at least one of the specified expression types.
155
156        This does NOT traverse into subscopes.
157
158        Args:
159            expression_types (type): the expression type(s) to match.
160            bfs (bool): True to use breadth-first search, False to use depth-first.
161
162        Yields:
163            exp.Expression: nodes
164        """
165        for expression, *_ in self.walk(bfs=bfs):
166            if isinstance(expression, expression_types):
167                yield expression

Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.

This does NOT traverse into subscopes.

Arguments:
  • expression_types (type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:

exp.Expression: nodes

def replace(self, old, new):
169    def replace(self, old, new):
170        """
171        Replace `old` with `new`.
172
173        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
174
175        Args:
176            old (exp.Expression): old node
177            new (exp.Expression): new node
178        """
179        old.replace(new)
180        self.clear_cache()

Replace old with new.

This can be used instead of exp.Expression.replace to ensure the Scope is kept up-to-date.

Arguments:
  • old (exp.Expression): old node
  • new (exp.Expression): new node
tables

List of tables in this scope.

Returns:

list[exp.Table]: tables

ctes

List of CTEs in this scope.

Returns:

list[exp.CTE]: ctes

derived_tables

List of derived tables in this scope.

For example:

SELECT * FROM (SELECT ...) <- that's a derived table

Returns:

list[exp.Subquery]: derived tables

udtfs

List of "User Defined Tabular Functions" in this scope.

Returns:

list[exp.UDTF]: UDTFs

subqueries

List of subqueries in this scope.

For example:

SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery

Returns:

list[exp.Subqueryable]: subqueries

columns

List of columns in this scope.

Returns:

list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.

selected_sources

Mapping of nodes and sources that are actually selected from in this scope.

That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.

Returns:

dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes

cte_sources

Sources that are CTEs.

Returns:

dict[str, Scope]: Mapping of source alias to Scope

selects

Select expressions of this scope.

For example, for the following expression: SELECT 1 as a, 2 as b FROM x

The outputs are the "1 as a" and "2 as b" expressions.

Returns:

list[exp.Expression]: expressions

external_columns

Columns that appear to reference sources in outer scopes.

Returns:

list[exp.Column]: Column instances that don't reference sources in the current scope.

unqualified_columns

Unqualified columns in the current scope.

Returns:

list[exp.Column]: Unqualified columns

join_hints

Hints that exist in the scope that reference tables

Returns:

list[exp.JoinHint]: Join hints that are referenced within the scope

def source_columns(self, source_name):
388    def source_columns(self, source_name):
389        """
390        Get all columns in the current scope for a particular source.
391
392        Args:
393            source_name (str): Name of the source
394        Returns:
395            list[exp.Column]: Column instances that reference `source_name`
396        """
397        return [column for column in self.columns if column.table == source_name]

Get all columns in the current scope for a particular source.

Arguments:
  • source_name (str): Name of the source
Returns:

list[exp.Column]: Column instances that reference source_name

is_subquery

Determine if this scope is a subquery

is_derived_table

Determine if this scope is a derived table

is_union

Determine if this scope is a union

is_cte

Determine if this scope is a common table expression

is_root

Determine if this is the root scope

is_udtf

Determine if this scope is a UDTF (User Defined Table Function)

is_correlated_subquery

Determine if this scope is a correlated subquery

def rename_source(self, old_name, new_name):
434    def rename_source(self, old_name, new_name):
435        """Rename a source in this scope"""
436        columns = self.sources.pop(old_name or "", [])
437        self.sources[new_name] = columns

Rename a source in this scope

def add_source(self, name, source):
439    def add_source(self, name, source):
440        """Add a source to this scope"""
441        self.sources[name] = source
442        self.clear_cache()

Add a source to this scope

def remove_source(self, name):
444    def remove_source(self, name):
445        """Remove a source from this scope"""
446        self.sources.pop(name, None)
447        self.clear_cache()

Remove a source from this scope

def traverse(self):
452    def traverse(self):
453        """
454        Traverse the scope tree from this node.
455
456        Yields:
457            Scope: scope instances in depth-first-search post-order
458        """
459        for child_scope in itertools.chain(
460            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
461        ):
462            yield from child_scope.traverse()
463        yield self

Traverse the scope tree from this node.

Yields:

Scope: scope instances in depth-first-search post-order

def ref_count(self):
465    def ref_count(self):
466        """
467        Count the number of times each scope in this tree is referenced.
468
469        Returns:
470            dict[int, int]: Mapping of Scope instance ID to reference count
471        """
472        scope_ref_count = defaultdict(lambda: 0)
473
474        for scope in self.traverse():
475            for _, source in scope.selected_sources.values():
476                scope_ref_count[id(source)] += 1
477
478        return scope_ref_count

Count the number of times each scope in this tree is referenced.

Returns:

dict[int, int]: Mapping of Scope instance ID to reference count

def traverse_scope(expression):
481def traverse_scope(expression):
482    """
483    Traverse an expression by it's "scopes".
484
485    "Scope" represents the current context of a Select statement.
486
487    This is helpful for optimizing queries, where we need more information than
488    the expression tree itself. For example, we might care about the source
489    names within a subquery. Returns a list because a generator could result in
490    incomplete properties which is confusing.
491
492    Examples:
493        >>> import sqlglot
494        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
495        >>> scopes = traverse_scope(expression)
496        >>> scopes[0].expression.sql(), list(scopes[0].sources)
497        ('SELECT a FROM x', ['x'])
498        >>> scopes[1].expression.sql(), list(scopes[1].sources)
499        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
500
501    Args:
502        expression (exp.Expression): expression to traverse
503    Returns:
504        list[Scope]: scope instances
505    """
506    return list(_traverse_scope(Scope(expression)))

Traverse an expression by it's "scopes".

"Scope" represents the current context of a Select statement.

This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
  • expression (exp.Expression): expression to traverse
Returns:

list[Scope]: scope instances

def build_scope(expression):
509def build_scope(expression):
510    """
511    Build a scope tree.
512
513    Args:
514        expression (exp.Expression): expression to build the scope tree for
515    Returns:
516        Scope: root scope
517    """
518    return traverse_scope(expression)[-1]

Build a scope tree.

Arguments:
  • expression (exp.Expression): expression to build the scope tree for
Returns:

Scope: root scope

def walk_in_scope(expression, bfs=True):
674def walk_in_scope(expression, bfs=True):
675    """
676    Returns a generator object which visits all nodes in the syntrax tree, stopping at
677    nodes that start child scopes.
678
679    Args:
680        expression (exp.Expression):
681        bfs (bool): if set to True the BFS traversal order will be applied,
682            otherwise the DFS traversal will be used instead.
683
684    Yields:
685        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
686    """
687    # We'll use this variable to pass state into the dfs generator.
688    # Whenever we set it to True, we exclude a subtree from traversal.
689    prune = False
690
691    for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
692        prune = False
693
694        yield node, parent, key
695
696        if node is expression:
697            continue
698        if (
699            isinstance(node, exp.CTE)
700            or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)))
701            or isinstance(node, exp.UDTF)
702            or isinstance(node, exp.Subqueryable)
703        ):
704            prune = True

Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.

Arguments:
  • expression (exp.Expression):
  • bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
Yields:

tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key