Edit on GitHub

sqlglot.optimizer.scope

  1from __future__ import annotations
  2
  3import itertools
  4import logging
  5import typing as t
  6from collections import defaultdict
  7from enum import Enum, auto
  8
  9from sqlglot import exp
 10from sqlglot.errors import OptimizeError
 11from sqlglot.helper import ensure_collection, find_new_name
 12
 13logger = logging.getLogger("sqlglot")
 14
 15
 16class ScopeType(Enum):
 17    ROOT = auto()
 18    SUBQUERY = auto()
 19    DERIVED_TABLE = auto()
 20    CTE = auto()
 21    UNION = auto()
 22    UDTF = auto()
 23
 24
 25class Scope:
 26    """
 27    Selection scope.
 28
 29    Attributes:
 30        expression (exp.Select|exp.Union): Root expression of this scope
 31        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 32            a Table expression or another Scope instance. For example:
 33                SELECT * FROM x                     {"x": Table(this="x")}
 34                SELECT * FROM x AS y                {"y": Table(this="x")}
 35                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 36        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 37            For example:
 38                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 39            The LATERAL VIEW EXPLODE gets x as a source.
 40        outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
 41            defines a column list of it's alias of this scope, this is that list of columns.
 42            For example:
 43                SELECT * FROM (SELECT ...) AS y(col1, col2)
 44            The inner query would have `["col1", "col2"]` for its `outer_column_list`
 45        parent (Scope): Parent scope
 46        scope_type (ScopeType): Type of this scope, relative to it's parent
 47        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 48        cte_scopes (list[Scope]): List of all child scopes for CTEs
 49        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 50        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 51        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 52        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 53            a list of the left and right child scopes.
 54    """
 55
 56    def __init__(
 57        self,
 58        expression,
 59        sources=None,
 60        outer_column_list=None,
 61        parent=None,
 62        scope_type=ScopeType.ROOT,
 63        lateral_sources=None,
 64    ):
 65        self.expression = expression
 66        self.sources = sources or {}
 67        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
 68        self.sources.update(self.lateral_sources)
 69        self.outer_column_list = outer_column_list or []
 70        self.parent = parent
 71        self.scope_type = scope_type
 72        self.subquery_scopes = []
 73        self.derived_table_scopes = []
 74        self.table_scopes = []
 75        self.cte_scopes = []
 76        self.union_scopes = []
 77        self.udtf_scopes = []
 78        self.clear_cache()
 79
 80    def clear_cache(self):
 81        self._collected = False
 82        self._raw_columns = None
 83        self._derived_tables = None
 84        self._udtfs = None
 85        self._tables = None
 86        self._ctes = None
 87        self._subqueries = None
 88        self._selected_sources = None
 89        self._columns = None
 90        self._external_columns = None
 91        self._join_hints = None
 92        self._pivots = None
 93        self._references = None
 94
 95    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 96        """Branch from the current scope to a new, inner scope"""
 97        return Scope(
 98            expression=expression.unnest(),
 99            sources={**self.cte_sources, **(chain_sources or {})},
100            parent=self,
101            scope_type=scope_type,
102            **kwargs,
103        )
104
105    def _collect(self):
106        self._tables = []
107        self._ctes = []
108        self._subqueries = []
109        self._derived_tables = []
110        self._udtfs = []
111        self._raw_columns = []
112        self._join_hints = []
113
114        for node, parent, _ in self.walk(bfs=False):
115            if node is self.expression:
116                continue
117            elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
118                self._raw_columns.append(node)
119            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
120                self._tables.append(node)
121            elif isinstance(node, exp.JoinHint):
122                self._join_hints.append(node)
123            elif isinstance(node, exp.UDTF):
124                self._udtfs.append(node)
125            elif isinstance(node, exp.CTE):
126                self._ctes.append(node)
127            elif (
128                isinstance(node, exp.Subquery)
129                and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
130                and _is_derived_table(node)
131            ):
132                self._derived_tables.append(node)
133            elif isinstance(node, exp.Subqueryable):
134                self._subqueries.append(node)
135
136        self._collected = True
137
138    def _ensure_collected(self):
139        if not self._collected:
140            self._collect()
141
142    def walk(self, bfs=True, prune=None):
143        return walk_in_scope(self.expression, bfs=bfs, prune=None)
144
145    def find(self, *expression_types, bfs=True):
146        return find_in_scope(self.expression, expression_types, bfs=bfs)
147
148    def find_all(self, *expression_types, bfs=True):
149        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
150
151    def replace(self, old, new):
152        """
153        Replace `old` with `new`.
154
155        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
156
157        Args:
158            old (exp.Expression): old node
159            new (exp.Expression): new node
160        """
161        old.replace(new)
162        self.clear_cache()
163
164    @property
165    def tables(self):
166        """
167        List of tables in this scope.
168
169        Returns:
170            list[exp.Table]: tables
171        """
172        self._ensure_collected()
173        return self._tables
174
175    @property
176    def ctes(self):
177        """
178        List of CTEs in this scope.
179
180        Returns:
181            list[exp.CTE]: ctes
182        """
183        self._ensure_collected()
184        return self._ctes
185
186    @property
187    def derived_tables(self):
188        """
189        List of derived tables in this scope.
190
191        For example:
192            SELECT * FROM (SELECT ...) <- that's a derived table
193
194        Returns:
195            list[exp.Subquery]: derived tables
196        """
197        self._ensure_collected()
198        return self._derived_tables
199
200    @property
201    def udtfs(self):
202        """
203        List of "User Defined Tabular Functions" in this scope.
204
205        Returns:
206            list[exp.UDTF]: UDTFs
207        """
208        self._ensure_collected()
209        return self._udtfs
210
211    @property
212    def subqueries(self):
213        """
214        List of subqueries in this scope.
215
216        For example:
217            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
218
219        Returns:
220            list[exp.Subqueryable]: subqueries
221        """
222        self._ensure_collected()
223        return self._subqueries
224
225    @property
226    def columns(self):
227        """
228        List of columns in this scope.
229
230        Returns:
231            list[exp.Column]: Column instances in this scope, plus any
232                Columns that reference this scope from correlated subqueries.
233        """
234        if self._columns is None:
235            self._ensure_collected()
236            columns = self._raw_columns
237
238            external_columns = [
239                column
240                for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
241                for column in scope.external_columns
242            ]
243
244            named_selects = set(self.expression.named_selects)
245
246            self._columns = []
247            for column in columns + external_columns:
248                ancestor = column.find_ancestor(
249                    exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table
250                )
251                if (
252                    not ancestor
253                    or column.table
254                    or isinstance(ancestor, exp.Select)
255                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
256                    or (
257                        isinstance(ancestor, exp.Order)
258                        and (
259                            isinstance(ancestor.parent, exp.Window)
260                            or column.name not in named_selects
261                        )
262                    )
263                ):
264                    self._columns.append(column)
265
266        return self._columns
267
268    @property
269    def selected_sources(self):
270        """
271        Mapping of nodes and sources that are actually selected from in this scope.
272
273        That is, all tables in a schema are selectable at any point. But a
274        table only becomes a selected source if it's included in a FROM or JOIN clause.
275
276        Returns:
277            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
278        """
279        if self._selected_sources is None:
280            result = {}
281
282            for name, node in self.references:
283                if name in result:
284                    raise OptimizeError(f"Alias already used: {name}")
285                if name in self.sources:
286                    result[name] = (node, self.sources[name])
287
288            self._selected_sources = result
289        return self._selected_sources
290
291    @property
292    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
293        if self._references is None:
294            self._references = []
295
296            for table in self.tables:
297                self._references.append((table.alias_or_name, table))
298            for expression in itertools.chain(self.derived_tables, self.udtfs):
299                self._references.append(
300                    (
301                        expression.alias,
302                        expression if expression.args.get("pivots") else expression.unnest(),
303                    )
304                )
305
306        return self._references
307
308    @property
309    def cte_sources(self):
310        """
311        Sources that are CTEs.
312
313        Returns:
314            dict[str, Scope]: Mapping of source alias to Scope
315        """
316        return {
317            alias: scope
318            for alias, scope in self.sources.items()
319            if isinstance(scope, Scope) and scope.is_cte
320        }
321
322    @property
323    def external_columns(self):
324        """
325        Columns that appear to reference sources in outer scopes.
326
327        Returns:
328            list[exp.Column]: Column instances that don't reference
329                sources in the current scope.
330        """
331        if self._external_columns is None:
332            self._external_columns = [
333                c for c in self.columns if c.table not in self.selected_sources
334            ]
335        return self._external_columns
336
337    @property
338    def unqualified_columns(self):
339        """
340        Unqualified columns in the current scope.
341
342        Returns:
343             list[exp.Column]: Unqualified columns
344        """
345        return [c for c in self.columns if not c.table]
346
347    @property
348    def join_hints(self):
349        """
350        Hints that exist in the scope that reference tables
351
352        Returns:
353            list[exp.JoinHint]: Join hints that are referenced within the scope
354        """
355        if self._join_hints is None:
356            return []
357        return self._join_hints
358
359    @property
360    def pivots(self):
361        if not self._pivots:
362            self._pivots = [
363                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
364            ]
365
366        return self._pivots
367
368    def source_columns(self, source_name):
369        """
370        Get all columns in the current scope for a particular source.
371
372        Args:
373            source_name (str): Name of the source
374        Returns:
375            list[exp.Column]: Column instances that reference `source_name`
376        """
377        return [column for column in self.columns if column.table == source_name]
378
379    @property
380    def is_subquery(self):
381        """Determine if this scope is a subquery"""
382        return self.scope_type == ScopeType.SUBQUERY
383
384    @property
385    def is_derived_table(self):
386        """Determine if this scope is a derived table"""
387        return self.scope_type == ScopeType.DERIVED_TABLE
388
389    @property
390    def is_union(self):
391        """Determine if this scope is a union"""
392        return self.scope_type == ScopeType.UNION
393
394    @property
395    def is_cte(self):
396        """Determine if this scope is a common table expression"""
397        return self.scope_type == ScopeType.CTE
398
399    @property
400    def is_root(self):
401        """Determine if this is the root scope"""
402        return self.scope_type == ScopeType.ROOT
403
404    @property
405    def is_udtf(self):
406        """Determine if this scope is a UDTF (User Defined Table Function)"""
407        return self.scope_type == ScopeType.UDTF
408
409    @property
410    def is_correlated_subquery(self):
411        """Determine if this scope is a correlated subquery"""
412        return bool(
413            (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
414            and self.external_columns
415        )
416
417    def rename_source(self, old_name, new_name):
418        """Rename a source in this scope"""
419        columns = self.sources.pop(old_name or "", [])
420        self.sources[new_name] = columns
421
422    def add_source(self, name, source):
423        """Add a source to this scope"""
424        self.sources[name] = source
425        self.clear_cache()
426
427    def remove_source(self, name):
428        """Remove a source from this scope"""
429        self.sources.pop(name, None)
430        self.clear_cache()
431
432    def __repr__(self):
433        return f"Scope<{self.expression.sql()}>"
434
435    def traverse(self):
436        """
437        Traverse the scope tree from this node.
438
439        Yields:
440            Scope: scope instances in depth-first-search post-order
441        """
442        for child_scope in itertools.chain(
443            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
444        ):
445            yield from child_scope.traverse()
446        yield self
447
448    def ref_count(self):
449        """
450        Count the number of times each scope in this tree is referenced.
451
452        Returns:
453            dict[int, int]: Mapping of Scope instance ID to reference count
454        """
455        scope_ref_count = defaultdict(lambda: 0)
456
457        for scope in self.traverse():
458            for _, source in scope.selected_sources.values():
459                scope_ref_count[id(source)] += 1
460
461        return scope_ref_count
462
463
464def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
465    """
466    Traverse an expression by its "scopes".
467
468    "Scope" represents the current context of a Select statement.
469
470    This is helpful for optimizing queries, where we need more information than
471    the expression tree itself. For example, we might care about the source
472    names within a subquery. Returns a list because a generator could result in
473    incomplete properties which is confusing.
474
475    Examples:
476        >>> import sqlglot
477        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
478        >>> scopes = traverse_scope(expression)
479        >>> scopes[0].expression.sql(), list(scopes[0].sources)
480        ('SELECT a FROM x', ['x'])
481        >>> scopes[1].expression.sql(), list(scopes[1].sources)
482        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
483
484    Args:
485        expression (exp.Expression): expression to traverse
486    Returns:
487        list[Scope]: scope instances
488    """
489    if isinstance(expression, exp.Unionable) or (
490        isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
491    ):
492        return list(_traverse_scope(Scope(expression)))
493
494    return []
495
496
497def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
498    """
499    Build a scope tree.
500
501    Args:
502        expression (exp.Expression): expression to build the scope tree for
503    Returns:
504        Scope: root scope
505    """
506    scopes = traverse_scope(expression)
507    if scopes:
508        return scopes[-1]
509    return None
510
511
512def _traverse_scope(scope):
513    if isinstance(scope.expression, exp.Select):
514        yield from _traverse_select(scope)
515    elif isinstance(scope.expression, exp.Union):
516        yield from _traverse_union(scope)
517    elif isinstance(scope.expression, exp.Subquery):
518        yield from _traverse_subqueries(scope)
519    elif isinstance(scope.expression, exp.Table):
520        yield from _traverse_tables(scope)
521    elif isinstance(scope.expression, exp.UDTF):
522        yield from _traverse_udtfs(scope)
523    elif isinstance(scope.expression, exp.DDL):
524        yield from _traverse_ddl(scope)
525    else:
526        logger.warning(
527            "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
528        )
529        return
530
531    yield scope
532
533
534def _traverse_select(scope):
535    yield from _traverse_ctes(scope)
536    yield from _traverse_tables(scope)
537    yield from _traverse_subqueries(scope)
538
539
540def _traverse_union(scope):
541    yield from _traverse_ctes(scope)
542
543    # The last scope to be yield should be the top most scope
544    left = None
545    for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
546        yield left
547
548    right = None
549    for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
550        yield right
551
552    scope.union_scopes = [left, right]
553
554
555def _traverse_ctes(scope):
556    sources = {}
557
558    for cte in scope.ctes:
559        recursive_scope = None
560
561        # if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
562        # thus the recursive scope is the first section of the union.
563        with_ = scope.expression.args.get("with")
564        if with_ and with_.recursive:
565            union = cte.this
566
567            if isinstance(union, exp.Union):
568                recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
569
570        child_scope = None
571
572        for child_scope in _traverse_scope(
573            scope.branch(
574                cte.this,
575                chain_sources=sources,
576                outer_column_list=cte.alias_column_names,
577                scope_type=ScopeType.CTE,
578            )
579        ):
580            yield child_scope
581
582            alias = cte.alias
583            sources[alias] = child_scope
584
585            if recursive_scope:
586                child_scope.add_source(alias, recursive_scope)
587
588        # append the final child_scope yielded
589        if child_scope:
590            scope.cte_scopes.append(child_scope)
591
592    scope.sources.update(sources)
593
594
595def _is_derived_table(expression: exp.Subquery) -> bool:
596    """
597    We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
598    as it doesn't introduce a new scope. If an alias is present, it shadows all names
599    under the Subquery, so that's one exception to this rule.
600    """
601    return bool(expression.alias or isinstance(expression.this, exp.Subqueryable))
602
603
604def _traverse_tables(scope):
605    sources = {}
606
607    # Traverse FROMs, JOINs, and LATERALs in the order they are defined
608    expressions = []
609    from_ = scope.expression.args.get("from")
610    if from_:
611        expressions.append(from_.this)
612
613    for join in scope.expression.args.get("joins") or []:
614        expressions.append(join.this)
615
616    if isinstance(scope.expression, exp.Table):
617        expressions.append(scope.expression)
618
619    expressions.extend(scope.expression.args.get("laterals") or [])
620
621    for expression in expressions:
622        if isinstance(expression, exp.Table):
623            table_name = expression.name
624            source_name = expression.alias_or_name
625
626            if table_name in scope.sources and not expression.db:
627                # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
628                # it is pivoted, because then we get back a new table and hence a new source.
629                pivots = expression.args.get("pivots")
630                if pivots:
631                    sources[pivots[0].alias] = expression
632                else:
633                    sources[source_name] = scope.sources[table_name]
634            elif source_name in sources:
635                sources[find_new_name(sources, table_name)] = expression
636            else:
637                sources[source_name] = expression
638
639            # Make sure to not include the joins twice
640            if expression is not scope.expression:
641                expressions.extend(join.this for join in expression.args.get("joins") or [])
642
643            continue
644
645        if not isinstance(expression, exp.DerivedTable):
646            continue
647
648        if isinstance(expression, exp.UDTF):
649            lateral_sources = sources
650            scope_type = ScopeType.UDTF
651            scopes = scope.udtf_scopes
652        elif _is_derived_table(expression):
653            lateral_sources = None
654            scope_type = ScopeType.DERIVED_TABLE
655            scopes = scope.derived_table_scopes
656            expressions.extend(join.this for join in expression.args.get("joins") or [])
657        else:
658            # Makes sure we check for possible sources in nested table constructs
659            expressions.append(expression.this)
660            expressions.extend(join.this for join in expression.args.get("joins") or [])
661            continue
662
663        for child_scope in _traverse_scope(
664            scope.branch(
665                expression,
666                lateral_sources=lateral_sources,
667                outer_column_list=expression.alias_column_names,
668                scope_type=scope_type,
669            )
670        ):
671            yield child_scope
672
673            # Tables without aliases will be set as ""
674            # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
675            # Until then, this means that only a single, unaliased derived table is allowed (rather,
676            # the latest one wins.
677            sources[expression.alias] = child_scope
678
679        # append the final child_scope yielded
680        scopes.append(child_scope)
681        scope.table_scopes.append(child_scope)
682
683    scope.sources.update(sources)
684
685
686def _traverse_subqueries(scope):
687    for subquery in scope.subqueries:
688        top = None
689        for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
690            yield child_scope
691            top = child_scope
692        scope.subquery_scopes.append(top)
693
694
695def _traverse_udtfs(scope):
696    if isinstance(scope.expression, exp.Unnest):
697        expressions = scope.expression.expressions
698    elif isinstance(scope.expression, exp.Lateral):
699        expressions = [scope.expression.this]
700    else:
701        expressions = []
702
703    sources = {}
704    for expression in expressions:
705        if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
706            top = None
707            for child_scope in _traverse_scope(
708                scope.branch(
709                    expression,
710                    scope_type=ScopeType.DERIVED_TABLE,
711                    outer_column_list=expression.alias_column_names,
712                )
713            ):
714                yield child_scope
715                top = child_scope
716                sources[expression.alias] = child_scope
717
718            scope.derived_table_scopes.append(top)
719            scope.table_scopes.append(top)
720
721    scope.sources.update(sources)
722
723
724def _traverse_ddl(scope):
725    yield from _traverse_ctes(scope)
726
727    query_scope = scope.branch(
728        scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
729    )
730    query_scope._collect()
731    query_scope._ctes = scope.ctes + query_scope._ctes
732
733    yield from _traverse_scope(query_scope)
734
735
736def walk_in_scope(expression, bfs=True, prune=None):
737    """
738    Returns a generator object which visits all nodes in the syntrax tree, stopping at
739    nodes that start child scopes.
740
741    Args:
742        expression (exp.Expression):
743        bfs (bool): if set to True the BFS traversal order will be applied,
744            otherwise the DFS traversal will be used instead.
745        prune ((node, parent, arg_key) -> bool): callable that returns True if
746            the generator should stop traversing this branch of the tree.
747
748    Yields:
749        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
750    """
751    # We'll use this variable to pass state into the dfs generator.
752    # Whenever we set it to True, we exclude a subtree from traversal.
753    crossed_scope_boundary = False
754
755    for node, parent, key in expression.walk(
756        bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
757    ):
758        crossed_scope_boundary = False
759
760        yield node, parent, key
761
762        if node is expression:
763            continue
764        if (
765            isinstance(node, exp.CTE)
766            or (
767                isinstance(node, exp.Subquery)
768                and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
769                and _is_derived_table(node)
770            )
771            or isinstance(node, exp.UDTF)
772            or isinstance(node, exp.Subqueryable)
773        ):
774            crossed_scope_boundary = True
775
776            if isinstance(node, (exp.Subquery, exp.UDTF)):
777                # The following args are not actually in the inner scope, so we should visit them
778                for key in ("joins", "laterals", "pivots"):
779                    for arg in node.args.get(key) or []:
780                        yield from walk_in_scope(arg, bfs=bfs)
781
782
783def find_all_in_scope(expression, expression_types, bfs=True):
784    """
785    Returns a generator object which visits all nodes in this scope and only yields those that
786    match at least one of the specified expression types.
787
788    This does NOT traverse into subscopes.
789
790    Args:
791        expression (exp.Expression):
792        expression_types (tuple[type]|type): the expression type(s) to match.
793        bfs (bool): True to use breadth-first search, False to use depth-first.
794
795    Yields:
796        exp.Expression: nodes
797    """
798    for expression, *_ in walk_in_scope(expression, bfs=bfs):
799        if isinstance(expression, tuple(ensure_collection(expression_types))):
800            yield expression
801
802
803def find_in_scope(expression, expression_types, bfs=True):
804    """
805    Returns the first node in this scope which matches at least one of the specified types.
806
807    This does NOT traverse into subscopes.
808
809    Args:
810        expression (exp.Expression):
811        expression_types (tuple[type]|type): the expression type(s) to match.
812        bfs (bool): True to use breadth-first search, False to use depth-first.
813
814    Returns:
815        exp.Expression: the node which matches the criteria or None if no node matching
816        the criteria was found.
817    """
818    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
logger = <Logger sqlglot (WARNING)>
class ScopeType(enum.Enum):
17class ScopeType(Enum):
18    ROOT = auto()
19    SUBQUERY = auto()
20    DERIVED_TABLE = auto()
21    CTE = auto()
22    UNION = auto()
23    UDTF = auto()

An enumeration.

ROOT = <ScopeType.ROOT: 1>
SUBQUERY = <ScopeType.SUBQUERY: 2>
DERIVED_TABLE = <ScopeType.DERIVED_TABLE: 3>
CTE = <ScopeType.CTE: 4>
UNION = <ScopeType.UNION: 5>
UDTF = <ScopeType.UDTF: 6>
Inherited Members
enum.Enum
name
value
class Scope:
 26class Scope:
 27    """
 28    Selection scope.
 29
 30    Attributes:
 31        expression (exp.Select|exp.Union): Root expression of this scope
 32        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 33            a Table expression or another Scope instance. For example:
 34                SELECT * FROM x                     {"x": Table(this="x")}
 35                SELECT * FROM x AS y                {"y": Table(this="x")}
 36                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 37        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 38            For example:
 39                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 40            The LATERAL VIEW EXPLODE gets x as a source.
 41        outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
 42            defines a column list of it's alias of this scope, this is that list of columns.
 43            For example:
 44                SELECT * FROM (SELECT ...) AS y(col1, col2)
 45            The inner query would have `["col1", "col2"]` for its `outer_column_list`
 46        parent (Scope): Parent scope
 47        scope_type (ScopeType): Type of this scope, relative to it's parent
 48        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 49        cte_scopes (list[Scope]): List of all child scopes for CTEs
 50        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 51        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 52        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 53        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 54            a list of the left and right child scopes.
 55    """
 56
 57    def __init__(
 58        self,
 59        expression,
 60        sources=None,
 61        outer_column_list=None,
 62        parent=None,
 63        scope_type=ScopeType.ROOT,
 64        lateral_sources=None,
 65    ):
 66        self.expression = expression
 67        self.sources = sources or {}
 68        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
 69        self.sources.update(self.lateral_sources)
 70        self.outer_column_list = outer_column_list or []
 71        self.parent = parent
 72        self.scope_type = scope_type
 73        self.subquery_scopes = []
 74        self.derived_table_scopes = []
 75        self.table_scopes = []
 76        self.cte_scopes = []
 77        self.union_scopes = []
 78        self.udtf_scopes = []
 79        self.clear_cache()
 80
 81    def clear_cache(self):
 82        self._collected = False
 83        self._raw_columns = None
 84        self._derived_tables = None
 85        self._udtfs = None
 86        self._tables = None
 87        self._ctes = None
 88        self._subqueries = None
 89        self._selected_sources = None
 90        self._columns = None
 91        self._external_columns = None
 92        self._join_hints = None
 93        self._pivots = None
 94        self._references = None
 95
 96    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 97        """Branch from the current scope to a new, inner scope"""
 98        return Scope(
 99            expression=expression.unnest(),
100            sources={**self.cte_sources, **(chain_sources or {})},
101            parent=self,
102            scope_type=scope_type,
103            **kwargs,
104        )
105
106    def _collect(self):
107        self._tables = []
108        self._ctes = []
109        self._subqueries = []
110        self._derived_tables = []
111        self._udtfs = []
112        self._raw_columns = []
113        self._join_hints = []
114
115        for node, parent, _ in self.walk(bfs=False):
116            if node is self.expression:
117                continue
118            elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
119                self._raw_columns.append(node)
120            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
121                self._tables.append(node)
122            elif isinstance(node, exp.JoinHint):
123                self._join_hints.append(node)
124            elif isinstance(node, exp.UDTF):
125                self._udtfs.append(node)
126            elif isinstance(node, exp.CTE):
127                self._ctes.append(node)
128            elif (
129                isinstance(node, exp.Subquery)
130                and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
131                and _is_derived_table(node)
132            ):
133                self._derived_tables.append(node)
134            elif isinstance(node, exp.Subqueryable):
135                self._subqueries.append(node)
136
137        self._collected = True
138
139    def _ensure_collected(self):
140        if not self._collected:
141            self._collect()
142
143    def walk(self, bfs=True, prune=None):
144        return walk_in_scope(self.expression, bfs=bfs, prune=None)
145
146    def find(self, *expression_types, bfs=True):
147        return find_in_scope(self.expression, expression_types, bfs=bfs)
148
149    def find_all(self, *expression_types, bfs=True):
150        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
151
152    def replace(self, old, new):
153        """
154        Replace `old` with `new`.
155
156        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
157
158        Args:
159            old (exp.Expression): old node
160            new (exp.Expression): new node
161        """
162        old.replace(new)
163        self.clear_cache()
164
165    @property
166    def tables(self):
167        """
168        List of tables in this scope.
169
170        Returns:
171            list[exp.Table]: tables
172        """
173        self._ensure_collected()
174        return self._tables
175
176    @property
177    def ctes(self):
178        """
179        List of CTEs in this scope.
180
181        Returns:
182            list[exp.CTE]: ctes
183        """
184        self._ensure_collected()
185        return self._ctes
186
187    @property
188    def derived_tables(self):
189        """
190        List of derived tables in this scope.
191
192        For example:
193            SELECT * FROM (SELECT ...) <- that's a derived table
194
195        Returns:
196            list[exp.Subquery]: derived tables
197        """
198        self._ensure_collected()
199        return self._derived_tables
200
201    @property
202    def udtfs(self):
203        """
204        List of "User Defined Tabular Functions" in this scope.
205
206        Returns:
207            list[exp.UDTF]: UDTFs
208        """
209        self._ensure_collected()
210        return self._udtfs
211
212    @property
213    def subqueries(self):
214        """
215        List of subqueries in this scope.
216
217        For example:
218            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
219
220        Returns:
221            list[exp.Subqueryable]: subqueries
222        """
223        self._ensure_collected()
224        return self._subqueries
225
226    @property
227    def columns(self):
228        """
229        List of columns in this scope.
230
231        Returns:
232            list[exp.Column]: Column instances in this scope, plus any
233                Columns that reference this scope from correlated subqueries.
234        """
235        if self._columns is None:
236            self._ensure_collected()
237            columns = self._raw_columns
238
239            external_columns = [
240                column
241                for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
242                for column in scope.external_columns
243            ]
244
245            named_selects = set(self.expression.named_selects)
246
247            self._columns = []
248            for column in columns + external_columns:
249                ancestor = column.find_ancestor(
250                    exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table
251                )
252                if (
253                    not ancestor
254                    or column.table
255                    or isinstance(ancestor, exp.Select)
256                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
257                    or (
258                        isinstance(ancestor, exp.Order)
259                        and (
260                            isinstance(ancestor.parent, exp.Window)
261                            or column.name not in named_selects
262                        )
263                    )
264                ):
265                    self._columns.append(column)
266
267        return self._columns
268
269    @property
270    def selected_sources(self):
271        """
272        Mapping of nodes and sources that are actually selected from in this scope.
273
274        That is, all tables in a schema are selectable at any point. But a
275        table only becomes a selected source if it's included in a FROM or JOIN clause.
276
277        Returns:
278            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
279        """
280        if self._selected_sources is None:
281            result = {}
282
283            for name, node in self.references:
284                if name in result:
285                    raise OptimizeError(f"Alias already used: {name}")
286                if name in self.sources:
287                    result[name] = (node, self.sources[name])
288
289            self._selected_sources = result
290        return self._selected_sources
291
292    @property
293    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
294        if self._references is None:
295            self._references = []
296
297            for table in self.tables:
298                self._references.append((table.alias_or_name, table))
299            for expression in itertools.chain(self.derived_tables, self.udtfs):
300                self._references.append(
301                    (
302                        expression.alias,
303                        expression if expression.args.get("pivots") else expression.unnest(),
304                    )
305                )
306
307        return self._references
308
309    @property
310    def cte_sources(self):
311        """
312        Sources that are CTEs.
313
314        Returns:
315            dict[str, Scope]: Mapping of source alias to Scope
316        """
317        return {
318            alias: scope
319            for alias, scope in self.sources.items()
320            if isinstance(scope, Scope) and scope.is_cte
321        }
322
323    @property
324    def external_columns(self):
325        """
326        Columns that appear to reference sources in outer scopes.
327
328        Returns:
329            list[exp.Column]: Column instances that don't reference
330                sources in the current scope.
331        """
332        if self._external_columns is None:
333            self._external_columns = [
334                c for c in self.columns if c.table not in self.selected_sources
335            ]
336        return self._external_columns
337
338    @property
339    def unqualified_columns(self):
340        """
341        Unqualified columns in the current scope.
342
343        Returns:
344             list[exp.Column]: Unqualified columns
345        """
346        return [c for c in self.columns if not c.table]
347
348    @property
349    def join_hints(self):
350        """
351        Hints that exist in the scope that reference tables
352
353        Returns:
354            list[exp.JoinHint]: Join hints that are referenced within the scope
355        """
356        if self._join_hints is None:
357            return []
358        return self._join_hints
359
360    @property
361    def pivots(self):
362        if not self._pivots:
363            self._pivots = [
364                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
365            ]
366
367        return self._pivots
368
369    def source_columns(self, source_name):
370        """
371        Get all columns in the current scope for a particular source.
372
373        Args:
374            source_name (str): Name of the source
375        Returns:
376            list[exp.Column]: Column instances that reference `source_name`
377        """
378        return [column for column in self.columns if column.table == source_name]
379
380    @property
381    def is_subquery(self):
382        """Determine if this scope is a subquery"""
383        return self.scope_type == ScopeType.SUBQUERY
384
385    @property
386    def is_derived_table(self):
387        """Determine if this scope is a derived table"""
388        return self.scope_type == ScopeType.DERIVED_TABLE
389
390    @property
391    def is_union(self):
392        """Determine if this scope is a union"""
393        return self.scope_type == ScopeType.UNION
394
395    @property
396    def is_cte(self):
397        """Determine if this scope is a common table expression"""
398        return self.scope_type == ScopeType.CTE
399
400    @property
401    def is_root(self):
402        """Determine if this is the root scope"""
403        return self.scope_type == ScopeType.ROOT
404
405    @property
406    def is_udtf(self):
407        """Determine if this scope is a UDTF (User Defined Table Function)"""
408        return self.scope_type == ScopeType.UDTF
409
410    @property
411    def is_correlated_subquery(self):
412        """Determine if this scope is a correlated subquery"""
413        return bool(
414            (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
415            and self.external_columns
416        )
417
418    def rename_source(self, old_name, new_name):
419        """Rename a source in this scope"""
420        columns = self.sources.pop(old_name or "", [])
421        self.sources[new_name] = columns
422
423    def add_source(self, name, source):
424        """Add a source to this scope"""
425        self.sources[name] = source
426        self.clear_cache()
427
428    def remove_source(self, name):
429        """Remove a source from this scope"""
430        self.sources.pop(name, None)
431        self.clear_cache()
432
433    def __repr__(self):
434        return f"Scope<{self.expression.sql()}>"
435
436    def traverse(self):
437        """
438        Traverse the scope tree from this node.
439
440        Yields:
441            Scope: scope instances in depth-first-search post-order
442        """
443        for child_scope in itertools.chain(
444            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
445        ):
446            yield from child_scope.traverse()
447        yield self
448
449    def ref_count(self):
450        """
451        Count the number of times each scope in this tree is referenced.
452
453        Returns:
454            dict[int, int]: Mapping of Scope instance ID to reference count
455        """
456        scope_ref_count = defaultdict(lambda: 0)
457
458        for scope in self.traverse():
459            for _, source in scope.selected_sources.values():
460                scope_ref_count[id(source)] += 1
461
462        return scope_ref_count

Selection scope.

Attributes:
  • expression (exp.Select|exp.Union): Root expression of this scope
  • sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
  • lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
  • outer_column_list (list[str]): If this is a derived table or CTE, and the outer query defines a column list of it's alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) The inner query would have ["col1", "col2"] for its outer_column_list
  • parent (Scope): Parent scope
  • scope_type (ScopeType): Type of this scope, relative to it's parent
  • subquery_scopes (list[Scope]): List of all child scopes for subqueries
  • cte_scopes (list[Scope]): List of all child scopes for CTEs
  • derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
  • udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
  • table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
  • union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
Scope( expression, sources=None, outer_column_list=None, parent=None, scope_type=<ScopeType.ROOT: 1>, lateral_sources=None)
57    def __init__(
58        self,
59        expression,
60        sources=None,
61        outer_column_list=None,
62        parent=None,
63        scope_type=ScopeType.ROOT,
64        lateral_sources=None,
65    ):
66        self.expression = expression
67        self.sources = sources or {}
68        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
69        self.sources.update(self.lateral_sources)
70        self.outer_column_list = outer_column_list or []
71        self.parent = parent
72        self.scope_type = scope_type
73        self.subquery_scopes = []
74        self.derived_table_scopes = []
75        self.table_scopes = []
76        self.cte_scopes = []
77        self.union_scopes = []
78        self.udtf_scopes = []
79        self.clear_cache()
expression
sources
lateral_sources
outer_column_list
parent
scope_type
subquery_scopes
derived_table_scopes
table_scopes
cte_scopes
union_scopes
udtf_scopes
def clear_cache(self):
81    def clear_cache(self):
82        self._collected = False
83        self._raw_columns = None
84        self._derived_tables = None
85        self._udtfs = None
86        self._tables = None
87        self._ctes = None
88        self._subqueries = None
89        self._selected_sources = None
90        self._columns = None
91        self._external_columns = None
92        self._join_hints = None
93        self._pivots = None
94        self._references = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 96    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 97        """Branch from the current scope to a new, inner scope"""
 98        return Scope(
 99            expression=expression.unnest(),
100            sources={**self.cte_sources, **(chain_sources or {})},
101            parent=self,
102            scope_type=scope_type,
103            **kwargs,
104        )

Branch from the current scope to a new, inner scope

def walk(self, bfs=True, prune=None):
143    def walk(self, bfs=True, prune=None):
144        return walk_in_scope(self.expression, bfs=bfs, prune=None)
def find(self, *expression_types, bfs=True):
146    def find(self, *expression_types, bfs=True):
147        return find_in_scope(self.expression, expression_types, bfs=bfs)
def find_all(self, *expression_types, bfs=True):
149    def find_all(self, *expression_types, bfs=True):
150        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
def replace(self, old, new):
152    def replace(self, old, new):
153        """
154        Replace `old` with `new`.
155
156        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
157
158        Args:
159            old (exp.Expression): old node
160            new (exp.Expression): new node
161        """
162        old.replace(new)
163        self.clear_cache()

Replace old with new.

This can be used instead of exp.Expression.replace to ensure the Scope is kept up-to-date.

Arguments:
  • old (exp.Expression): old node
  • new (exp.Expression): new node
tables

List of tables in this scope.

Returns:

list[exp.Table]: tables

ctes

List of CTEs in this scope.

Returns:

list[exp.CTE]: ctes

derived_tables

List of derived tables in this scope.

For example:

SELECT * FROM (SELECT ...) <- that's a derived table

Returns:

list[exp.Subquery]: derived tables

udtfs

List of "User Defined Tabular Functions" in this scope.

Returns:

list[exp.UDTF]: UDTFs

subqueries

List of subqueries in this scope.

For example:

SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery

Returns:

list[exp.Subqueryable]: subqueries

columns

List of columns in this scope.

Returns:

list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.

selected_sources

Mapping of nodes and sources that are actually selected from in this scope.

That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.

Returns:

dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes

references: List[Tuple[str, sqlglot.expressions.Expression]]
cte_sources

Sources that are CTEs.

Returns:

dict[str, Scope]: Mapping of source alias to Scope

external_columns

Columns that appear to reference sources in outer scopes.

Returns:

list[exp.Column]: Column instances that don't reference sources in the current scope.

unqualified_columns

Unqualified columns in the current scope.

Returns:

list[exp.Column]: Unqualified columns

join_hints

Hints that exist in the scope that reference tables

Returns:

list[exp.JoinHint]: Join hints that are referenced within the scope

pivots
def source_columns(self, source_name):
369    def source_columns(self, source_name):
370        """
371        Get all columns in the current scope for a particular source.
372
373        Args:
374            source_name (str): Name of the source
375        Returns:
376            list[exp.Column]: Column instances that reference `source_name`
377        """
378        return [column for column in self.columns if column.table == source_name]

Get all columns in the current scope for a particular source.

Arguments:
  • source_name (str): Name of the source
Returns:

list[exp.Column]: Column instances that reference source_name

is_subquery

Determine if this scope is a subquery

is_derived_table

Determine if this scope is a derived table

is_union

Determine if this scope is a union

is_cte

Determine if this scope is a common table expression

is_root

Determine if this is the root scope

is_udtf

Determine if this scope is a UDTF (User Defined Table Function)

is_correlated_subquery

Determine if this scope is a correlated subquery

def rename_source(self, old_name, new_name):
418    def rename_source(self, old_name, new_name):
419        """Rename a source in this scope"""
420        columns = self.sources.pop(old_name or "", [])
421        self.sources[new_name] = columns

Rename a source in this scope

def add_source(self, name, source):
423    def add_source(self, name, source):
424        """Add a source to this scope"""
425        self.sources[name] = source
426        self.clear_cache()

Add a source to this scope

def remove_source(self, name):
428    def remove_source(self, name):
429        """Remove a source from this scope"""
430        self.sources.pop(name, None)
431        self.clear_cache()

Remove a source from this scope

def traverse(self):
436    def traverse(self):
437        """
438        Traverse the scope tree from this node.
439
440        Yields:
441            Scope: scope instances in depth-first-search post-order
442        """
443        for child_scope in itertools.chain(
444            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
445        ):
446            yield from child_scope.traverse()
447        yield self

Traverse the scope tree from this node.

Yields:

Scope: scope instances in depth-first-search post-order

def ref_count(self):
449    def ref_count(self):
450        """
451        Count the number of times each scope in this tree is referenced.
452
453        Returns:
454            dict[int, int]: Mapping of Scope instance ID to reference count
455        """
456        scope_ref_count = defaultdict(lambda: 0)
457
458        for scope in self.traverse():
459            for _, source in scope.selected_sources.values():
460                scope_ref_count[id(source)] += 1
461
462        return scope_ref_count

Count the number of times each scope in this tree is referenced.

Returns:

dict[int, int]: Mapping of Scope instance ID to reference count

def traverse_scope( expression: sqlglot.expressions.Expression) -> List[Scope]:
465def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
466    """
467    Traverse an expression by its "scopes".
468
469    "Scope" represents the current context of a Select statement.
470
471    This is helpful for optimizing queries, where we need more information than
472    the expression tree itself. For example, we might care about the source
473    names within a subquery. Returns a list because a generator could result in
474    incomplete properties which is confusing.
475
476    Examples:
477        >>> import sqlglot
478        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
479        >>> scopes = traverse_scope(expression)
480        >>> scopes[0].expression.sql(), list(scopes[0].sources)
481        ('SELECT a FROM x', ['x'])
482        >>> scopes[1].expression.sql(), list(scopes[1].sources)
483        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
484
485    Args:
486        expression (exp.Expression): expression to traverse
487    Returns:
488        list[Scope]: scope instances
489    """
490    if isinstance(expression, exp.Unionable) or (
491        isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
492    ):
493        return list(_traverse_scope(Scope(expression)))
494
495    return []

Traverse an expression by its "scopes".

"Scope" represents the current context of a Select statement.

This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
  • expression (exp.Expression): expression to traverse
Returns:

list[Scope]: scope instances

def build_scope( expression: sqlglot.expressions.Expression) -> Optional[Scope]:
498def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
499    """
500    Build a scope tree.
501
502    Args:
503        expression (exp.Expression): expression to build the scope tree for
504    Returns:
505        Scope: root scope
506    """
507    scopes = traverse_scope(expression)
508    if scopes:
509        return scopes[-1]
510    return None

Build a scope tree.

Arguments:
  • expression (exp.Expression): expression to build the scope tree for
Returns:

Scope: root scope

def walk_in_scope(expression, bfs=True, prune=None):
737def walk_in_scope(expression, bfs=True, prune=None):
738    """
739    Returns a generator object which visits all nodes in the syntrax tree, stopping at
740    nodes that start child scopes.
741
742    Args:
743        expression (exp.Expression):
744        bfs (bool): if set to True the BFS traversal order will be applied,
745            otherwise the DFS traversal will be used instead.
746        prune ((node, parent, arg_key) -> bool): callable that returns True if
747            the generator should stop traversing this branch of the tree.
748
749    Yields:
750        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
751    """
752    # We'll use this variable to pass state into the dfs generator.
753    # Whenever we set it to True, we exclude a subtree from traversal.
754    crossed_scope_boundary = False
755
756    for node, parent, key in expression.walk(
757        bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
758    ):
759        crossed_scope_boundary = False
760
761        yield node, parent, key
762
763        if node is expression:
764            continue
765        if (
766            isinstance(node, exp.CTE)
767            or (
768                isinstance(node, exp.Subquery)
769                and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
770                and _is_derived_table(node)
771            )
772            or isinstance(node, exp.UDTF)
773            or isinstance(node, exp.Subqueryable)
774        ):
775            crossed_scope_boundary = True
776
777            if isinstance(node, (exp.Subquery, exp.UDTF)):
778                # The following args are not actually in the inner scope, so we should visit them
779                for key in ("joins", "laterals", "pivots"):
780                    for arg in node.args.get(key) or []:
781                        yield from walk_in_scope(arg, bfs=bfs)

Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.

Arguments:
  • expression (exp.Expression):
  • bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
  • prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:

tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key

def find_all_in_scope(expression, expression_types, bfs=True):
784def find_all_in_scope(expression, expression_types, bfs=True):
785    """
786    Returns a generator object which visits all nodes in this scope and only yields those that
787    match at least one of the specified expression types.
788
789    This does NOT traverse into subscopes.
790
791    Args:
792        expression (exp.Expression):
793        expression_types (tuple[type]|type): the expression type(s) to match.
794        bfs (bool): True to use breadth-first search, False to use depth-first.
795
796    Yields:
797        exp.Expression: nodes
798    """
799    for expression, *_ in walk_in_scope(expression, bfs=bfs):
800        if isinstance(expression, tuple(ensure_collection(expression_types))):
801            yield expression

Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:

exp.Expression: nodes

def find_in_scope(expression, expression_types, bfs=True):
804def find_in_scope(expression, expression_types, bfs=True):
805    """
806    Returns the first node in this scope which matches at least one of the specified types.
807
808    This does NOT traverse into subscopes.
809
810    Args:
811        expression (exp.Expression):
812        expression_types (tuple[type]|type): the expression type(s) to match.
813        bfs (bool): True to use breadth-first search, False to use depth-first.
814
815    Returns:
816        exp.Expression: the node which matches the criteria or None if no node matching
817        the criteria was found.
818    """
819    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)

Returns the first node in this scope which matches at least one of the specified types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:

exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.