sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name, seq_get 12 13logger = logging.getLogger("sqlglot") 14 15 16class ScopeType(Enum): 17 ROOT = auto() 18 SUBQUERY = auto() 19 DERIVED_TABLE = auto() 20 CTE = auto() 21 UNION = auto() 22 UDTF = auto() 23 24 25class Scope: 26 """ 27 Selection scope. 28 29 Attributes: 30 expression (exp.Select|exp.Union): Root expression of this scope 31 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 32 a Table expression or another Scope instance. For example: 33 SELECT * FROM x {"x": Table(this="x")} 34 SELECT * FROM x AS y {"y": Table(this="x")} 35 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 36 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 37 For example: 38 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 39 The LATERAL VIEW EXPLODE gets x as a source. 40 cte_sources (dict[str, Scope]): Sources from CTES 41 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 42 defines a column list of it's alias of this scope, this is that list of columns. 43 For example: 44 SELECT * FROM (SELECT ...) AS y(col1, col2) 45 The inner query would have `["col1", "col2"]` for its `outer_column_list` 46 parent (Scope): Parent scope 47 scope_type (ScopeType): Type of this scope, relative to it's parent 48 subquery_scopes (list[Scope]): List of all child scopes for subqueries 49 cte_scopes (list[Scope]): List of all child scopes for CTEs 50 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 51 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 52 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 53 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 54 a list of the left and right child scopes. 55 """ 56 57 def __init__( 58 self, 59 expression, 60 sources=None, 61 outer_column_list=None, 62 parent=None, 63 scope_type=ScopeType.ROOT, 64 lateral_sources=None, 65 cte_sources=None, 66 ): 67 self.expression = expression 68 self.sources = sources or {} 69 self.lateral_sources = lateral_sources or {} 70 self.cte_sources = cte_sources or {} 71 self.sources.update(self.lateral_sources) 72 self.sources.update(self.cte_sources) 73 self.outer_column_list = outer_column_list or [] 74 self.parent = parent 75 self.scope_type = scope_type 76 self.subquery_scopes = [] 77 self.derived_table_scopes = [] 78 self.table_scopes = [] 79 self.cte_scopes = [] 80 self.union_scopes = [] 81 self.udtf_scopes = [] 82 self.clear_cache() 83 84 def clear_cache(self): 85 self._collected = False 86 self._raw_columns = None 87 self._derived_tables = None 88 self._udtfs = None 89 self._tables = None 90 self._ctes = None 91 self._subqueries = None 92 self._selected_sources = None 93 self._columns = None 94 self._external_columns = None 95 self._join_hints = None 96 self._pivots = None 97 self._references = None 98 99 def branch( 100 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 101 ): 102 """Branch from the current scope to a new, inner scope""" 103 return Scope( 104 expression=expression.unnest(), 105 sources=sources.copy() if sources else None, 106 parent=self, 107 scope_type=scope_type, 108 cte_sources={**self.cte_sources, **(cte_sources or {})}, 109 lateral_sources=lateral_sources.copy() if lateral_sources else None, 110 **kwargs, 111 ) 112 113 def _collect(self): 114 self._tables = [] 115 self._ctes = [] 116 self._subqueries = [] 117 self._derived_tables = [] 118 self._udtfs = [] 119 self._raw_columns = [] 120 self._join_hints = [] 121 122 for node, parent, _ in self.walk(bfs=False): 123 if node is self.expression: 124 continue 125 126 if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 127 self._raw_columns.append(node) 128 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 129 self._tables.append(node) 130 elif isinstance(node, exp.JoinHint): 131 self._join_hints.append(node) 132 elif isinstance(node, exp.UDTF): 133 self._udtfs.append(node) 134 elif isinstance(node, exp.CTE): 135 self._ctes.append(node) 136 elif _is_derived_table(node) and isinstance(parent, (exp.From, exp.Join, exp.Subquery)): 137 self._derived_tables.append(node) 138 elif isinstance(node, exp.UNWRAPPED_QUERIES): 139 self._subqueries.append(node) 140 141 self._collected = True 142 143 def _ensure_collected(self): 144 if not self._collected: 145 self._collect() 146 147 def walk(self, bfs=True, prune=None): 148 return walk_in_scope(self.expression, bfs=bfs, prune=None) 149 150 def find(self, *expression_types, bfs=True): 151 return find_in_scope(self.expression, expression_types, bfs=bfs) 152 153 def find_all(self, *expression_types, bfs=True): 154 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 155 156 def replace(self, old, new): 157 """ 158 Replace `old` with `new`. 159 160 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 161 162 Args: 163 old (exp.Expression): old node 164 new (exp.Expression): new node 165 """ 166 old.replace(new) 167 self.clear_cache() 168 169 @property 170 def tables(self): 171 """ 172 List of tables in this scope. 173 174 Returns: 175 list[exp.Table]: tables 176 """ 177 self._ensure_collected() 178 return self._tables 179 180 @property 181 def ctes(self): 182 """ 183 List of CTEs in this scope. 184 185 Returns: 186 list[exp.CTE]: ctes 187 """ 188 self._ensure_collected() 189 return self._ctes 190 191 @property 192 def derived_tables(self): 193 """ 194 List of derived tables in this scope. 195 196 For example: 197 SELECT * FROM (SELECT ...) <- that's a derived table 198 199 Returns: 200 list[exp.Subquery]: derived tables 201 """ 202 self._ensure_collected() 203 return self._derived_tables 204 205 @property 206 def udtfs(self): 207 """ 208 List of "User Defined Tabular Functions" in this scope. 209 210 Returns: 211 list[exp.UDTF]: UDTFs 212 """ 213 self._ensure_collected() 214 return self._udtfs 215 216 @property 217 def subqueries(self): 218 """ 219 List of subqueries in this scope. 220 221 For example: 222 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 223 224 Returns: 225 list[exp.Select | exp.Union]: subqueries 226 """ 227 self._ensure_collected() 228 return self._subqueries 229 230 @property 231 def columns(self): 232 """ 233 List of columns in this scope. 234 235 Returns: 236 list[exp.Column]: Column instances in this scope, plus any 237 Columns that reference this scope from correlated subqueries. 238 """ 239 if self._columns is None: 240 self._ensure_collected() 241 columns = self._raw_columns 242 243 external_columns = [ 244 column 245 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 246 for column in scope.external_columns 247 ] 248 249 named_selects = set(self.expression.named_selects) 250 251 self._columns = [] 252 for column in columns + external_columns: 253 ancestor = column.find_ancestor( 254 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 255 ) 256 if ( 257 not ancestor 258 or column.table 259 or isinstance(ancestor, exp.Select) 260 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 261 or ( 262 isinstance(ancestor, exp.Order) 263 and ( 264 isinstance(ancestor.parent, exp.Window) 265 or column.name not in named_selects 266 ) 267 ) 268 ): 269 self._columns.append(column) 270 271 return self._columns 272 273 @property 274 def selected_sources(self): 275 """ 276 Mapping of nodes and sources that are actually selected from in this scope. 277 278 That is, all tables in a schema are selectable at any point. But a 279 table only becomes a selected source if it's included in a FROM or JOIN clause. 280 281 Returns: 282 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 283 """ 284 if self._selected_sources is None: 285 result = {} 286 287 for name, node in self.references: 288 if name in result: 289 raise OptimizeError(f"Alias already used: {name}") 290 if name in self.sources: 291 result[name] = (node, self.sources[name]) 292 293 self._selected_sources = result 294 return self._selected_sources 295 296 @property 297 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 298 if self._references is None: 299 self._references = [] 300 301 for table in self.tables: 302 self._references.append((table.alias_or_name, table)) 303 for expression in itertools.chain(self.derived_tables, self.udtfs): 304 self._references.append( 305 ( 306 expression.alias, 307 expression if expression.args.get("pivots") else expression.unnest(), 308 ) 309 ) 310 311 return self._references 312 313 @property 314 def external_columns(self): 315 """ 316 Columns that appear to reference sources in outer scopes. 317 318 Returns: 319 list[exp.Column]: Column instances that don't reference 320 sources in the current scope. 321 """ 322 if self._external_columns is None: 323 if isinstance(self.expression, exp.Union): 324 left, right = self.union_scopes 325 self._external_columns = left.external_columns + right.external_columns 326 else: 327 self._external_columns = [ 328 c for c in self.columns if c.table not in self.selected_sources 329 ] 330 331 return self._external_columns 332 333 @property 334 def unqualified_columns(self): 335 """ 336 Unqualified columns in the current scope. 337 338 Returns: 339 list[exp.Column]: Unqualified columns 340 """ 341 return [c for c in self.columns if not c.table] 342 343 @property 344 def join_hints(self): 345 """ 346 Hints that exist in the scope that reference tables 347 348 Returns: 349 list[exp.JoinHint]: Join hints that are referenced within the scope 350 """ 351 if self._join_hints is None: 352 return [] 353 return self._join_hints 354 355 @property 356 def pivots(self): 357 if not self._pivots: 358 self._pivots = [ 359 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 360 ] 361 362 return self._pivots 363 364 def source_columns(self, source_name): 365 """ 366 Get all columns in the current scope for a particular source. 367 368 Args: 369 source_name (str): Name of the source 370 Returns: 371 list[exp.Column]: Column instances that reference `source_name` 372 """ 373 return [column for column in self.columns if column.table == source_name] 374 375 @property 376 def is_subquery(self): 377 """Determine if this scope is a subquery""" 378 return self.scope_type == ScopeType.SUBQUERY 379 380 @property 381 def is_derived_table(self): 382 """Determine if this scope is a derived table""" 383 return self.scope_type == ScopeType.DERIVED_TABLE 384 385 @property 386 def is_union(self): 387 """Determine if this scope is a union""" 388 return self.scope_type == ScopeType.UNION 389 390 @property 391 def is_cte(self): 392 """Determine if this scope is a common table expression""" 393 return self.scope_type == ScopeType.CTE 394 395 @property 396 def is_root(self): 397 """Determine if this is the root scope""" 398 return self.scope_type == ScopeType.ROOT 399 400 @property 401 def is_udtf(self): 402 """Determine if this scope is a UDTF (User Defined Table Function)""" 403 return self.scope_type == ScopeType.UDTF 404 405 @property 406 def is_correlated_subquery(self): 407 """Determine if this scope is a correlated subquery""" 408 return bool( 409 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 410 and self.external_columns 411 ) 412 413 def rename_source(self, old_name, new_name): 414 """Rename a source in this scope""" 415 columns = self.sources.pop(old_name or "", []) 416 self.sources[new_name] = columns 417 418 def add_source(self, name, source): 419 """Add a source to this scope""" 420 self.sources[name] = source 421 self.clear_cache() 422 423 def remove_source(self, name): 424 """Remove a source from this scope""" 425 self.sources.pop(name, None) 426 self.clear_cache() 427 428 def __repr__(self): 429 return f"Scope<{self.expression.sql()}>" 430 431 def traverse(self): 432 """ 433 Traverse the scope tree from this node. 434 435 Yields: 436 Scope: scope instances in depth-first-search post-order 437 """ 438 for child_scope in itertools.chain( 439 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 440 ): 441 yield from child_scope.traverse() 442 yield self 443 444 def ref_count(self): 445 """ 446 Count the number of times each scope in this tree is referenced. 447 448 Returns: 449 dict[int, int]: Mapping of Scope instance ID to reference count 450 """ 451 scope_ref_count = defaultdict(lambda: 0) 452 453 for scope in self.traverse(): 454 for _, source in scope.selected_sources.values(): 455 scope_ref_count[id(source)] += 1 456 457 return scope_ref_count 458 459 460def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 461 """ 462 Traverse an expression by its "scopes". 463 464 "Scope" represents the current context of a Select statement. 465 466 This is helpful for optimizing queries, where we need more information than 467 the expression tree itself. For example, we might care about the source 468 names within a subquery. Returns a list because a generator could result in 469 incomplete properties which is confusing. 470 471 Examples: 472 >>> import sqlglot 473 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 474 >>> scopes = traverse_scope(expression) 475 >>> scopes[0].expression.sql(), list(scopes[0].sources) 476 ('SELECT a FROM x', ['x']) 477 >>> scopes[1].expression.sql(), list(scopes[1].sources) 478 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 479 480 Args: 481 expression: Expression to traverse 482 483 Returns: 484 A list of the created scope instances 485 """ 486 if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query): 487 # We ignore the DDL expression and build a scope for its query instead 488 ddl_with = expression.args.get("with") 489 expression = expression.expression 490 491 # If the DDL has CTEs attached, we need to add them to the query, or 492 # prepend them if the query itself already has CTEs attached to it 493 if ddl_with: 494 ddl_with.pop() 495 query_ctes = expression.ctes 496 if not query_ctes: 497 expression.set("with", ddl_with) 498 else: 499 expression.args["with"].set("recursive", ddl_with.recursive) 500 expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes]) 501 502 if isinstance(expression, exp.Query): 503 return list(_traverse_scope(Scope(expression))) 504 505 return [] 506 507 508def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 509 """ 510 Build a scope tree. 511 512 Args: 513 expression: Expression to build the scope tree for. 514 515 Returns: 516 The root scope 517 """ 518 return seq_get(traverse_scope(expression), -1) 519 520 521def _traverse_scope(scope): 522 if isinstance(scope.expression, exp.Select): 523 yield from _traverse_select(scope) 524 elif isinstance(scope.expression, exp.Union): 525 yield from _traverse_union(scope) 526 elif isinstance(scope.expression, exp.Subquery): 527 if scope.is_root: 528 yield from _traverse_select(scope) 529 else: 530 yield from _traverse_subqueries(scope) 531 elif isinstance(scope.expression, exp.Table): 532 yield from _traverse_tables(scope) 533 elif isinstance(scope.expression, exp.UDTF): 534 yield from _traverse_udtfs(scope) 535 else: 536 logger.warning( 537 "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) 538 ) 539 return 540 541 yield scope 542 543 544def _traverse_select(scope): 545 yield from _traverse_ctes(scope) 546 yield from _traverse_tables(scope) 547 yield from _traverse_subqueries(scope) 548 549 550def _traverse_union(scope): 551 yield from _traverse_ctes(scope) 552 553 # The last scope to be yield should be the top most scope 554 left = None 555 for left in _traverse_scope( 556 scope.branch( 557 scope.expression.left, 558 outer_column_list=scope.outer_column_list, 559 scope_type=ScopeType.UNION, 560 ) 561 ): 562 yield left 563 564 right = None 565 for right in _traverse_scope( 566 scope.branch( 567 scope.expression.right, 568 outer_column_list=scope.outer_column_list, 569 scope_type=ScopeType.UNION, 570 ) 571 ): 572 yield right 573 574 scope.union_scopes = [left, right] 575 576 577def _traverse_ctes(scope): 578 sources = {} 579 580 for cte in scope.ctes: 581 recursive_scope = None 582 583 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 584 # thus the recursive scope is the first section of the union. 585 with_ = scope.expression.args.get("with") 586 if with_ and with_.recursive: 587 union = cte.this 588 589 if isinstance(union, exp.Union): 590 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) 591 592 child_scope = None 593 594 for child_scope in _traverse_scope( 595 scope.branch( 596 cte.this, 597 cte_sources=sources, 598 outer_column_list=cte.alias_column_names, 599 scope_type=ScopeType.CTE, 600 ) 601 ): 602 yield child_scope 603 604 alias = cte.alias 605 sources[alias] = child_scope 606 607 if recursive_scope: 608 child_scope.add_source(alias, recursive_scope) 609 child_scope.cte_sources[alias] = recursive_scope 610 611 # append the final child_scope yielded 612 if child_scope: 613 scope.cte_scopes.append(child_scope) 614 615 scope.sources.update(sources) 616 scope.cte_sources.update(sources) 617 618 619def _is_derived_table(expression: exp.Subquery) -> bool: 620 """ 621 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 622 as it doesn't introduce a new scope. If an alias is present, it shadows all names 623 under the Subquery, so that's one exception to this rule. 624 """ 625 return isinstance(expression, exp.Subquery) and bool( 626 expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) 627 ) 628 629 630def _traverse_tables(scope): 631 sources = {} 632 633 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 634 expressions = [] 635 from_ = scope.expression.args.get("from") 636 if from_: 637 expressions.append(from_.this) 638 639 for join in scope.expression.args.get("joins") or []: 640 expressions.append(join.this) 641 642 if isinstance(scope.expression, exp.Table): 643 expressions.append(scope.expression) 644 645 expressions.extend(scope.expression.args.get("laterals") or []) 646 647 for expression in expressions: 648 if isinstance(expression, exp.Table): 649 table_name = expression.name 650 source_name = expression.alias_or_name 651 652 if table_name in scope.sources and not expression.db: 653 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 654 # it is pivoted, because then we get back a new table and hence a new source. 655 pivots = expression.args.get("pivots") 656 if pivots: 657 sources[pivots[0].alias] = expression 658 else: 659 sources[source_name] = scope.sources[table_name] 660 elif source_name in sources: 661 sources[find_new_name(sources, table_name)] = expression 662 else: 663 sources[source_name] = expression 664 665 # Make sure to not include the joins twice 666 if expression is not scope.expression: 667 expressions.extend(join.this for join in expression.args.get("joins") or []) 668 669 continue 670 671 if not isinstance(expression, exp.DerivedTable): 672 continue 673 674 if isinstance(expression, exp.UDTF): 675 lateral_sources = sources 676 scope_type = ScopeType.UDTF 677 scopes = scope.udtf_scopes 678 elif _is_derived_table(expression): 679 lateral_sources = None 680 scope_type = ScopeType.DERIVED_TABLE 681 scopes = scope.derived_table_scopes 682 expressions.extend(join.this for join in expression.args.get("joins") or []) 683 else: 684 # Makes sure we check for possible sources in nested table constructs 685 expressions.append(expression.this) 686 expressions.extend(join.this for join in expression.args.get("joins") or []) 687 continue 688 689 for child_scope in _traverse_scope( 690 scope.branch( 691 expression, 692 lateral_sources=lateral_sources, 693 outer_column_list=expression.alias_column_names, 694 scope_type=scope_type, 695 ) 696 ): 697 yield child_scope 698 699 # Tables without aliases will be set as "" 700 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 701 # Until then, this means that only a single, unaliased derived table is allowed (rather, 702 # the latest one wins. 703 sources[expression.alias] = child_scope 704 705 # append the final child_scope yielded 706 scopes.append(child_scope) 707 scope.table_scopes.append(child_scope) 708 709 scope.sources.update(sources) 710 711 712def _traverse_subqueries(scope): 713 for subquery in scope.subqueries: 714 top = None 715 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 716 yield child_scope 717 top = child_scope 718 scope.subquery_scopes.append(top) 719 720 721def _traverse_udtfs(scope): 722 if isinstance(scope.expression, exp.Unnest): 723 expressions = scope.expression.expressions 724 elif isinstance(scope.expression, exp.Lateral): 725 expressions = [scope.expression.this] 726 else: 727 expressions = [] 728 729 sources = {} 730 for expression in expressions: 731 if _is_derived_table(expression): 732 top = None 733 for child_scope in _traverse_scope( 734 scope.branch( 735 expression, 736 scope_type=ScopeType.DERIVED_TABLE, 737 outer_column_list=expression.alias_column_names, 738 ) 739 ): 740 yield child_scope 741 top = child_scope 742 sources[expression.alias] = child_scope 743 744 scope.derived_table_scopes.append(top) 745 scope.table_scopes.append(top) 746 747 scope.sources.update(sources) 748 749 750def walk_in_scope(expression, bfs=True, prune=None): 751 """ 752 Returns a generator object which visits all nodes in the syntrax tree, stopping at 753 nodes that start child scopes. 754 755 Args: 756 expression (exp.Expression): 757 bfs (bool): if set to True the BFS traversal order will be applied, 758 otherwise the DFS traversal will be used instead. 759 prune ((node, parent, arg_key) -> bool): callable that returns True if 760 the generator should stop traversing this branch of the tree. 761 762 Yields: 763 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 764 """ 765 # We'll use this variable to pass state into the dfs generator. 766 # Whenever we set it to True, we exclude a subtree from traversal. 767 crossed_scope_boundary = False 768 769 for node, parent, key in expression.walk( 770 bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) 771 ): 772 crossed_scope_boundary = False 773 774 yield node, parent, key 775 776 if node is expression: 777 continue 778 if ( 779 isinstance(node, exp.CTE) 780 or (_is_derived_table(node) and isinstance(parent, (exp.From, exp.Join, exp.Subquery))) 781 or isinstance(node, exp.UDTF) 782 or isinstance(node, exp.UNWRAPPED_QUERIES) 783 ): 784 crossed_scope_boundary = True 785 786 if isinstance(node, (exp.Subquery, exp.UDTF)): 787 # The following args are not actually in the inner scope, so we should visit them 788 for key in ("joins", "laterals", "pivots"): 789 for arg in node.args.get(key) or []: 790 yield from walk_in_scope(arg, bfs=bfs) 791 792 793def find_all_in_scope(expression, expression_types, bfs=True): 794 """ 795 Returns a generator object which visits all nodes in this scope and only yields those that 796 match at least one of the specified expression types. 797 798 This does NOT traverse into subscopes. 799 800 Args: 801 expression (exp.Expression): 802 expression_types (tuple[type]|type): the expression type(s) to match. 803 bfs (bool): True to use breadth-first search, False to use depth-first. 804 805 Yields: 806 exp.Expression: nodes 807 """ 808 for expression, *_ in walk_in_scope(expression, bfs=bfs): 809 if isinstance(expression, tuple(ensure_collection(expression_types))): 810 yield expression 811 812 813def find_in_scope(expression, expression_types, bfs=True): 814 """ 815 Returns the first node in this scope which matches at least one of the specified types. 816 817 This does NOT traverse into subscopes. 818 819 Args: 820 expression (exp.Expression): 821 expression_types (tuple[type]|type): the expression type(s) to match. 822 bfs (bool): True to use breadth-first search, False to use depth-first. 823 824 Returns: 825 exp.Expression: the node which matches the criteria or None if no node matching 826 the criteria was found. 827 """ 828 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
17class ScopeType(Enum): 18 ROOT = auto() 19 SUBQUERY = auto() 20 DERIVED_TABLE = auto() 21 CTE = auto() 22 UNION = auto() 23 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
26class Scope: 27 """ 28 Selection scope. 29 30 Attributes: 31 expression (exp.Select|exp.Union): Root expression of this scope 32 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 33 a Table expression or another Scope instance. For example: 34 SELECT * FROM x {"x": Table(this="x")} 35 SELECT * FROM x AS y {"y": Table(this="x")} 36 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 37 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 38 For example: 39 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 40 The LATERAL VIEW EXPLODE gets x as a source. 41 cte_sources (dict[str, Scope]): Sources from CTES 42 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 43 defines a column list of it's alias of this scope, this is that list of columns. 44 For example: 45 SELECT * FROM (SELECT ...) AS y(col1, col2) 46 The inner query would have `["col1", "col2"]` for its `outer_column_list` 47 parent (Scope): Parent scope 48 scope_type (ScopeType): Type of this scope, relative to it's parent 49 subquery_scopes (list[Scope]): List of all child scopes for subqueries 50 cte_scopes (list[Scope]): List of all child scopes for CTEs 51 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 52 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 53 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 54 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 55 a list of the left and right child scopes. 56 """ 57 58 def __init__( 59 self, 60 expression, 61 sources=None, 62 outer_column_list=None, 63 parent=None, 64 scope_type=ScopeType.ROOT, 65 lateral_sources=None, 66 cte_sources=None, 67 ): 68 self.expression = expression 69 self.sources = sources or {} 70 self.lateral_sources = lateral_sources or {} 71 self.cte_sources = cte_sources or {} 72 self.sources.update(self.lateral_sources) 73 self.sources.update(self.cte_sources) 74 self.outer_column_list = outer_column_list or [] 75 self.parent = parent 76 self.scope_type = scope_type 77 self.subquery_scopes = [] 78 self.derived_table_scopes = [] 79 self.table_scopes = [] 80 self.cte_scopes = [] 81 self.union_scopes = [] 82 self.udtf_scopes = [] 83 self.clear_cache() 84 85 def clear_cache(self): 86 self._collected = False 87 self._raw_columns = None 88 self._derived_tables = None 89 self._udtfs = None 90 self._tables = None 91 self._ctes = None 92 self._subqueries = None 93 self._selected_sources = None 94 self._columns = None 95 self._external_columns = None 96 self._join_hints = None 97 self._pivots = None 98 self._references = None 99 100 def branch( 101 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 102 ): 103 """Branch from the current scope to a new, inner scope""" 104 return Scope( 105 expression=expression.unnest(), 106 sources=sources.copy() if sources else None, 107 parent=self, 108 scope_type=scope_type, 109 cte_sources={**self.cte_sources, **(cte_sources or {})}, 110 lateral_sources=lateral_sources.copy() if lateral_sources else None, 111 **kwargs, 112 ) 113 114 def _collect(self): 115 self._tables = [] 116 self._ctes = [] 117 self._subqueries = [] 118 self._derived_tables = [] 119 self._udtfs = [] 120 self._raw_columns = [] 121 self._join_hints = [] 122 123 for node, parent, _ in self.walk(bfs=False): 124 if node is self.expression: 125 continue 126 127 if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 128 self._raw_columns.append(node) 129 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 130 self._tables.append(node) 131 elif isinstance(node, exp.JoinHint): 132 self._join_hints.append(node) 133 elif isinstance(node, exp.UDTF): 134 self._udtfs.append(node) 135 elif isinstance(node, exp.CTE): 136 self._ctes.append(node) 137 elif _is_derived_table(node) and isinstance(parent, (exp.From, exp.Join, exp.Subquery)): 138 self._derived_tables.append(node) 139 elif isinstance(node, exp.UNWRAPPED_QUERIES): 140 self._subqueries.append(node) 141 142 self._collected = True 143 144 def _ensure_collected(self): 145 if not self._collected: 146 self._collect() 147 148 def walk(self, bfs=True, prune=None): 149 return walk_in_scope(self.expression, bfs=bfs, prune=None) 150 151 def find(self, *expression_types, bfs=True): 152 return find_in_scope(self.expression, expression_types, bfs=bfs) 153 154 def find_all(self, *expression_types, bfs=True): 155 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 156 157 def replace(self, old, new): 158 """ 159 Replace `old` with `new`. 160 161 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 162 163 Args: 164 old (exp.Expression): old node 165 new (exp.Expression): new node 166 """ 167 old.replace(new) 168 self.clear_cache() 169 170 @property 171 def tables(self): 172 """ 173 List of tables in this scope. 174 175 Returns: 176 list[exp.Table]: tables 177 """ 178 self._ensure_collected() 179 return self._tables 180 181 @property 182 def ctes(self): 183 """ 184 List of CTEs in this scope. 185 186 Returns: 187 list[exp.CTE]: ctes 188 """ 189 self._ensure_collected() 190 return self._ctes 191 192 @property 193 def derived_tables(self): 194 """ 195 List of derived tables in this scope. 196 197 For example: 198 SELECT * FROM (SELECT ...) <- that's a derived table 199 200 Returns: 201 list[exp.Subquery]: derived tables 202 """ 203 self._ensure_collected() 204 return self._derived_tables 205 206 @property 207 def udtfs(self): 208 """ 209 List of "User Defined Tabular Functions" in this scope. 210 211 Returns: 212 list[exp.UDTF]: UDTFs 213 """ 214 self._ensure_collected() 215 return self._udtfs 216 217 @property 218 def subqueries(self): 219 """ 220 List of subqueries in this scope. 221 222 For example: 223 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 224 225 Returns: 226 list[exp.Select | exp.Union]: subqueries 227 """ 228 self._ensure_collected() 229 return self._subqueries 230 231 @property 232 def columns(self): 233 """ 234 List of columns in this scope. 235 236 Returns: 237 list[exp.Column]: Column instances in this scope, plus any 238 Columns that reference this scope from correlated subqueries. 239 """ 240 if self._columns is None: 241 self._ensure_collected() 242 columns = self._raw_columns 243 244 external_columns = [ 245 column 246 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 247 for column in scope.external_columns 248 ] 249 250 named_selects = set(self.expression.named_selects) 251 252 self._columns = [] 253 for column in columns + external_columns: 254 ancestor = column.find_ancestor( 255 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 256 ) 257 if ( 258 not ancestor 259 or column.table 260 or isinstance(ancestor, exp.Select) 261 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 262 or ( 263 isinstance(ancestor, exp.Order) 264 and ( 265 isinstance(ancestor.parent, exp.Window) 266 or column.name not in named_selects 267 ) 268 ) 269 ): 270 self._columns.append(column) 271 272 return self._columns 273 274 @property 275 def selected_sources(self): 276 """ 277 Mapping of nodes and sources that are actually selected from in this scope. 278 279 That is, all tables in a schema are selectable at any point. But a 280 table only becomes a selected source if it's included in a FROM or JOIN clause. 281 282 Returns: 283 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 284 """ 285 if self._selected_sources is None: 286 result = {} 287 288 for name, node in self.references: 289 if name in result: 290 raise OptimizeError(f"Alias already used: {name}") 291 if name in self.sources: 292 result[name] = (node, self.sources[name]) 293 294 self._selected_sources = result 295 return self._selected_sources 296 297 @property 298 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 299 if self._references is None: 300 self._references = [] 301 302 for table in self.tables: 303 self._references.append((table.alias_or_name, table)) 304 for expression in itertools.chain(self.derived_tables, self.udtfs): 305 self._references.append( 306 ( 307 expression.alias, 308 expression if expression.args.get("pivots") else expression.unnest(), 309 ) 310 ) 311 312 return self._references 313 314 @property 315 def external_columns(self): 316 """ 317 Columns that appear to reference sources in outer scopes. 318 319 Returns: 320 list[exp.Column]: Column instances that don't reference 321 sources in the current scope. 322 """ 323 if self._external_columns is None: 324 if isinstance(self.expression, exp.Union): 325 left, right = self.union_scopes 326 self._external_columns = left.external_columns + right.external_columns 327 else: 328 self._external_columns = [ 329 c for c in self.columns if c.table not in self.selected_sources 330 ] 331 332 return self._external_columns 333 334 @property 335 def unqualified_columns(self): 336 """ 337 Unqualified columns in the current scope. 338 339 Returns: 340 list[exp.Column]: Unqualified columns 341 """ 342 return [c for c in self.columns if not c.table] 343 344 @property 345 def join_hints(self): 346 """ 347 Hints that exist in the scope that reference tables 348 349 Returns: 350 list[exp.JoinHint]: Join hints that are referenced within the scope 351 """ 352 if self._join_hints is None: 353 return [] 354 return self._join_hints 355 356 @property 357 def pivots(self): 358 if not self._pivots: 359 self._pivots = [ 360 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 361 ] 362 363 return self._pivots 364 365 def source_columns(self, source_name): 366 """ 367 Get all columns in the current scope for a particular source. 368 369 Args: 370 source_name (str): Name of the source 371 Returns: 372 list[exp.Column]: Column instances that reference `source_name` 373 """ 374 return [column for column in self.columns if column.table == source_name] 375 376 @property 377 def is_subquery(self): 378 """Determine if this scope is a subquery""" 379 return self.scope_type == ScopeType.SUBQUERY 380 381 @property 382 def is_derived_table(self): 383 """Determine if this scope is a derived table""" 384 return self.scope_type == ScopeType.DERIVED_TABLE 385 386 @property 387 def is_union(self): 388 """Determine if this scope is a union""" 389 return self.scope_type == ScopeType.UNION 390 391 @property 392 def is_cte(self): 393 """Determine if this scope is a common table expression""" 394 return self.scope_type == ScopeType.CTE 395 396 @property 397 def is_root(self): 398 """Determine if this is the root scope""" 399 return self.scope_type == ScopeType.ROOT 400 401 @property 402 def is_udtf(self): 403 """Determine if this scope is a UDTF (User Defined Table Function)""" 404 return self.scope_type == ScopeType.UDTF 405 406 @property 407 def is_correlated_subquery(self): 408 """Determine if this scope is a correlated subquery""" 409 return bool( 410 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 411 and self.external_columns 412 ) 413 414 def rename_source(self, old_name, new_name): 415 """Rename a source in this scope""" 416 columns = self.sources.pop(old_name or "", []) 417 self.sources[new_name] = columns 418 419 def add_source(self, name, source): 420 """Add a source to this scope""" 421 self.sources[name] = source 422 self.clear_cache() 423 424 def remove_source(self, name): 425 """Remove a source from this scope""" 426 self.sources.pop(name, None) 427 self.clear_cache() 428 429 def __repr__(self): 430 return f"Scope<{self.expression.sql()}>" 431 432 def traverse(self): 433 """ 434 Traverse the scope tree from this node. 435 436 Yields: 437 Scope: scope instances in depth-first-search post-order 438 """ 439 for child_scope in itertools.chain( 440 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 441 ): 442 yield from child_scope.traverse() 443 yield self 444 445 def ref_count(self): 446 """ 447 Count the number of times each scope in this tree is referenced. 448 449 Returns: 450 dict[int, int]: Mapping of Scope instance ID to reference count 451 """ 452 scope_ref_count = defaultdict(lambda: 0) 453 454 for scope in self.traverse(): 455 for _, source in scope.selected_sources.values(): 456 scope_ref_count[id(source)] += 1 457 458 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_column_list
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
58 def __init__( 59 self, 60 expression, 61 sources=None, 62 outer_column_list=None, 63 parent=None, 64 scope_type=ScopeType.ROOT, 65 lateral_sources=None, 66 cte_sources=None, 67 ): 68 self.expression = expression 69 self.sources = sources or {} 70 self.lateral_sources = lateral_sources or {} 71 self.cte_sources = cte_sources or {} 72 self.sources.update(self.lateral_sources) 73 self.sources.update(self.cte_sources) 74 self.outer_column_list = outer_column_list or [] 75 self.parent = parent 76 self.scope_type = scope_type 77 self.subquery_scopes = [] 78 self.derived_table_scopes = [] 79 self.table_scopes = [] 80 self.cte_scopes = [] 81 self.union_scopes = [] 82 self.udtf_scopes = [] 83 self.clear_cache()
85 def clear_cache(self): 86 self._collected = False 87 self._raw_columns = None 88 self._derived_tables = None 89 self._udtfs = None 90 self._tables = None 91 self._ctes = None 92 self._subqueries = None 93 self._selected_sources = None 94 self._columns = None 95 self._external_columns = None 96 self._join_hints = None 97 self._pivots = None 98 self._references = None
100 def branch( 101 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 102 ): 103 """Branch from the current scope to a new, inner scope""" 104 return Scope( 105 expression=expression.unnest(), 106 sources=sources.copy() if sources else None, 107 parent=self, 108 scope_type=scope_type, 109 cte_sources={**self.cte_sources, **(cte_sources or {})}, 110 lateral_sources=lateral_sources.copy() if lateral_sources else None, 111 **kwargs, 112 )
Branch from the current scope to a new, inner scope
157 def replace(self, old, new): 158 """ 159 Replace `old` with `new`. 160 161 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 162 163 Args: 164 old (exp.Expression): old node 165 new (exp.Expression): new node 166 """ 167 old.replace(new) 168 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
170 @property 171 def tables(self): 172 """ 173 List of tables in this scope. 174 175 Returns: 176 list[exp.Table]: tables 177 """ 178 self._ensure_collected() 179 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
181 @property 182 def ctes(self): 183 """ 184 List of CTEs in this scope. 185 186 Returns: 187 list[exp.CTE]: ctes 188 """ 189 self._ensure_collected() 190 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
192 @property 193 def derived_tables(self): 194 """ 195 List of derived tables in this scope. 196 197 For example: 198 SELECT * FROM (SELECT ...) <- that's a derived table 199 200 Returns: 201 list[exp.Subquery]: derived tables 202 """ 203 self._ensure_collected() 204 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
206 @property 207 def udtfs(self): 208 """ 209 List of "User Defined Tabular Functions" in this scope. 210 211 Returns: 212 list[exp.UDTF]: UDTFs 213 """ 214 self._ensure_collected() 215 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
217 @property 218 def subqueries(self): 219 """ 220 List of subqueries in this scope. 221 222 For example: 223 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 224 225 Returns: 226 list[exp.Select | exp.Union]: subqueries 227 """ 228 self._ensure_collected() 229 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.Union]: subqueries
231 @property 232 def columns(self): 233 """ 234 List of columns in this scope. 235 236 Returns: 237 list[exp.Column]: Column instances in this scope, plus any 238 Columns that reference this scope from correlated subqueries. 239 """ 240 if self._columns is None: 241 self._ensure_collected() 242 columns = self._raw_columns 243 244 external_columns = [ 245 column 246 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 247 for column in scope.external_columns 248 ] 249 250 named_selects = set(self.expression.named_selects) 251 252 self._columns = [] 253 for column in columns + external_columns: 254 ancestor = column.find_ancestor( 255 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 256 ) 257 if ( 258 not ancestor 259 or column.table 260 or isinstance(ancestor, exp.Select) 261 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 262 or ( 263 isinstance(ancestor, exp.Order) 264 and ( 265 isinstance(ancestor.parent, exp.Window) 266 or column.name not in named_selects 267 ) 268 ) 269 ): 270 self._columns.append(column) 271 272 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
274 @property 275 def selected_sources(self): 276 """ 277 Mapping of nodes and sources that are actually selected from in this scope. 278 279 That is, all tables in a schema are selectable at any point. But a 280 table only becomes a selected source if it's included in a FROM or JOIN clause. 281 282 Returns: 283 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 284 """ 285 if self._selected_sources is None: 286 result = {} 287 288 for name, node in self.references: 289 if name in result: 290 raise OptimizeError(f"Alias already used: {name}") 291 if name in self.sources: 292 result[name] = (node, self.sources[name]) 293 294 self._selected_sources = result 295 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
297 @property 298 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 299 if self._references is None: 300 self._references = [] 301 302 for table in self.tables: 303 self._references.append((table.alias_or_name, table)) 304 for expression in itertools.chain(self.derived_tables, self.udtfs): 305 self._references.append( 306 ( 307 expression.alias, 308 expression if expression.args.get("pivots") else expression.unnest(), 309 ) 310 ) 311 312 return self._references
314 @property 315 def external_columns(self): 316 """ 317 Columns that appear to reference sources in outer scopes. 318 319 Returns: 320 list[exp.Column]: Column instances that don't reference 321 sources in the current scope. 322 """ 323 if self._external_columns is None: 324 if isinstance(self.expression, exp.Union): 325 left, right = self.union_scopes 326 self._external_columns = left.external_columns + right.external_columns 327 else: 328 self._external_columns = [ 329 c for c in self.columns if c.table not in self.selected_sources 330 ] 331 332 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
334 @property 335 def unqualified_columns(self): 336 """ 337 Unqualified columns in the current scope. 338 339 Returns: 340 list[exp.Column]: Unqualified columns 341 """ 342 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
344 @property 345 def join_hints(self): 346 """ 347 Hints that exist in the scope that reference tables 348 349 Returns: 350 list[exp.JoinHint]: Join hints that are referenced within the scope 351 """ 352 if self._join_hints is None: 353 return [] 354 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
365 def source_columns(self, source_name): 366 """ 367 Get all columns in the current scope for a particular source. 368 369 Args: 370 source_name (str): Name of the source 371 Returns: 372 list[exp.Column]: Column instances that reference `source_name` 373 """ 374 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
376 @property 377 def is_subquery(self): 378 """Determine if this scope is a subquery""" 379 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
381 @property 382 def is_derived_table(self): 383 """Determine if this scope is a derived table""" 384 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
386 @property 387 def is_union(self): 388 """Determine if this scope is a union""" 389 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
391 @property 392 def is_cte(self): 393 """Determine if this scope is a common table expression""" 394 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
396 @property 397 def is_root(self): 398 """Determine if this is the root scope""" 399 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
401 @property 402 def is_udtf(self): 403 """Determine if this scope is a UDTF (User Defined Table Function)""" 404 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
414 def rename_source(self, old_name, new_name): 415 """Rename a source in this scope""" 416 columns = self.sources.pop(old_name or "", []) 417 self.sources[new_name] = columns
Rename a source in this scope
419 def add_source(self, name, source): 420 """Add a source to this scope""" 421 self.sources[name] = source 422 self.clear_cache()
Add a source to this scope
424 def remove_source(self, name): 425 """Remove a source from this scope""" 426 self.sources.pop(name, None) 427 self.clear_cache()
Remove a source from this scope
432 def traverse(self): 433 """ 434 Traverse the scope tree from this node. 435 436 Yields: 437 Scope: scope instances in depth-first-search post-order 438 """ 439 for child_scope in itertools.chain( 440 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 441 ): 442 yield from child_scope.traverse() 443 yield self
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
445 def ref_count(self): 446 """ 447 Count the number of times each scope in this tree is referenced. 448 449 Returns: 450 dict[int, int]: Mapping of Scope instance ID to reference count 451 """ 452 scope_ref_count = defaultdict(lambda: 0) 453 454 for scope in self.traverse(): 455 for _, source in scope.selected_sources.values(): 456 scope_ref_count[id(source)] += 1 457 458 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
461def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 462 """ 463 Traverse an expression by its "scopes". 464 465 "Scope" represents the current context of a Select statement. 466 467 This is helpful for optimizing queries, where we need more information than 468 the expression tree itself. For example, we might care about the source 469 names within a subquery. Returns a list because a generator could result in 470 incomplete properties which is confusing. 471 472 Examples: 473 >>> import sqlglot 474 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 475 >>> scopes = traverse_scope(expression) 476 >>> scopes[0].expression.sql(), list(scopes[0].sources) 477 ('SELECT a FROM x', ['x']) 478 >>> scopes[1].expression.sql(), list(scopes[1].sources) 479 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 480 481 Args: 482 expression: Expression to traverse 483 484 Returns: 485 A list of the created scope instances 486 """ 487 if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query): 488 # We ignore the DDL expression and build a scope for its query instead 489 ddl_with = expression.args.get("with") 490 expression = expression.expression 491 492 # If the DDL has CTEs attached, we need to add them to the query, or 493 # prepend them if the query itself already has CTEs attached to it 494 if ddl_with: 495 ddl_with.pop() 496 query_ctes = expression.ctes 497 if not query_ctes: 498 expression.set("with", ddl_with) 499 else: 500 expression.args["with"].set("recursive", ddl_with.recursive) 501 expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes]) 502 503 if isinstance(expression, exp.Query): 504 return list(_traverse_scope(Scope(expression))) 505 506 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression: Expression to traverse
Returns:
A list of the created scope instances
509def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 510 """ 511 Build a scope tree. 512 513 Args: 514 expression: Expression to build the scope tree for. 515 516 Returns: 517 The root scope 518 """ 519 return seq_get(traverse_scope(expression), -1)
Build a scope tree.
Arguments:
- expression: Expression to build the scope tree for.
Returns:
The root scope
751def walk_in_scope(expression, bfs=True, prune=None): 752 """ 753 Returns a generator object which visits all nodes in the syntrax tree, stopping at 754 nodes that start child scopes. 755 756 Args: 757 expression (exp.Expression): 758 bfs (bool): if set to True the BFS traversal order will be applied, 759 otherwise the DFS traversal will be used instead. 760 prune ((node, parent, arg_key) -> bool): callable that returns True if 761 the generator should stop traversing this branch of the tree. 762 763 Yields: 764 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 765 """ 766 # We'll use this variable to pass state into the dfs generator. 767 # Whenever we set it to True, we exclude a subtree from traversal. 768 crossed_scope_boundary = False 769 770 for node, parent, key in expression.walk( 771 bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) 772 ): 773 crossed_scope_boundary = False 774 775 yield node, parent, key 776 777 if node is expression: 778 continue 779 if ( 780 isinstance(node, exp.CTE) 781 or (_is_derived_table(node) and isinstance(parent, (exp.From, exp.Join, exp.Subquery))) 782 or isinstance(node, exp.UDTF) 783 or isinstance(node, exp.UNWRAPPED_QUERIES) 784 ): 785 crossed_scope_boundary = True 786 787 if isinstance(node, (exp.Subquery, exp.UDTF)): 788 # The following args are not actually in the inner scope, so we should visit them 789 for key in ("joins", "laterals", "pivots"): 790 for arg in node.args.get(key) or []: 791 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
794def find_all_in_scope(expression, expression_types, bfs=True): 795 """ 796 Returns a generator object which visits all nodes in this scope and only yields those that 797 match at least one of the specified expression types. 798 799 This does NOT traverse into subscopes. 800 801 Args: 802 expression (exp.Expression): 803 expression_types (tuple[type]|type): the expression type(s) to match. 804 bfs (bool): True to use breadth-first search, False to use depth-first. 805 806 Yields: 807 exp.Expression: nodes 808 """ 809 for expression, *_ in walk_in_scope(expression, bfs=bfs): 810 if isinstance(expression, tuple(ensure_collection(expression_types))): 811 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
814def find_in_scope(expression, expression_types, bfs=True): 815 """ 816 Returns the first node in this scope which matches at least one of the specified types. 817 818 This does NOT traverse into subscopes. 819 820 Args: 821 expression (exp.Expression): 822 expression_types (tuple[type]|type): the expression type(s) to match. 823 bfs (bool): True to use breadth-first search, False to use depth-first. 824 825 Returns: 826 exp.Expression: the node which matches the criteria or None if no node matching 827 the criteria was found. 828 """ 829 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.