sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name 12 13logger = logging.getLogger("sqlglot") 14 15 16class ScopeType(Enum): 17 ROOT = auto() 18 SUBQUERY = auto() 19 DERIVED_TABLE = auto() 20 CTE = auto() 21 UNION = auto() 22 UDTF = auto() 23 24 25class Scope: 26 """ 27 Selection scope. 28 29 Attributes: 30 expression (exp.Select|exp.Union): Root expression of this scope 31 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 32 a Table expression or another Scope instance. For example: 33 SELECT * FROM x {"x": Table(this="x")} 34 SELECT * FROM x AS y {"y": Table(this="x")} 35 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 36 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 37 For example: 38 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 39 The LATERAL VIEW EXPLODE gets x as a source. 40 cte_sources (dict[str, Scope]): Sources from CTES 41 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 42 defines a column list of it's alias of this scope, this is that list of columns. 43 For example: 44 SELECT * FROM (SELECT ...) AS y(col1, col2) 45 The inner query would have `["col1", "col2"]` for its `outer_column_list` 46 parent (Scope): Parent scope 47 scope_type (ScopeType): Type of this scope, relative to it's parent 48 subquery_scopes (list[Scope]): List of all child scopes for subqueries 49 cte_scopes (list[Scope]): List of all child scopes for CTEs 50 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 51 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 52 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 53 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 54 a list of the left and right child scopes. 55 """ 56 57 def __init__( 58 self, 59 expression, 60 sources=None, 61 outer_column_list=None, 62 parent=None, 63 scope_type=ScopeType.ROOT, 64 lateral_sources=None, 65 cte_sources=None, 66 ): 67 self.expression = expression 68 self.sources = sources or {} 69 self.lateral_sources = lateral_sources or {} 70 self.cte_sources = cte_sources or {} 71 self.sources.update(self.lateral_sources) 72 self.sources.update(self.cte_sources) 73 self.outer_column_list = outer_column_list or [] 74 self.parent = parent 75 self.scope_type = scope_type 76 self.subquery_scopes = [] 77 self.derived_table_scopes = [] 78 self.table_scopes = [] 79 self.cte_scopes = [] 80 self.union_scopes = [] 81 self.udtf_scopes = [] 82 self.clear_cache() 83 84 def clear_cache(self): 85 self._collected = False 86 self._raw_columns = None 87 self._derived_tables = None 88 self._udtfs = None 89 self._tables = None 90 self._ctes = None 91 self._subqueries = None 92 self._selected_sources = None 93 self._columns = None 94 self._external_columns = None 95 self._join_hints = None 96 self._pivots = None 97 self._references = None 98 99 def branch( 100 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 101 ): 102 """Branch from the current scope to a new, inner scope""" 103 return Scope( 104 expression=expression.unnest(), 105 sources=sources.copy() if sources else None, 106 parent=self, 107 scope_type=scope_type, 108 cte_sources={**self.cte_sources, **(cte_sources or {})}, 109 lateral_sources=lateral_sources.copy() if lateral_sources else None, 110 **kwargs, 111 ) 112 113 def _collect(self): 114 self._tables = [] 115 self._ctes = [] 116 self._subqueries = [] 117 self._derived_tables = [] 118 self._udtfs = [] 119 self._raw_columns = [] 120 self._join_hints = [] 121 122 for node, parent, _ in self.walk(bfs=False): 123 if node is self.expression: 124 continue 125 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 126 self._raw_columns.append(node) 127 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 128 self._tables.append(node) 129 elif isinstance(node, exp.JoinHint): 130 self._join_hints.append(node) 131 elif isinstance(node, exp.UDTF): 132 self._udtfs.append(node) 133 elif isinstance(node, exp.CTE): 134 self._ctes.append(node) 135 elif ( 136 isinstance(node, exp.Subquery) 137 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 138 and _is_derived_table(node) 139 ): 140 self._derived_tables.append(node) 141 elif isinstance(node, exp.Subqueryable): 142 self._subqueries.append(node) 143 144 self._collected = True 145 146 def _ensure_collected(self): 147 if not self._collected: 148 self._collect() 149 150 def walk(self, bfs=True, prune=None): 151 return walk_in_scope(self.expression, bfs=bfs, prune=None) 152 153 def find(self, *expression_types, bfs=True): 154 return find_in_scope(self.expression, expression_types, bfs=bfs) 155 156 def find_all(self, *expression_types, bfs=True): 157 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 158 159 def replace(self, old, new): 160 """ 161 Replace `old` with `new`. 162 163 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 164 165 Args: 166 old (exp.Expression): old node 167 new (exp.Expression): new node 168 """ 169 old.replace(new) 170 self.clear_cache() 171 172 @property 173 def tables(self): 174 """ 175 List of tables in this scope. 176 177 Returns: 178 list[exp.Table]: tables 179 """ 180 self._ensure_collected() 181 return self._tables 182 183 @property 184 def ctes(self): 185 """ 186 List of CTEs in this scope. 187 188 Returns: 189 list[exp.CTE]: ctes 190 """ 191 self._ensure_collected() 192 return self._ctes 193 194 @property 195 def derived_tables(self): 196 """ 197 List of derived tables in this scope. 198 199 For example: 200 SELECT * FROM (SELECT ...) <- that's a derived table 201 202 Returns: 203 list[exp.Subquery]: derived tables 204 """ 205 self._ensure_collected() 206 return self._derived_tables 207 208 @property 209 def udtfs(self): 210 """ 211 List of "User Defined Tabular Functions" in this scope. 212 213 Returns: 214 list[exp.UDTF]: UDTFs 215 """ 216 self._ensure_collected() 217 return self._udtfs 218 219 @property 220 def subqueries(self): 221 """ 222 List of subqueries in this scope. 223 224 For example: 225 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 226 227 Returns: 228 list[exp.Subqueryable]: subqueries 229 """ 230 self._ensure_collected() 231 return self._subqueries 232 233 @property 234 def columns(self): 235 """ 236 List of columns in this scope. 237 238 Returns: 239 list[exp.Column]: Column instances in this scope, plus any 240 Columns that reference this scope from correlated subqueries. 241 """ 242 if self._columns is None: 243 self._ensure_collected() 244 columns = self._raw_columns 245 246 external_columns = [ 247 column 248 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 249 for column in scope.external_columns 250 ] 251 252 named_selects = set(self.expression.named_selects) 253 254 self._columns = [] 255 for column in columns + external_columns: 256 ancestor = column.find_ancestor( 257 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 258 ) 259 if ( 260 not ancestor 261 or column.table 262 or isinstance(ancestor, exp.Select) 263 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 264 or ( 265 isinstance(ancestor, exp.Order) 266 and ( 267 isinstance(ancestor.parent, exp.Window) 268 or column.name not in named_selects 269 ) 270 ) 271 ): 272 self._columns.append(column) 273 274 return self._columns 275 276 @property 277 def selected_sources(self): 278 """ 279 Mapping of nodes and sources that are actually selected from in this scope. 280 281 That is, all tables in a schema are selectable at any point. But a 282 table only becomes a selected source if it's included in a FROM or JOIN clause. 283 284 Returns: 285 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 286 """ 287 if self._selected_sources is None: 288 result = {} 289 290 for name, node in self.references: 291 if name in result: 292 raise OptimizeError(f"Alias already used: {name}") 293 if name in self.sources: 294 result[name] = (node, self.sources[name]) 295 296 self._selected_sources = result 297 return self._selected_sources 298 299 @property 300 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 301 if self._references is None: 302 self._references = [] 303 304 for table in self.tables: 305 self._references.append((table.alias_or_name, table)) 306 for expression in itertools.chain(self.derived_tables, self.udtfs): 307 self._references.append( 308 ( 309 expression.alias, 310 expression if expression.args.get("pivots") else expression.unnest(), 311 ) 312 ) 313 314 return self._references 315 316 @property 317 def external_columns(self): 318 """ 319 Columns that appear to reference sources in outer scopes. 320 321 Returns: 322 list[exp.Column]: Column instances that don't reference 323 sources in the current scope. 324 """ 325 if self._external_columns is None: 326 self._external_columns = [ 327 c for c in self.columns if c.table not in self.selected_sources 328 ] 329 return self._external_columns 330 331 @property 332 def unqualified_columns(self): 333 """ 334 Unqualified columns in the current scope. 335 336 Returns: 337 list[exp.Column]: Unqualified columns 338 """ 339 return [c for c in self.columns if not c.table] 340 341 @property 342 def join_hints(self): 343 """ 344 Hints that exist in the scope that reference tables 345 346 Returns: 347 list[exp.JoinHint]: Join hints that are referenced within the scope 348 """ 349 if self._join_hints is None: 350 return [] 351 return self._join_hints 352 353 @property 354 def pivots(self): 355 if not self._pivots: 356 self._pivots = [ 357 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 358 ] 359 360 return self._pivots 361 362 def source_columns(self, source_name): 363 """ 364 Get all columns in the current scope for a particular source. 365 366 Args: 367 source_name (str): Name of the source 368 Returns: 369 list[exp.Column]: Column instances that reference `source_name` 370 """ 371 return [column for column in self.columns if column.table == source_name] 372 373 @property 374 def is_subquery(self): 375 """Determine if this scope is a subquery""" 376 return self.scope_type == ScopeType.SUBQUERY 377 378 @property 379 def is_derived_table(self): 380 """Determine if this scope is a derived table""" 381 return self.scope_type == ScopeType.DERIVED_TABLE 382 383 @property 384 def is_union(self): 385 """Determine if this scope is a union""" 386 return self.scope_type == ScopeType.UNION 387 388 @property 389 def is_cte(self): 390 """Determine if this scope is a common table expression""" 391 return self.scope_type == ScopeType.CTE 392 393 @property 394 def is_root(self): 395 """Determine if this is the root scope""" 396 return self.scope_type == ScopeType.ROOT 397 398 @property 399 def is_udtf(self): 400 """Determine if this scope is a UDTF (User Defined Table Function)""" 401 return self.scope_type == ScopeType.UDTF 402 403 @property 404 def is_correlated_subquery(self): 405 """Determine if this scope is a correlated subquery""" 406 return bool( 407 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 408 and self.external_columns 409 ) 410 411 def rename_source(self, old_name, new_name): 412 """Rename a source in this scope""" 413 columns = self.sources.pop(old_name or "", []) 414 self.sources[new_name] = columns 415 416 def add_source(self, name, source): 417 """Add a source to this scope""" 418 self.sources[name] = source 419 self.clear_cache() 420 421 def remove_source(self, name): 422 """Remove a source from this scope""" 423 self.sources.pop(name, None) 424 self.clear_cache() 425 426 def __repr__(self): 427 return f"Scope<{self.expression.sql()}>" 428 429 def traverse(self): 430 """ 431 Traverse the scope tree from this node. 432 433 Yields: 434 Scope: scope instances in depth-first-search post-order 435 """ 436 for child_scope in itertools.chain( 437 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 438 ): 439 yield from child_scope.traverse() 440 yield self 441 442 def ref_count(self): 443 """ 444 Count the number of times each scope in this tree is referenced. 445 446 Returns: 447 dict[int, int]: Mapping of Scope instance ID to reference count 448 """ 449 scope_ref_count = defaultdict(lambda: 0) 450 451 for scope in self.traverse(): 452 for _, source in scope.selected_sources.values(): 453 scope_ref_count[id(source)] += 1 454 455 return scope_ref_count 456 457 458def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 459 """ 460 Traverse an expression by its "scopes". 461 462 "Scope" represents the current context of a Select statement. 463 464 This is helpful for optimizing queries, where we need more information than 465 the expression tree itself. For example, we might care about the source 466 names within a subquery. Returns a list because a generator could result in 467 incomplete properties which is confusing. 468 469 Examples: 470 >>> import sqlglot 471 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 472 >>> scopes = traverse_scope(expression) 473 >>> scopes[0].expression.sql(), list(scopes[0].sources) 474 ('SELECT a FROM x', ['x']) 475 >>> scopes[1].expression.sql(), list(scopes[1].sources) 476 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 477 478 Args: 479 expression (exp.Expression): expression to traverse 480 Returns: 481 list[Scope]: scope instances 482 """ 483 if isinstance(expression, exp.Unionable) or ( 484 isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable) 485 ): 486 return list(_traverse_scope(Scope(expression))) 487 488 return [] 489 490 491def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 492 """ 493 Build a scope tree. 494 495 Args: 496 expression (exp.Expression): expression to build the scope tree for 497 Returns: 498 Scope: root scope 499 """ 500 scopes = traverse_scope(expression) 501 if scopes: 502 return scopes[-1] 503 return None 504 505 506def _traverse_scope(scope): 507 if isinstance(scope.expression, exp.Select): 508 yield from _traverse_select(scope) 509 elif isinstance(scope.expression, exp.Union): 510 yield from _traverse_union(scope) 511 elif isinstance(scope.expression, exp.Subquery): 512 if scope.is_root: 513 yield from _traverse_select(scope) 514 else: 515 yield from _traverse_subqueries(scope) 516 elif isinstance(scope.expression, exp.Table): 517 yield from _traverse_tables(scope) 518 elif isinstance(scope.expression, exp.UDTF): 519 yield from _traverse_udtfs(scope) 520 elif isinstance(scope.expression, exp.DDL): 521 yield from _traverse_ddl(scope) 522 else: 523 logger.warning( 524 "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) 525 ) 526 return 527 528 yield scope 529 530 531def _traverse_select(scope): 532 yield from _traverse_ctes(scope) 533 yield from _traverse_tables(scope) 534 yield from _traverse_subqueries(scope) 535 536 537def _traverse_union(scope): 538 yield from _traverse_ctes(scope) 539 540 # The last scope to be yield should be the top most scope 541 left = None 542 for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): 543 yield left 544 545 right = None 546 for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): 547 yield right 548 549 scope.union_scopes = [left, right] 550 551 552def _traverse_ctes(scope): 553 sources = {} 554 555 for cte in scope.ctes: 556 recursive_scope = None 557 558 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 559 # thus the recursive scope is the first section of the union. 560 with_ = scope.expression.args.get("with") 561 if with_ and with_.recursive: 562 union = cte.this 563 564 if isinstance(union, exp.Union): 565 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) 566 567 child_scope = None 568 569 for child_scope in _traverse_scope( 570 scope.branch( 571 cte.this, 572 cte_sources=sources, 573 outer_column_list=cte.alias_column_names, 574 scope_type=ScopeType.CTE, 575 ) 576 ): 577 yield child_scope 578 579 alias = cte.alias 580 sources[alias] = child_scope 581 582 if recursive_scope: 583 child_scope.add_source(alias, recursive_scope) 584 child_scope.cte_sources[alias] = recursive_scope 585 586 # append the final child_scope yielded 587 if child_scope: 588 scope.cte_scopes.append(child_scope) 589 590 scope.sources.update(sources) 591 scope.cte_sources.update(sources) 592 593 594def _is_derived_table(expression: exp.Subquery) -> bool: 595 """ 596 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 597 as it doesn't introduce a new scope. If an alias is present, it shadows all names 598 under the Subquery, so that's one exception to this rule. 599 """ 600 return bool(expression.alias or isinstance(expression.this, exp.Subqueryable)) 601 602 603def _traverse_tables(scope): 604 sources = {} 605 606 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 607 expressions = [] 608 from_ = scope.expression.args.get("from") 609 if from_: 610 expressions.append(from_.this) 611 612 for join in scope.expression.args.get("joins") or []: 613 expressions.append(join.this) 614 615 if isinstance(scope.expression, exp.Table): 616 expressions.append(scope.expression) 617 618 expressions.extend(scope.expression.args.get("laterals") or []) 619 620 for expression in expressions: 621 if isinstance(expression, exp.Table): 622 table_name = expression.name 623 source_name = expression.alias_or_name 624 625 if table_name in scope.sources and not expression.db: 626 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 627 # it is pivoted, because then we get back a new table and hence a new source. 628 pivots = expression.args.get("pivots") 629 if pivots: 630 sources[pivots[0].alias] = expression 631 else: 632 sources[source_name] = scope.sources[table_name] 633 elif source_name in sources: 634 sources[find_new_name(sources, table_name)] = expression 635 else: 636 sources[source_name] = expression 637 638 # Make sure to not include the joins twice 639 if expression is not scope.expression: 640 expressions.extend(join.this for join in expression.args.get("joins") or []) 641 642 continue 643 644 if not isinstance(expression, exp.DerivedTable): 645 continue 646 647 if isinstance(expression, exp.UDTF): 648 lateral_sources = sources 649 scope_type = ScopeType.UDTF 650 scopes = scope.udtf_scopes 651 elif _is_derived_table(expression): 652 lateral_sources = None 653 scope_type = ScopeType.DERIVED_TABLE 654 scopes = scope.derived_table_scopes 655 expressions.extend(join.this for join in expression.args.get("joins") or []) 656 else: 657 # Makes sure we check for possible sources in nested table constructs 658 expressions.append(expression.this) 659 expressions.extend(join.this for join in expression.args.get("joins") or []) 660 continue 661 662 for child_scope in _traverse_scope( 663 scope.branch( 664 expression, 665 lateral_sources=lateral_sources, 666 outer_column_list=expression.alias_column_names, 667 scope_type=scope_type, 668 ) 669 ): 670 yield child_scope 671 672 # Tables without aliases will be set as "" 673 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 674 # Until then, this means that only a single, unaliased derived table is allowed (rather, 675 # the latest one wins. 676 sources[expression.alias] = child_scope 677 678 # append the final child_scope yielded 679 scopes.append(child_scope) 680 scope.table_scopes.append(child_scope) 681 682 scope.sources.update(sources) 683 684 685def _traverse_subqueries(scope): 686 for subquery in scope.subqueries: 687 top = None 688 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 689 yield child_scope 690 top = child_scope 691 scope.subquery_scopes.append(top) 692 693 694def _traverse_udtfs(scope): 695 if isinstance(scope.expression, exp.Unnest): 696 expressions = scope.expression.expressions 697 elif isinstance(scope.expression, exp.Lateral): 698 expressions = [scope.expression.this] 699 else: 700 expressions = [] 701 702 sources = {} 703 for expression in expressions: 704 if isinstance(expression, exp.Subquery) and _is_derived_table(expression): 705 top = None 706 for child_scope in _traverse_scope( 707 scope.branch( 708 expression, 709 scope_type=ScopeType.DERIVED_TABLE, 710 outer_column_list=expression.alias_column_names, 711 ) 712 ): 713 yield child_scope 714 top = child_scope 715 sources[expression.alias] = child_scope 716 717 scope.derived_table_scopes.append(top) 718 scope.table_scopes.append(top) 719 720 scope.sources.update(sources) 721 722 723def _traverse_ddl(scope): 724 yield from _traverse_ctes(scope) 725 726 query_scope = scope.branch( 727 scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources 728 ) 729 query_scope._collect() 730 query_scope._ctes = scope.ctes + query_scope._ctes 731 732 yield from _traverse_scope(query_scope) 733 734 735def walk_in_scope(expression, bfs=True, prune=None): 736 """ 737 Returns a generator object which visits all nodes in the syntrax tree, stopping at 738 nodes that start child scopes. 739 740 Args: 741 expression (exp.Expression): 742 bfs (bool): if set to True the BFS traversal order will be applied, 743 otherwise the DFS traversal will be used instead. 744 prune ((node, parent, arg_key) -> bool): callable that returns True if 745 the generator should stop traversing this branch of the tree. 746 747 Yields: 748 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 749 """ 750 # We'll use this variable to pass state into the dfs generator. 751 # Whenever we set it to True, we exclude a subtree from traversal. 752 crossed_scope_boundary = False 753 754 for node, parent, key in expression.walk( 755 bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) 756 ): 757 crossed_scope_boundary = False 758 759 yield node, parent, key 760 761 if node is expression: 762 continue 763 if ( 764 isinstance(node, exp.CTE) 765 or ( 766 isinstance(node, exp.Subquery) 767 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 768 and _is_derived_table(node) 769 ) 770 or isinstance(node, exp.UDTF) 771 or isinstance(node, exp.Subqueryable) 772 ): 773 crossed_scope_boundary = True 774 775 if isinstance(node, (exp.Subquery, exp.UDTF)): 776 # The following args are not actually in the inner scope, so we should visit them 777 for key in ("joins", "laterals", "pivots"): 778 for arg in node.args.get(key) or []: 779 yield from walk_in_scope(arg, bfs=bfs) 780 781 782def find_all_in_scope(expression, expression_types, bfs=True): 783 """ 784 Returns a generator object which visits all nodes in this scope and only yields those that 785 match at least one of the specified expression types. 786 787 This does NOT traverse into subscopes. 788 789 Args: 790 expression (exp.Expression): 791 expression_types (tuple[type]|type): the expression type(s) to match. 792 bfs (bool): True to use breadth-first search, False to use depth-first. 793 794 Yields: 795 exp.Expression: nodes 796 """ 797 for expression, *_ in walk_in_scope(expression, bfs=bfs): 798 if isinstance(expression, tuple(ensure_collection(expression_types))): 799 yield expression 800 801 802def find_in_scope(expression, expression_types, bfs=True): 803 """ 804 Returns the first node in this scope which matches at least one of the specified types. 805 806 This does NOT traverse into subscopes. 807 808 Args: 809 expression (exp.Expression): 810 expression_types (tuple[type]|type): the expression type(s) to match. 811 bfs (bool): True to use breadth-first search, False to use depth-first. 812 813 Returns: 814 exp.Expression: the node which matches the criteria or None if no node matching 815 the criteria was found. 816 """ 817 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
17class ScopeType(Enum): 18 ROOT = auto() 19 SUBQUERY = auto() 20 DERIVED_TABLE = auto() 21 CTE = auto() 22 UNION = auto() 23 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
26class Scope: 27 """ 28 Selection scope. 29 30 Attributes: 31 expression (exp.Select|exp.Union): Root expression of this scope 32 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 33 a Table expression or another Scope instance. For example: 34 SELECT * FROM x {"x": Table(this="x")} 35 SELECT * FROM x AS y {"y": Table(this="x")} 36 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 37 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 38 For example: 39 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 40 The LATERAL VIEW EXPLODE gets x as a source. 41 cte_sources (dict[str, Scope]): Sources from CTES 42 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 43 defines a column list of it's alias of this scope, this is that list of columns. 44 For example: 45 SELECT * FROM (SELECT ...) AS y(col1, col2) 46 The inner query would have `["col1", "col2"]` for its `outer_column_list` 47 parent (Scope): Parent scope 48 scope_type (ScopeType): Type of this scope, relative to it's parent 49 subquery_scopes (list[Scope]): List of all child scopes for subqueries 50 cte_scopes (list[Scope]): List of all child scopes for CTEs 51 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 52 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 53 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 54 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 55 a list of the left and right child scopes. 56 """ 57 58 def __init__( 59 self, 60 expression, 61 sources=None, 62 outer_column_list=None, 63 parent=None, 64 scope_type=ScopeType.ROOT, 65 lateral_sources=None, 66 cte_sources=None, 67 ): 68 self.expression = expression 69 self.sources = sources or {} 70 self.lateral_sources = lateral_sources or {} 71 self.cte_sources = cte_sources or {} 72 self.sources.update(self.lateral_sources) 73 self.sources.update(self.cte_sources) 74 self.outer_column_list = outer_column_list or [] 75 self.parent = parent 76 self.scope_type = scope_type 77 self.subquery_scopes = [] 78 self.derived_table_scopes = [] 79 self.table_scopes = [] 80 self.cte_scopes = [] 81 self.union_scopes = [] 82 self.udtf_scopes = [] 83 self.clear_cache() 84 85 def clear_cache(self): 86 self._collected = False 87 self._raw_columns = None 88 self._derived_tables = None 89 self._udtfs = None 90 self._tables = None 91 self._ctes = None 92 self._subqueries = None 93 self._selected_sources = None 94 self._columns = None 95 self._external_columns = None 96 self._join_hints = None 97 self._pivots = None 98 self._references = None 99 100 def branch( 101 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 102 ): 103 """Branch from the current scope to a new, inner scope""" 104 return Scope( 105 expression=expression.unnest(), 106 sources=sources.copy() if sources else None, 107 parent=self, 108 scope_type=scope_type, 109 cte_sources={**self.cte_sources, **(cte_sources or {})}, 110 lateral_sources=lateral_sources.copy() if lateral_sources else None, 111 **kwargs, 112 ) 113 114 def _collect(self): 115 self._tables = [] 116 self._ctes = [] 117 self._subqueries = [] 118 self._derived_tables = [] 119 self._udtfs = [] 120 self._raw_columns = [] 121 self._join_hints = [] 122 123 for node, parent, _ in self.walk(bfs=False): 124 if node is self.expression: 125 continue 126 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 127 self._raw_columns.append(node) 128 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 129 self._tables.append(node) 130 elif isinstance(node, exp.JoinHint): 131 self._join_hints.append(node) 132 elif isinstance(node, exp.UDTF): 133 self._udtfs.append(node) 134 elif isinstance(node, exp.CTE): 135 self._ctes.append(node) 136 elif ( 137 isinstance(node, exp.Subquery) 138 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 139 and _is_derived_table(node) 140 ): 141 self._derived_tables.append(node) 142 elif isinstance(node, exp.Subqueryable): 143 self._subqueries.append(node) 144 145 self._collected = True 146 147 def _ensure_collected(self): 148 if not self._collected: 149 self._collect() 150 151 def walk(self, bfs=True, prune=None): 152 return walk_in_scope(self.expression, bfs=bfs, prune=None) 153 154 def find(self, *expression_types, bfs=True): 155 return find_in_scope(self.expression, expression_types, bfs=bfs) 156 157 def find_all(self, *expression_types, bfs=True): 158 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 159 160 def replace(self, old, new): 161 """ 162 Replace `old` with `new`. 163 164 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 165 166 Args: 167 old (exp.Expression): old node 168 new (exp.Expression): new node 169 """ 170 old.replace(new) 171 self.clear_cache() 172 173 @property 174 def tables(self): 175 """ 176 List of tables in this scope. 177 178 Returns: 179 list[exp.Table]: tables 180 """ 181 self._ensure_collected() 182 return self._tables 183 184 @property 185 def ctes(self): 186 """ 187 List of CTEs in this scope. 188 189 Returns: 190 list[exp.CTE]: ctes 191 """ 192 self._ensure_collected() 193 return self._ctes 194 195 @property 196 def derived_tables(self): 197 """ 198 List of derived tables in this scope. 199 200 For example: 201 SELECT * FROM (SELECT ...) <- that's a derived table 202 203 Returns: 204 list[exp.Subquery]: derived tables 205 """ 206 self._ensure_collected() 207 return self._derived_tables 208 209 @property 210 def udtfs(self): 211 """ 212 List of "User Defined Tabular Functions" in this scope. 213 214 Returns: 215 list[exp.UDTF]: UDTFs 216 """ 217 self._ensure_collected() 218 return self._udtfs 219 220 @property 221 def subqueries(self): 222 """ 223 List of subqueries in this scope. 224 225 For example: 226 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 227 228 Returns: 229 list[exp.Subqueryable]: subqueries 230 """ 231 self._ensure_collected() 232 return self._subqueries 233 234 @property 235 def columns(self): 236 """ 237 List of columns in this scope. 238 239 Returns: 240 list[exp.Column]: Column instances in this scope, plus any 241 Columns that reference this scope from correlated subqueries. 242 """ 243 if self._columns is None: 244 self._ensure_collected() 245 columns = self._raw_columns 246 247 external_columns = [ 248 column 249 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 250 for column in scope.external_columns 251 ] 252 253 named_selects = set(self.expression.named_selects) 254 255 self._columns = [] 256 for column in columns + external_columns: 257 ancestor = column.find_ancestor( 258 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 259 ) 260 if ( 261 not ancestor 262 or column.table 263 or isinstance(ancestor, exp.Select) 264 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 265 or ( 266 isinstance(ancestor, exp.Order) 267 and ( 268 isinstance(ancestor.parent, exp.Window) 269 or column.name not in named_selects 270 ) 271 ) 272 ): 273 self._columns.append(column) 274 275 return self._columns 276 277 @property 278 def selected_sources(self): 279 """ 280 Mapping of nodes and sources that are actually selected from in this scope. 281 282 That is, all tables in a schema are selectable at any point. But a 283 table only becomes a selected source if it's included in a FROM or JOIN clause. 284 285 Returns: 286 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 287 """ 288 if self._selected_sources is None: 289 result = {} 290 291 for name, node in self.references: 292 if name in result: 293 raise OptimizeError(f"Alias already used: {name}") 294 if name in self.sources: 295 result[name] = (node, self.sources[name]) 296 297 self._selected_sources = result 298 return self._selected_sources 299 300 @property 301 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 302 if self._references is None: 303 self._references = [] 304 305 for table in self.tables: 306 self._references.append((table.alias_or_name, table)) 307 for expression in itertools.chain(self.derived_tables, self.udtfs): 308 self._references.append( 309 ( 310 expression.alias, 311 expression if expression.args.get("pivots") else expression.unnest(), 312 ) 313 ) 314 315 return self._references 316 317 @property 318 def external_columns(self): 319 """ 320 Columns that appear to reference sources in outer scopes. 321 322 Returns: 323 list[exp.Column]: Column instances that don't reference 324 sources in the current scope. 325 """ 326 if self._external_columns is None: 327 self._external_columns = [ 328 c for c in self.columns if c.table not in self.selected_sources 329 ] 330 return self._external_columns 331 332 @property 333 def unqualified_columns(self): 334 """ 335 Unqualified columns in the current scope. 336 337 Returns: 338 list[exp.Column]: Unqualified columns 339 """ 340 return [c for c in self.columns if not c.table] 341 342 @property 343 def join_hints(self): 344 """ 345 Hints that exist in the scope that reference tables 346 347 Returns: 348 list[exp.JoinHint]: Join hints that are referenced within the scope 349 """ 350 if self._join_hints is None: 351 return [] 352 return self._join_hints 353 354 @property 355 def pivots(self): 356 if not self._pivots: 357 self._pivots = [ 358 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 359 ] 360 361 return self._pivots 362 363 def source_columns(self, source_name): 364 """ 365 Get all columns in the current scope for a particular source. 366 367 Args: 368 source_name (str): Name of the source 369 Returns: 370 list[exp.Column]: Column instances that reference `source_name` 371 """ 372 return [column for column in self.columns if column.table == source_name] 373 374 @property 375 def is_subquery(self): 376 """Determine if this scope is a subquery""" 377 return self.scope_type == ScopeType.SUBQUERY 378 379 @property 380 def is_derived_table(self): 381 """Determine if this scope is a derived table""" 382 return self.scope_type == ScopeType.DERIVED_TABLE 383 384 @property 385 def is_union(self): 386 """Determine if this scope is a union""" 387 return self.scope_type == ScopeType.UNION 388 389 @property 390 def is_cte(self): 391 """Determine if this scope is a common table expression""" 392 return self.scope_type == ScopeType.CTE 393 394 @property 395 def is_root(self): 396 """Determine if this is the root scope""" 397 return self.scope_type == ScopeType.ROOT 398 399 @property 400 def is_udtf(self): 401 """Determine if this scope is a UDTF (User Defined Table Function)""" 402 return self.scope_type == ScopeType.UDTF 403 404 @property 405 def is_correlated_subquery(self): 406 """Determine if this scope is a correlated subquery""" 407 return bool( 408 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 409 and self.external_columns 410 ) 411 412 def rename_source(self, old_name, new_name): 413 """Rename a source in this scope""" 414 columns = self.sources.pop(old_name or "", []) 415 self.sources[new_name] = columns 416 417 def add_source(self, name, source): 418 """Add a source to this scope""" 419 self.sources[name] = source 420 self.clear_cache() 421 422 def remove_source(self, name): 423 """Remove a source from this scope""" 424 self.sources.pop(name, None) 425 self.clear_cache() 426 427 def __repr__(self): 428 return f"Scope<{self.expression.sql()}>" 429 430 def traverse(self): 431 """ 432 Traverse the scope tree from this node. 433 434 Yields: 435 Scope: scope instances in depth-first-search post-order 436 """ 437 for child_scope in itertools.chain( 438 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 439 ): 440 yield from child_scope.traverse() 441 yield self 442 443 def ref_count(self): 444 """ 445 Count the number of times each scope in this tree is referenced. 446 447 Returns: 448 dict[int, int]: Mapping of Scope instance ID to reference count 449 """ 450 scope_ref_count = defaultdict(lambda: 0) 451 452 for scope in self.traverse(): 453 for _, source in scope.selected_sources.values(): 454 scope_ref_count[id(source)] += 1 455 456 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_column_list
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
58 def __init__( 59 self, 60 expression, 61 sources=None, 62 outer_column_list=None, 63 parent=None, 64 scope_type=ScopeType.ROOT, 65 lateral_sources=None, 66 cte_sources=None, 67 ): 68 self.expression = expression 69 self.sources = sources or {} 70 self.lateral_sources = lateral_sources or {} 71 self.cte_sources = cte_sources or {} 72 self.sources.update(self.lateral_sources) 73 self.sources.update(self.cte_sources) 74 self.outer_column_list = outer_column_list or [] 75 self.parent = parent 76 self.scope_type = scope_type 77 self.subquery_scopes = [] 78 self.derived_table_scopes = [] 79 self.table_scopes = [] 80 self.cte_scopes = [] 81 self.union_scopes = [] 82 self.udtf_scopes = [] 83 self.clear_cache()
85 def clear_cache(self): 86 self._collected = False 87 self._raw_columns = None 88 self._derived_tables = None 89 self._udtfs = None 90 self._tables = None 91 self._ctes = None 92 self._subqueries = None 93 self._selected_sources = None 94 self._columns = None 95 self._external_columns = None 96 self._join_hints = None 97 self._pivots = None 98 self._references = None
100 def branch( 101 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 102 ): 103 """Branch from the current scope to a new, inner scope""" 104 return Scope( 105 expression=expression.unnest(), 106 sources=sources.copy() if sources else None, 107 parent=self, 108 scope_type=scope_type, 109 cte_sources={**self.cte_sources, **(cte_sources or {})}, 110 lateral_sources=lateral_sources.copy() if lateral_sources else None, 111 **kwargs, 112 )
Branch from the current scope to a new, inner scope
160 def replace(self, old, new): 161 """ 162 Replace `old` with `new`. 163 164 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 165 166 Args: 167 old (exp.Expression): old node 168 new (exp.Expression): new node 169 """ 170 old.replace(new) 171 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
173 @property 174 def tables(self): 175 """ 176 List of tables in this scope. 177 178 Returns: 179 list[exp.Table]: tables 180 """ 181 self._ensure_collected() 182 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
184 @property 185 def ctes(self): 186 """ 187 List of CTEs in this scope. 188 189 Returns: 190 list[exp.CTE]: ctes 191 """ 192 self._ensure_collected() 193 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
195 @property 196 def derived_tables(self): 197 """ 198 List of derived tables in this scope. 199 200 For example: 201 SELECT * FROM (SELECT ...) <- that's a derived table 202 203 Returns: 204 list[exp.Subquery]: derived tables 205 """ 206 self._ensure_collected() 207 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
209 @property 210 def udtfs(self): 211 """ 212 List of "User Defined Tabular Functions" in this scope. 213 214 Returns: 215 list[exp.UDTF]: UDTFs 216 """ 217 self._ensure_collected() 218 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
220 @property 221 def subqueries(self): 222 """ 223 List of subqueries in this scope. 224 225 For example: 226 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 227 228 Returns: 229 list[exp.Subqueryable]: subqueries 230 """ 231 self._ensure_collected() 232 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
234 @property 235 def columns(self): 236 """ 237 List of columns in this scope. 238 239 Returns: 240 list[exp.Column]: Column instances in this scope, plus any 241 Columns that reference this scope from correlated subqueries. 242 """ 243 if self._columns is None: 244 self._ensure_collected() 245 columns = self._raw_columns 246 247 external_columns = [ 248 column 249 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 250 for column in scope.external_columns 251 ] 252 253 named_selects = set(self.expression.named_selects) 254 255 self._columns = [] 256 for column in columns + external_columns: 257 ancestor = column.find_ancestor( 258 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 259 ) 260 if ( 261 not ancestor 262 or column.table 263 or isinstance(ancestor, exp.Select) 264 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 265 or ( 266 isinstance(ancestor, exp.Order) 267 and ( 268 isinstance(ancestor.parent, exp.Window) 269 or column.name not in named_selects 270 ) 271 ) 272 ): 273 self._columns.append(column) 274 275 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
277 @property 278 def selected_sources(self): 279 """ 280 Mapping of nodes and sources that are actually selected from in this scope. 281 282 That is, all tables in a schema are selectable at any point. But a 283 table only becomes a selected source if it's included in a FROM or JOIN clause. 284 285 Returns: 286 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 287 """ 288 if self._selected_sources is None: 289 result = {} 290 291 for name, node in self.references: 292 if name in result: 293 raise OptimizeError(f"Alias already used: {name}") 294 if name in self.sources: 295 result[name] = (node, self.sources[name]) 296 297 self._selected_sources = result 298 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
300 @property 301 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 302 if self._references is None: 303 self._references = [] 304 305 for table in self.tables: 306 self._references.append((table.alias_or_name, table)) 307 for expression in itertools.chain(self.derived_tables, self.udtfs): 308 self._references.append( 309 ( 310 expression.alias, 311 expression if expression.args.get("pivots") else expression.unnest(), 312 ) 313 ) 314 315 return self._references
317 @property 318 def external_columns(self): 319 """ 320 Columns that appear to reference sources in outer scopes. 321 322 Returns: 323 list[exp.Column]: Column instances that don't reference 324 sources in the current scope. 325 """ 326 if self._external_columns is None: 327 self._external_columns = [ 328 c for c in self.columns if c.table not in self.selected_sources 329 ] 330 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
332 @property 333 def unqualified_columns(self): 334 """ 335 Unqualified columns in the current scope. 336 337 Returns: 338 list[exp.Column]: Unqualified columns 339 """ 340 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
342 @property 343 def join_hints(self): 344 """ 345 Hints that exist in the scope that reference tables 346 347 Returns: 348 list[exp.JoinHint]: Join hints that are referenced within the scope 349 """ 350 if self._join_hints is None: 351 return [] 352 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
363 def source_columns(self, source_name): 364 """ 365 Get all columns in the current scope for a particular source. 366 367 Args: 368 source_name (str): Name of the source 369 Returns: 370 list[exp.Column]: Column instances that reference `source_name` 371 """ 372 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
374 @property 375 def is_subquery(self): 376 """Determine if this scope is a subquery""" 377 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
379 @property 380 def is_derived_table(self): 381 """Determine if this scope is a derived table""" 382 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
384 @property 385 def is_union(self): 386 """Determine if this scope is a union""" 387 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
389 @property 390 def is_cte(self): 391 """Determine if this scope is a common table expression""" 392 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
394 @property 395 def is_root(self): 396 """Determine if this is the root scope""" 397 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
399 @property 400 def is_udtf(self): 401 """Determine if this scope is a UDTF (User Defined Table Function)""" 402 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
412 def rename_source(self, old_name, new_name): 413 """Rename a source in this scope""" 414 columns = self.sources.pop(old_name or "", []) 415 self.sources[new_name] = columns
Rename a source in this scope
417 def add_source(self, name, source): 418 """Add a source to this scope""" 419 self.sources[name] = source 420 self.clear_cache()
Add a source to this scope
422 def remove_source(self, name): 423 """Remove a source from this scope""" 424 self.sources.pop(name, None) 425 self.clear_cache()
Remove a source from this scope
430 def traverse(self): 431 """ 432 Traverse the scope tree from this node. 433 434 Yields: 435 Scope: scope instances in depth-first-search post-order 436 """ 437 for child_scope in itertools.chain( 438 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 439 ): 440 yield from child_scope.traverse() 441 yield self
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
443 def ref_count(self): 444 """ 445 Count the number of times each scope in this tree is referenced. 446 447 Returns: 448 dict[int, int]: Mapping of Scope instance ID to reference count 449 """ 450 scope_ref_count = defaultdict(lambda: 0) 451 452 for scope in self.traverse(): 453 for _, source in scope.selected_sources.values(): 454 scope_ref_count[id(source)] += 1 455 456 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
459def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 460 """ 461 Traverse an expression by its "scopes". 462 463 "Scope" represents the current context of a Select statement. 464 465 This is helpful for optimizing queries, where we need more information than 466 the expression tree itself. For example, we might care about the source 467 names within a subquery. Returns a list because a generator could result in 468 incomplete properties which is confusing. 469 470 Examples: 471 >>> import sqlglot 472 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 473 >>> scopes = traverse_scope(expression) 474 >>> scopes[0].expression.sql(), list(scopes[0].sources) 475 ('SELECT a FROM x', ['x']) 476 >>> scopes[1].expression.sql(), list(scopes[1].sources) 477 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 478 479 Args: 480 expression (exp.Expression): expression to traverse 481 Returns: 482 list[Scope]: scope instances 483 """ 484 if isinstance(expression, exp.Unionable) or ( 485 isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable) 486 ): 487 return list(_traverse_scope(Scope(expression))) 488 489 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
492def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 493 """ 494 Build a scope tree. 495 496 Args: 497 expression (exp.Expression): expression to build the scope tree for 498 Returns: 499 Scope: root scope 500 """ 501 scopes = traverse_scope(expression) 502 if scopes: 503 return scopes[-1] 504 return None
Build a scope tree.
Arguments:
- expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
736def walk_in_scope(expression, bfs=True, prune=None): 737 """ 738 Returns a generator object which visits all nodes in the syntrax tree, stopping at 739 nodes that start child scopes. 740 741 Args: 742 expression (exp.Expression): 743 bfs (bool): if set to True the BFS traversal order will be applied, 744 otherwise the DFS traversal will be used instead. 745 prune ((node, parent, arg_key) -> bool): callable that returns True if 746 the generator should stop traversing this branch of the tree. 747 748 Yields: 749 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 750 """ 751 # We'll use this variable to pass state into the dfs generator. 752 # Whenever we set it to True, we exclude a subtree from traversal. 753 crossed_scope_boundary = False 754 755 for node, parent, key in expression.walk( 756 bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) 757 ): 758 crossed_scope_boundary = False 759 760 yield node, parent, key 761 762 if node is expression: 763 continue 764 if ( 765 isinstance(node, exp.CTE) 766 or ( 767 isinstance(node, exp.Subquery) 768 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 769 and _is_derived_table(node) 770 ) 771 or isinstance(node, exp.UDTF) 772 or isinstance(node, exp.Subqueryable) 773 ): 774 crossed_scope_boundary = True 775 776 if isinstance(node, (exp.Subquery, exp.UDTF)): 777 # The following args are not actually in the inner scope, so we should visit them 778 for key in ("joins", "laterals", "pivots"): 779 for arg in node.args.get(key) or []: 780 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
783def find_all_in_scope(expression, expression_types, bfs=True): 784 """ 785 Returns a generator object which visits all nodes in this scope and only yields those that 786 match at least one of the specified expression types. 787 788 This does NOT traverse into subscopes. 789 790 Args: 791 expression (exp.Expression): 792 expression_types (tuple[type]|type): the expression type(s) to match. 793 bfs (bool): True to use breadth-first search, False to use depth-first. 794 795 Yields: 796 exp.Expression: nodes 797 """ 798 for expression, *_ in walk_in_scope(expression, bfs=bfs): 799 if isinstance(expression, tuple(ensure_collection(expression_types))): 800 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
803def find_in_scope(expression, expression_types, bfs=True): 804 """ 805 Returns the first node in this scope which matches at least one of the specified types. 806 807 This does NOT traverse into subscopes. 808 809 Args: 810 expression (exp.Expression): 811 expression_types (tuple[type]|type): the expression type(s) to match. 812 bfs (bool): True to use breadth-first search, False to use depth-first. 813 814 Returns: 815 exp.Expression: the node which matches the criteria or None if no node matching 816 the criteria was found. 817 """ 818 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.