sqlglot.optimizer.scope
1import itertools 2from collections import defaultdict 3from enum import Enum, auto 4 5from sqlglot import exp 6from sqlglot.errors import OptimizeError 7from sqlglot.helper import find_new_name 8 9 10class ScopeType(Enum): 11 ROOT = auto() 12 SUBQUERY = auto() 13 DERIVED_TABLE = auto() 14 CTE = auto() 15 UNION = auto() 16 UDTF = auto() 17 18 19class Scope: 20 """ 21 Selection scope. 22 23 Attributes: 24 expression (exp.Select|exp.Union): Root expression of this scope 25 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 26 a Table expression or another Scope instance. For example: 27 SELECT * FROM x {"x": Table(this="x")} 28 SELECT * FROM x AS y {"y": Table(this="x")} 29 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 30 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 31 For example: 32 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 33 The LATERAL VIEW EXPLODE gets x as a source. 34 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 35 defines a column list of it's alias of this scope, this is that list of columns. 36 For example: 37 SELECT * FROM (SELECT ...) AS y(col1, col2) 38 The inner query would have `["col1", "col2"]` for its `outer_column_list` 39 parent (Scope): Parent scope 40 scope_type (ScopeType): Type of this scope, relative to it's parent 41 subquery_scopes (list[Scope]): List of all child scopes for subqueries 42 cte_scopes (list[Scope]): List of all child scopes for CTEs 43 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 44 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 45 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 46 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 47 a list of the left and right child scopes. 48 """ 49 50 def __init__( 51 self, 52 expression, 53 sources=None, 54 outer_column_list=None, 55 parent=None, 56 scope_type=ScopeType.ROOT, 57 lateral_sources=None, 58 ): 59 self.expression = expression 60 self.sources = sources or {} 61 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 62 self.sources.update(self.lateral_sources) 63 self.outer_column_list = outer_column_list or [] 64 self.parent = parent 65 self.scope_type = scope_type 66 self.subquery_scopes = [] 67 self.derived_table_scopes = [] 68 self.table_scopes = [] 69 self.cte_scopes = [] 70 self.union_scopes = [] 71 self.udtf_scopes = [] 72 self.clear_cache() 73 74 def clear_cache(self): 75 self._collected = False 76 self._raw_columns = None 77 self._derived_tables = None 78 self._udtfs = None 79 self._tables = None 80 self._ctes = None 81 self._subqueries = None 82 self._selected_sources = None 83 self._columns = None 84 self._external_columns = None 85 self._join_hints = None 86 self._pivots = None 87 88 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 89 """Branch from the current scope to a new, inner scope""" 90 return Scope( 91 expression=expression.unnest(), 92 sources={**self.cte_sources, **(chain_sources or {})}, 93 parent=self, 94 scope_type=scope_type, 95 **kwargs, 96 ) 97 98 def _collect(self): 99 self._tables = [] 100 self._ctes = [] 101 self._subqueries = [] 102 self._derived_tables = [] 103 self._udtfs = [] 104 self._raw_columns = [] 105 self._join_hints = [] 106 107 for node, parent, _ in self.walk(bfs=False): 108 if node is self.expression: 109 continue 110 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 111 self._raw_columns.append(node) 112 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 113 self._tables.append(node) 114 elif isinstance(node, exp.JoinHint): 115 self._join_hints.append(node) 116 elif isinstance(node, exp.UDTF): 117 self._udtfs.append(node) 118 elif isinstance(node, exp.CTE): 119 self._ctes.append(node) 120 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): 121 self._derived_tables.append(node) 122 elif isinstance(node, exp.Subqueryable): 123 self._subqueries.append(node) 124 125 self._collected = True 126 127 def _ensure_collected(self): 128 if not self._collected: 129 self._collect() 130 131 def walk(self, bfs=True): 132 return walk_in_scope(self.expression, bfs=bfs) 133 134 def find(self, *expression_types, bfs=True): 135 """ 136 Returns the first node in this scope which matches at least one of the specified types. 137 138 This does NOT traverse into subscopes. 139 140 Args: 141 expression_types (type): the expression type(s) to match. 142 bfs (bool): True to use breadth-first search, False to use depth-first. 143 144 Returns: 145 exp.Expression: the node which matches the criteria or None if no node matching 146 the criteria was found. 147 """ 148 return next(self.find_all(*expression_types, bfs=bfs), None) 149 150 def find_all(self, *expression_types, bfs=True): 151 """ 152 Returns a generator object which visits all nodes in this scope and only yields those that 153 match at least one of the specified expression types. 154 155 This does NOT traverse into subscopes. 156 157 Args: 158 expression_types (type): the expression type(s) to match. 159 bfs (bool): True to use breadth-first search, False to use depth-first. 160 161 Yields: 162 exp.Expression: nodes 163 """ 164 for expression, *_ in self.walk(bfs=bfs): 165 if isinstance(expression, expression_types): 166 yield expression 167 168 def replace(self, old, new): 169 """ 170 Replace `old` with `new`. 171 172 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 173 174 Args: 175 old (exp.Expression): old node 176 new (exp.Expression): new node 177 """ 178 old.replace(new) 179 self.clear_cache() 180 181 @property 182 def tables(self): 183 """ 184 List of tables in this scope. 185 186 Returns: 187 list[exp.Table]: tables 188 """ 189 self._ensure_collected() 190 return self._tables 191 192 @property 193 def ctes(self): 194 """ 195 List of CTEs in this scope. 196 197 Returns: 198 list[exp.CTE]: ctes 199 """ 200 self._ensure_collected() 201 return self._ctes 202 203 @property 204 def derived_tables(self): 205 """ 206 List of derived tables in this scope. 207 208 For example: 209 SELECT * FROM (SELECT ...) <- that's a derived table 210 211 Returns: 212 list[exp.Subquery]: derived tables 213 """ 214 self._ensure_collected() 215 return self._derived_tables 216 217 @property 218 def udtfs(self): 219 """ 220 List of "User Defined Tabular Functions" in this scope. 221 222 Returns: 223 list[exp.UDTF]: UDTFs 224 """ 225 self._ensure_collected() 226 return self._udtfs 227 228 @property 229 def subqueries(self): 230 """ 231 List of subqueries in this scope. 232 233 For example: 234 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 235 236 Returns: 237 list[exp.Subqueryable]: subqueries 238 """ 239 self._ensure_collected() 240 return self._subqueries 241 242 @property 243 def columns(self): 244 """ 245 List of columns in this scope. 246 247 Returns: 248 list[exp.Column]: Column instances in this scope, plus any 249 Columns that reference this scope from correlated subqueries. 250 """ 251 if self._columns is None: 252 self._ensure_collected() 253 columns = self._raw_columns 254 255 external_columns = [ 256 column 257 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 258 for column in scope.external_columns 259 ] 260 261 named_selects = set(self.expression.named_selects) 262 263 self._columns = [] 264 for column in columns + external_columns: 265 ancestor = column.find_ancestor( 266 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint 267 ) 268 if ( 269 not ancestor 270 or column.table 271 or isinstance(ancestor, exp.Select) 272 or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window)) 273 or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) 274 ): 275 self._columns.append(column) 276 277 return self._columns 278 279 @property 280 def selected_sources(self): 281 """ 282 Mapping of nodes and sources that are actually selected from in this scope. 283 284 That is, all tables in a schema are selectable at any point. But a 285 table only becomes a selected source if it's included in a FROM or JOIN clause. 286 287 Returns: 288 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 289 """ 290 if self._selected_sources is None: 291 referenced_names = [] 292 293 for table in self.tables: 294 referenced_names.append((table.alias_or_name, table)) 295 for expression in itertools.chain(self.derived_tables, self.udtfs): 296 referenced_names.append((expression.alias, expression.unnest())) 297 result = {} 298 299 for name, node in referenced_names: 300 if name in result: 301 raise OptimizeError(f"Alias already used: {name}") 302 if name in self.sources: 303 result[name] = (node, self.sources[name]) 304 305 self._selected_sources = result 306 return self._selected_sources 307 308 @property 309 def cte_sources(self): 310 """ 311 Sources that are CTEs. 312 313 Returns: 314 dict[str, Scope]: Mapping of source alias to Scope 315 """ 316 return { 317 alias: scope 318 for alias, scope in self.sources.items() 319 if isinstance(scope, Scope) and scope.is_cte 320 } 321 322 @property 323 def selects(self): 324 """ 325 Select expressions of this scope. 326 327 For example, for the following expression: 328 SELECT 1 as a, 2 as b FROM x 329 330 The outputs are the "1 as a" and "2 as b" expressions. 331 332 Returns: 333 list[exp.Expression]: expressions 334 """ 335 if isinstance(self.expression, exp.Union): 336 return self.expression.unnest().selects 337 return self.expression.selects 338 339 @property 340 def external_columns(self): 341 """ 342 Columns that appear to reference sources in outer scopes. 343 344 Returns: 345 list[exp.Column]: Column instances that don't reference 346 sources in the current scope. 347 """ 348 if self._external_columns is None: 349 self._external_columns = [ 350 c for c in self.columns if c.table not in self.selected_sources 351 ] 352 return self._external_columns 353 354 @property 355 def unqualified_columns(self): 356 """ 357 Unqualified columns in the current scope. 358 359 Returns: 360 list[exp.Column]: Unqualified columns 361 """ 362 return [c for c in self.columns if not c.table] 363 364 @property 365 def join_hints(self): 366 """ 367 Hints that exist in the scope that reference tables 368 369 Returns: 370 list[exp.JoinHint]: Join hints that are referenced within the scope 371 """ 372 if self._join_hints is None: 373 return [] 374 return self._join_hints 375 376 @property 377 def pivots(self): 378 if not self._pivots: 379 self._pivots = [ 380 pivot 381 for node in self.tables + self.derived_tables 382 for pivot in node.args.get("pivots") or [] 383 ] 384 385 return self._pivots 386 387 def source_columns(self, source_name): 388 """ 389 Get all columns in the current scope for a particular source. 390 391 Args: 392 source_name (str): Name of the source 393 Returns: 394 list[exp.Column]: Column instances that reference `source_name` 395 """ 396 return [column for column in self.columns if column.table == source_name] 397 398 @property 399 def is_subquery(self): 400 """Determine if this scope is a subquery""" 401 return self.scope_type == ScopeType.SUBQUERY 402 403 @property 404 def is_derived_table(self): 405 """Determine if this scope is a derived table""" 406 return self.scope_type == ScopeType.DERIVED_TABLE 407 408 @property 409 def is_union(self): 410 """Determine if this scope is a union""" 411 return self.scope_type == ScopeType.UNION 412 413 @property 414 def is_cte(self): 415 """Determine if this scope is a common table expression""" 416 return self.scope_type == ScopeType.CTE 417 418 @property 419 def is_root(self): 420 """Determine if this is the root scope""" 421 return self.scope_type == ScopeType.ROOT 422 423 @property 424 def is_udtf(self): 425 """Determine if this scope is a UDTF (User Defined Table Function)""" 426 return self.scope_type == ScopeType.UDTF 427 428 @property 429 def is_correlated_subquery(self): 430 """Determine if this scope is a correlated subquery""" 431 return bool(self.is_subquery and self.external_columns) 432 433 def rename_source(self, old_name, new_name): 434 """Rename a source in this scope""" 435 columns = self.sources.pop(old_name or "", []) 436 self.sources[new_name] = columns 437 438 def add_source(self, name, source): 439 """Add a source to this scope""" 440 self.sources[name] = source 441 self.clear_cache() 442 443 def remove_source(self, name): 444 """Remove a source from this scope""" 445 self.sources.pop(name, None) 446 self.clear_cache() 447 448 def __repr__(self): 449 return f"Scope<{self.expression.sql()}>" 450 451 def traverse(self): 452 """ 453 Traverse the scope tree from this node. 454 455 Yields: 456 Scope: scope instances in depth-first-search post-order 457 """ 458 for child_scope in itertools.chain( 459 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 460 ): 461 yield from child_scope.traverse() 462 yield self 463 464 def ref_count(self): 465 """ 466 Count the number of times each scope in this tree is referenced. 467 468 Returns: 469 dict[int, int]: Mapping of Scope instance ID to reference count 470 """ 471 scope_ref_count = defaultdict(lambda: 0) 472 473 for scope in self.traverse(): 474 for _, source in scope.selected_sources.values(): 475 scope_ref_count[id(source)] += 1 476 477 return scope_ref_count 478 479 480def traverse_scope(expression): 481 """ 482 Traverse an expression by it's "scopes". 483 484 "Scope" represents the current context of a Select statement. 485 486 This is helpful for optimizing queries, where we need more information than 487 the expression tree itself. For example, we might care about the source 488 names within a subquery. Returns a list because a generator could result in 489 incomplete properties which is confusing. 490 491 Examples: 492 >>> import sqlglot 493 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 494 >>> scopes = traverse_scope(expression) 495 >>> scopes[0].expression.sql(), list(scopes[0].sources) 496 ('SELECT a FROM x', ['x']) 497 >>> scopes[1].expression.sql(), list(scopes[1].sources) 498 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 499 500 Args: 501 expression (exp.Expression): expression to traverse 502 Returns: 503 list[Scope]: scope instances 504 """ 505 return list(_traverse_scope(Scope(expression))) 506 507 508def build_scope(expression): 509 """ 510 Build a scope tree. 511 512 Args: 513 expression (exp.Expression): expression to build the scope tree for 514 Returns: 515 Scope: root scope 516 """ 517 return traverse_scope(expression)[-1] 518 519 520def _traverse_scope(scope): 521 if isinstance(scope.expression, exp.Select): 522 yield from _traverse_select(scope) 523 elif isinstance(scope.expression, exp.Union): 524 yield from _traverse_union(scope) 525 elif isinstance(scope.expression, exp.Subquery): 526 yield from _traverse_subqueries(scope) 527 elif isinstance(scope.expression, exp.Table): 528 # This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..) 529 yield from _traverse_tables(scope) 530 elif isinstance(scope.expression, exp.UDTF): 531 pass 532 else: 533 raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") 534 yield scope 535 536 537def _traverse_select(scope): 538 yield from _traverse_ctes(scope) 539 yield from _traverse_tables(scope) 540 yield from _traverse_subqueries(scope) 541 542 543def _traverse_union(scope): 544 yield from _traverse_ctes(scope) 545 546 # The last scope to be yield should be the top most scope 547 left = None 548 for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): 549 yield left 550 551 right = None 552 for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): 553 yield right 554 555 scope.union_scopes = [left, right] 556 557 558def _traverse_ctes(scope): 559 sources = {} 560 561 for cte in scope.ctes: 562 recursive_scope = None 563 564 # if the scope is a recursive cte, it must be in the form of 565 # base_case UNION recursive. thus the recursive scope is the first 566 # section of the union. 567 if scope.expression.args["with"].recursive: 568 union = cte.this 569 570 if isinstance(union, exp.Union): 571 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) 572 573 for child_scope in _traverse_scope( 574 scope.branch( 575 cte.this, 576 chain_sources=sources, 577 outer_column_list=cte.alias_column_names, 578 scope_type=ScopeType.CTE, 579 ) 580 ): 581 yield child_scope 582 583 alias = cte.alias 584 sources[alias] = child_scope 585 586 if recursive_scope: 587 child_scope.add_source(alias, recursive_scope) 588 589 # append the final child_scope yielded 590 scope.cte_scopes.append(child_scope) 591 592 scope.sources.update(sources) 593 594 595def _traverse_tables(scope): 596 sources = {} 597 598 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 599 expressions = [] 600 from_ = scope.expression.args.get("from") 601 if from_: 602 expressions.append(from_.this) 603 604 for join in scope.expression.args.get("joins") or []: 605 expressions.append(join.this) 606 607 if isinstance(scope.expression, exp.Table): 608 expressions.append(scope.expression) 609 610 expressions.extend(scope.expression.args.get("laterals") or []) 611 612 for expression in expressions: 613 if isinstance(expression, exp.Table): 614 table_name = expression.name 615 source_name = expression.alias_or_name 616 617 if table_name in scope.sources: 618 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 619 # it is pivoted, because then we get back a new table and hence a new source. 620 pivots = expression.args.get("pivots") 621 if pivots: 622 sources[pivots[0].alias] = expression 623 else: 624 sources[source_name] = scope.sources[table_name] 625 elif source_name in sources: 626 sources[find_new_name(sources, table_name)] = expression 627 else: 628 sources[source_name] = expression 629 continue 630 631 if isinstance(expression, exp.UDTF): 632 lateral_sources = sources 633 scope_type = ScopeType.UDTF 634 scopes = scope.udtf_scopes 635 else: 636 lateral_sources = None 637 scope_type = ScopeType.DERIVED_TABLE 638 scopes = scope.derived_table_scopes 639 640 for child_scope in _traverse_scope( 641 scope.branch( 642 expression, 643 lateral_sources=lateral_sources, 644 outer_column_list=expression.alias_column_names, 645 scope_type=scope_type, 646 ) 647 ): 648 yield child_scope 649 650 # Tables without aliases will be set as "" 651 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 652 # Until then, this means that only a single, unaliased derived table is allowed (rather, 653 # the latest one wins. 654 alias = expression.alias 655 sources[alias] = child_scope 656 657 # append the final child_scope yielded 658 scopes.append(child_scope) 659 scope.table_scopes.append(child_scope) 660 661 scope.sources.update(sources) 662 663 664def _traverse_subqueries(scope): 665 for subquery in scope.subqueries: 666 top = None 667 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 668 yield child_scope 669 top = child_scope 670 scope.subquery_scopes.append(top) 671 672 673def walk_in_scope(expression, bfs=True): 674 """ 675 Returns a generator object which visits all nodes in the syntrax tree, stopping at 676 nodes that start child scopes. 677 678 Args: 679 expression (exp.Expression): 680 bfs (bool): if set to True the BFS traversal order will be applied, 681 otherwise the DFS traversal will be used instead. 682 683 Yields: 684 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 685 """ 686 # We'll use this variable to pass state into the dfs generator. 687 # Whenever we set it to True, we exclude a subtree from traversal. 688 prune = False 689 690 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 691 prune = False 692 693 yield node, parent, key 694 695 if node is expression: 696 continue 697 if ( 698 isinstance(node, exp.CTE) 699 or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) 700 or isinstance(node, exp.UDTF) 701 or isinstance(node, exp.Subqueryable) 702 ): 703 prune = True
11class ScopeType(Enum): 12 ROOT = auto() 13 SUBQUERY = auto() 14 DERIVED_TABLE = auto() 15 CTE = auto() 16 UNION = auto() 17 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
20class Scope: 21 """ 22 Selection scope. 23 24 Attributes: 25 expression (exp.Select|exp.Union): Root expression of this scope 26 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 27 a Table expression or another Scope instance. For example: 28 SELECT * FROM x {"x": Table(this="x")} 29 SELECT * FROM x AS y {"y": Table(this="x")} 30 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 31 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 32 For example: 33 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 34 The LATERAL VIEW EXPLODE gets x as a source. 35 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 36 defines a column list of it's alias of this scope, this is that list of columns. 37 For example: 38 SELECT * FROM (SELECT ...) AS y(col1, col2) 39 The inner query would have `["col1", "col2"]` for its `outer_column_list` 40 parent (Scope): Parent scope 41 scope_type (ScopeType): Type of this scope, relative to it's parent 42 subquery_scopes (list[Scope]): List of all child scopes for subqueries 43 cte_scopes (list[Scope]): List of all child scopes for CTEs 44 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 45 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 46 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 47 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 48 a list of the left and right child scopes. 49 """ 50 51 def __init__( 52 self, 53 expression, 54 sources=None, 55 outer_column_list=None, 56 parent=None, 57 scope_type=ScopeType.ROOT, 58 lateral_sources=None, 59 ): 60 self.expression = expression 61 self.sources = sources or {} 62 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 63 self.sources.update(self.lateral_sources) 64 self.outer_column_list = outer_column_list or [] 65 self.parent = parent 66 self.scope_type = scope_type 67 self.subquery_scopes = [] 68 self.derived_table_scopes = [] 69 self.table_scopes = [] 70 self.cte_scopes = [] 71 self.union_scopes = [] 72 self.udtf_scopes = [] 73 self.clear_cache() 74 75 def clear_cache(self): 76 self._collected = False 77 self._raw_columns = None 78 self._derived_tables = None 79 self._udtfs = None 80 self._tables = None 81 self._ctes = None 82 self._subqueries = None 83 self._selected_sources = None 84 self._columns = None 85 self._external_columns = None 86 self._join_hints = None 87 self._pivots = None 88 89 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 90 """Branch from the current scope to a new, inner scope""" 91 return Scope( 92 expression=expression.unnest(), 93 sources={**self.cte_sources, **(chain_sources or {})}, 94 parent=self, 95 scope_type=scope_type, 96 **kwargs, 97 ) 98 99 def _collect(self): 100 self._tables = [] 101 self._ctes = [] 102 self._subqueries = [] 103 self._derived_tables = [] 104 self._udtfs = [] 105 self._raw_columns = [] 106 self._join_hints = [] 107 108 for node, parent, _ in self.walk(bfs=False): 109 if node is self.expression: 110 continue 111 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 112 self._raw_columns.append(node) 113 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 114 self._tables.append(node) 115 elif isinstance(node, exp.JoinHint): 116 self._join_hints.append(node) 117 elif isinstance(node, exp.UDTF): 118 self._udtfs.append(node) 119 elif isinstance(node, exp.CTE): 120 self._ctes.append(node) 121 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): 122 self._derived_tables.append(node) 123 elif isinstance(node, exp.Subqueryable): 124 self._subqueries.append(node) 125 126 self._collected = True 127 128 def _ensure_collected(self): 129 if not self._collected: 130 self._collect() 131 132 def walk(self, bfs=True): 133 return walk_in_scope(self.expression, bfs=bfs) 134 135 def find(self, *expression_types, bfs=True): 136 """ 137 Returns the first node in this scope which matches at least one of the specified types. 138 139 This does NOT traverse into subscopes. 140 141 Args: 142 expression_types (type): the expression type(s) to match. 143 bfs (bool): True to use breadth-first search, False to use depth-first. 144 145 Returns: 146 exp.Expression: the node which matches the criteria or None if no node matching 147 the criteria was found. 148 """ 149 return next(self.find_all(*expression_types, bfs=bfs), None) 150 151 def find_all(self, *expression_types, bfs=True): 152 """ 153 Returns a generator object which visits all nodes in this scope and only yields those that 154 match at least one of the specified expression types. 155 156 This does NOT traverse into subscopes. 157 158 Args: 159 expression_types (type): the expression type(s) to match. 160 bfs (bool): True to use breadth-first search, False to use depth-first. 161 162 Yields: 163 exp.Expression: nodes 164 """ 165 for expression, *_ in self.walk(bfs=bfs): 166 if isinstance(expression, expression_types): 167 yield expression 168 169 def replace(self, old, new): 170 """ 171 Replace `old` with `new`. 172 173 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 174 175 Args: 176 old (exp.Expression): old node 177 new (exp.Expression): new node 178 """ 179 old.replace(new) 180 self.clear_cache() 181 182 @property 183 def tables(self): 184 """ 185 List of tables in this scope. 186 187 Returns: 188 list[exp.Table]: tables 189 """ 190 self._ensure_collected() 191 return self._tables 192 193 @property 194 def ctes(self): 195 """ 196 List of CTEs in this scope. 197 198 Returns: 199 list[exp.CTE]: ctes 200 """ 201 self._ensure_collected() 202 return self._ctes 203 204 @property 205 def derived_tables(self): 206 """ 207 List of derived tables in this scope. 208 209 For example: 210 SELECT * FROM (SELECT ...) <- that's a derived table 211 212 Returns: 213 list[exp.Subquery]: derived tables 214 """ 215 self._ensure_collected() 216 return self._derived_tables 217 218 @property 219 def udtfs(self): 220 """ 221 List of "User Defined Tabular Functions" in this scope. 222 223 Returns: 224 list[exp.UDTF]: UDTFs 225 """ 226 self._ensure_collected() 227 return self._udtfs 228 229 @property 230 def subqueries(self): 231 """ 232 List of subqueries in this scope. 233 234 For example: 235 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 236 237 Returns: 238 list[exp.Subqueryable]: subqueries 239 """ 240 self._ensure_collected() 241 return self._subqueries 242 243 @property 244 def columns(self): 245 """ 246 List of columns in this scope. 247 248 Returns: 249 list[exp.Column]: Column instances in this scope, plus any 250 Columns that reference this scope from correlated subqueries. 251 """ 252 if self._columns is None: 253 self._ensure_collected() 254 columns = self._raw_columns 255 256 external_columns = [ 257 column 258 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 259 for column in scope.external_columns 260 ] 261 262 named_selects = set(self.expression.named_selects) 263 264 self._columns = [] 265 for column in columns + external_columns: 266 ancestor = column.find_ancestor( 267 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint 268 ) 269 if ( 270 not ancestor 271 or column.table 272 or isinstance(ancestor, exp.Select) 273 or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window)) 274 or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) 275 ): 276 self._columns.append(column) 277 278 return self._columns 279 280 @property 281 def selected_sources(self): 282 """ 283 Mapping of nodes and sources that are actually selected from in this scope. 284 285 That is, all tables in a schema are selectable at any point. But a 286 table only becomes a selected source if it's included in a FROM or JOIN clause. 287 288 Returns: 289 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 290 """ 291 if self._selected_sources is None: 292 referenced_names = [] 293 294 for table in self.tables: 295 referenced_names.append((table.alias_or_name, table)) 296 for expression in itertools.chain(self.derived_tables, self.udtfs): 297 referenced_names.append((expression.alias, expression.unnest())) 298 result = {} 299 300 for name, node in referenced_names: 301 if name in result: 302 raise OptimizeError(f"Alias already used: {name}") 303 if name in self.sources: 304 result[name] = (node, self.sources[name]) 305 306 self._selected_sources = result 307 return self._selected_sources 308 309 @property 310 def cte_sources(self): 311 """ 312 Sources that are CTEs. 313 314 Returns: 315 dict[str, Scope]: Mapping of source alias to Scope 316 """ 317 return { 318 alias: scope 319 for alias, scope in self.sources.items() 320 if isinstance(scope, Scope) and scope.is_cte 321 } 322 323 @property 324 def selects(self): 325 """ 326 Select expressions of this scope. 327 328 For example, for the following expression: 329 SELECT 1 as a, 2 as b FROM x 330 331 The outputs are the "1 as a" and "2 as b" expressions. 332 333 Returns: 334 list[exp.Expression]: expressions 335 """ 336 if isinstance(self.expression, exp.Union): 337 return self.expression.unnest().selects 338 return self.expression.selects 339 340 @property 341 def external_columns(self): 342 """ 343 Columns that appear to reference sources in outer scopes. 344 345 Returns: 346 list[exp.Column]: Column instances that don't reference 347 sources in the current scope. 348 """ 349 if self._external_columns is None: 350 self._external_columns = [ 351 c for c in self.columns if c.table not in self.selected_sources 352 ] 353 return self._external_columns 354 355 @property 356 def unqualified_columns(self): 357 """ 358 Unqualified columns in the current scope. 359 360 Returns: 361 list[exp.Column]: Unqualified columns 362 """ 363 return [c for c in self.columns if not c.table] 364 365 @property 366 def join_hints(self): 367 """ 368 Hints that exist in the scope that reference tables 369 370 Returns: 371 list[exp.JoinHint]: Join hints that are referenced within the scope 372 """ 373 if self._join_hints is None: 374 return [] 375 return self._join_hints 376 377 @property 378 def pivots(self): 379 if not self._pivots: 380 self._pivots = [ 381 pivot 382 for node in self.tables + self.derived_tables 383 for pivot in node.args.get("pivots") or [] 384 ] 385 386 return self._pivots 387 388 def source_columns(self, source_name): 389 """ 390 Get all columns in the current scope for a particular source. 391 392 Args: 393 source_name (str): Name of the source 394 Returns: 395 list[exp.Column]: Column instances that reference `source_name` 396 """ 397 return [column for column in self.columns if column.table == source_name] 398 399 @property 400 def is_subquery(self): 401 """Determine if this scope is a subquery""" 402 return self.scope_type == ScopeType.SUBQUERY 403 404 @property 405 def is_derived_table(self): 406 """Determine if this scope is a derived table""" 407 return self.scope_type == ScopeType.DERIVED_TABLE 408 409 @property 410 def is_union(self): 411 """Determine if this scope is a union""" 412 return self.scope_type == ScopeType.UNION 413 414 @property 415 def is_cte(self): 416 """Determine if this scope is a common table expression""" 417 return self.scope_type == ScopeType.CTE 418 419 @property 420 def is_root(self): 421 """Determine if this is the root scope""" 422 return self.scope_type == ScopeType.ROOT 423 424 @property 425 def is_udtf(self): 426 """Determine if this scope is a UDTF (User Defined Table Function)""" 427 return self.scope_type == ScopeType.UDTF 428 429 @property 430 def is_correlated_subquery(self): 431 """Determine if this scope is a correlated subquery""" 432 return bool(self.is_subquery and self.external_columns) 433 434 def rename_source(self, old_name, new_name): 435 """Rename a source in this scope""" 436 columns = self.sources.pop(old_name or "", []) 437 self.sources[new_name] = columns 438 439 def add_source(self, name, source): 440 """Add a source to this scope""" 441 self.sources[name] = source 442 self.clear_cache() 443 444 def remove_source(self, name): 445 """Remove a source from this scope""" 446 self.sources.pop(name, None) 447 self.clear_cache() 448 449 def __repr__(self): 450 return f"Scope<{self.expression.sql()}>" 451 452 def traverse(self): 453 """ 454 Traverse the scope tree from this node. 455 456 Yields: 457 Scope: scope instances in depth-first-search post-order 458 """ 459 for child_scope in itertools.chain( 460 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 461 ): 462 yield from child_scope.traverse() 463 yield self 464 465 def ref_count(self): 466 """ 467 Count the number of times each scope in this tree is referenced. 468 469 Returns: 470 dict[int, int]: Mapping of Scope instance ID to reference count 471 """ 472 scope_ref_count = defaultdict(lambda: 0) 473 474 for scope in self.traverse(): 475 for _, source in scope.selected_sources.values(): 476 scope_ref_count[id(source)] += 1 477 478 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_column_list
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
51 def __init__( 52 self, 53 expression, 54 sources=None, 55 outer_column_list=None, 56 parent=None, 57 scope_type=ScopeType.ROOT, 58 lateral_sources=None, 59 ): 60 self.expression = expression 61 self.sources = sources or {} 62 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 63 self.sources.update(self.lateral_sources) 64 self.outer_column_list = outer_column_list or [] 65 self.parent = parent 66 self.scope_type = scope_type 67 self.subquery_scopes = [] 68 self.derived_table_scopes = [] 69 self.table_scopes = [] 70 self.cte_scopes = [] 71 self.union_scopes = [] 72 self.udtf_scopes = [] 73 self.clear_cache()
75 def clear_cache(self): 76 self._collected = False 77 self._raw_columns = None 78 self._derived_tables = None 79 self._udtfs = None 80 self._tables = None 81 self._ctes = None 82 self._subqueries = None 83 self._selected_sources = None 84 self._columns = None 85 self._external_columns = None 86 self._join_hints = None 87 self._pivots = None
89 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 90 """Branch from the current scope to a new, inner scope""" 91 return Scope( 92 expression=expression.unnest(), 93 sources={**self.cte_sources, **(chain_sources or {})}, 94 parent=self, 95 scope_type=scope_type, 96 **kwargs, 97 )
Branch from the current scope to a new, inner scope
135 def find(self, *expression_types, bfs=True): 136 """ 137 Returns the first node in this scope which matches at least one of the specified types. 138 139 This does NOT traverse into subscopes. 140 141 Args: 142 expression_types (type): the expression type(s) to match. 143 bfs (bool): True to use breadth-first search, False to use depth-first. 144 145 Returns: 146 exp.Expression: the node which matches the criteria or None if no node matching 147 the criteria was found. 148 """ 149 return next(self.find_all(*expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.
151 def find_all(self, *expression_types, bfs=True): 152 """ 153 Returns a generator object which visits all nodes in this scope and only yields those that 154 match at least one of the specified expression types. 155 156 This does NOT traverse into subscopes. 157 158 Args: 159 expression_types (type): the expression type(s) to match. 160 bfs (bool): True to use breadth-first search, False to use depth-first. 161 162 Yields: 163 exp.Expression: nodes 164 """ 165 for expression, *_ in self.walk(bfs=bfs): 166 if isinstance(expression, expression_types): 167 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
169 def replace(self, old, new): 170 """ 171 Replace `old` with `new`. 172 173 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 174 175 Args: 176 old (exp.Expression): old node 177 new (exp.Expression): new node 178 """ 179 old.replace(new) 180 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
Select expressions of this scope.
For example, for the following expression: SELECT 1 as a, 2 as b FROM x
The outputs are the "1 as a" and "2 as b" expressions.
Returns:
list[exp.Expression]: expressions
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
388 def source_columns(self, source_name): 389 """ 390 Get all columns in the current scope for a particular source. 391 392 Args: 393 source_name (str): Name of the source 394 Returns: 395 list[exp.Column]: Column instances that reference `source_name` 396 """ 397 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
434 def rename_source(self, old_name, new_name): 435 """Rename a source in this scope""" 436 columns = self.sources.pop(old_name or "", []) 437 self.sources[new_name] = columns
Rename a source in this scope
439 def add_source(self, name, source): 440 """Add a source to this scope""" 441 self.sources[name] = source 442 self.clear_cache()
Add a source to this scope
444 def remove_source(self, name): 445 """Remove a source from this scope""" 446 self.sources.pop(name, None) 447 self.clear_cache()
Remove a source from this scope
452 def traverse(self): 453 """ 454 Traverse the scope tree from this node. 455 456 Yields: 457 Scope: scope instances in depth-first-search post-order 458 """ 459 for child_scope in itertools.chain( 460 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 461 ): 462 yield from child_scope.traverse() 463 yield self
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
465 def ref_count(self): 466 """ 467 Count the number of times each scope in this tree is referenced. 468 469 Returns: 470 dict[int, int]: Mapping of Scope instance ID to reference count 471 """ 472 scope_ref_count = defaultdict(lambda: 0) 473 474 for scope in self.traverse(): 475 for _, source in scope.selected_sources.values(): 476 scope_ref_count[id(source)] += 1 477 478 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
481def traverse_scope(expression): 482 """ 483 Traverse an expression by it's "scopes". 484 485 "Scope" represents the current context of a Select statement. 486 487 This is helpful for optimizing queries, where we need more information than 488 the expression tree itself. For example, we might care about the source 489 names within a subquery. Returns a list because a generator could result in 490 incomplete properties which is confusing. 491 492 Examples: 493 >>> import sqlglot 494 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 495 >>> scopes = traverse_scope(expression) 496 >>> scopes[0].expression.sql(), list(scopes[0].sources) 497 ('SELECT a FROM x', ['x']) 498 >>> scopes[1].expression.sql(), list(scopes[1].sources) 499 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 500 501 Args: 502 expression (exp.Expression): expression to traverse 503 Returns: 504 list[Scope]: scope instances 505 """ 506 return list(_traverse_scope(Scope(expression)))
Traverse an expression by it's "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
509def build_scope(expression): 510 """ 511 Build a scope tree. 512 513 Args: 514 expression (exp.Expression): expression to build the scope tree for 515 Returns: 516 Scope: root scope 517 """ 518 return traverse_scope(expression)[-1]
Build a scope tree.
Arguments:
- expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
674def walk_in_scope(expression, bfs=True): 675 """ 676 Returns a generator object which visits all nodes in the syntrax tree, stopping at 677 nodes that start child scopes. 678 679 Args: 680 expression (exp.Expression): 681 bfs (bool): if set to True the BFS traversal order will be applied, 682 otherwise the DFS traversal will be used instead. 683 684 Yields: 685 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 686 """ 687 # We'll use this variable to pass state into the dfs generator. 688 # Whenever we set it to True, we exclude a subtree from traversal. 689 prune = False 690 691 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 692 prune = False 693 694 yield node, parent, key 695 696 if node is expression: 697 continue 698 if ( 699 isinstance(node, exp.CTE) 700 or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) 701 or isinstance(node, exp.UDTF) 702 or isinstance(node, exp.Subqueryable) 703 ): 704 prune = True
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key