sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25) -> exp.Expression: 26 """ 27 Rewrite sqlglot AST to have fully qualified columns. 28 29 Example: 30 >>> import sqlglot 31 >>> schema = {"tbl": {"col": "INT"}} 32 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 33 >>> qualify_columns(expression, schema).sql() 34 'SELECT tbl.col AS col FROM tbl' 35 36 Args: 37 expression: Expression to qualify. 38 schema: Database schema. 39 expand_alias_refs: Whether to expand references to aliases. 40 expand_stars: Whether to expand star queries. This is a necessary step 41 for most of the optimizer's rules to work; do not set to False unless you 42 know what you're doing! 43 infer_schema: Whether to infer the schema if missing. 44 45 Returns: 46 The qualified expression. 47 48 Notes: 49 - Currently only handles a single PIVOT or UNPIVOT operator 50 """ 51 schema = ensure_schema(schema) 52 annotator = TypeAnnotator(schema) 53 infer_schema = schema.empty if infer_schema is None else infer_schema 54 dialect = Dialect.get_or_raise(schema.dialect) 55 pseudocolumns = dialect.PSEUDOCOLUMNS 56 57 for scope in traverse_scope(expression): 58 resolver = Resolver(scope, schema, infer_schema=infer_schema) 59 _pop_table_column_aliases(scope.ctes) 60 _pop_table_column_aliases(scope.derived_tables) 61 using_column_tables = _expand_using(scope, resolver) 62 63 if schema.empty and expand_alias_refs: 64 _expand_alias_refs(scope, resolver) 65 66 _convert_columns_to_dots(scope, resolver) 67 _qualify_columns(scope, resolver) 68 69 if not schema.empty and expand_alias_refs: 70 _expand_alias_refs(scope, resolver) 71 72 if not isinstance(scope.expression, exp.UDTF): 73 if expand_stars: 74 _expand_stars( 75 scope, 76 resolver, 77 using_column_tables, 78 pseudocolumns, 79 annotator, 80 ) 81 qualify_outputs(scope) 82 83 _expand_group_by(scope) 84 _expand_order_by(scope, resolver) 85 86 if dialect == "bigquery": 87 annotator.annotate_scope(scope) 88 89 return expression 90 91 92def validate_qualify_columns(expression: E) -> E: 93 """Raise an `OptimizeError` if any columns aren't qualified""" 94 all_unqualified_columns = [] 95 for scope in traverse_scope(expression): 96 if isinstance(scope.expression, exp.Select): 97 unqualified_columns = scope.unqualified_columns 98 99 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 100 column = scope.external_columns[0] 101 for_table = f" for table: '{column.table}'" if column.table else "" 102 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 103 104 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 105 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 106 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 107 # this list here to ensure those in the former category will be excluded. 108 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 109 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 110 111 all_unqualified_columns.extend(unqualified_columns) 112 113 if all_unqualified_columns: 114 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 115 116 return expression 117 118 119def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 120 name_column = [] 121 field = unpivot.args.get("field") 122 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 123 name_column.append(field.this) 124 125 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 126 return itertools.chain(name_column, value_columns) 127 128 129def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 130 """ 131 Remove table column aliases. 132 133 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 134 """ 135 for derived_table in derived_tables: 136 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 137 continue 138 table_alias = derived_table.args.get("alias") 139 if table_alias: 140 table_alias.args.pop("columns", None) 141 142 143def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 144 joins = list(scope.find_all(exp.Join)) 145 names = {join.alias_or_name for join in joins} 146 ordered = [key for key in scope.selected_sources if key not in names] 147 148 # Mapping of automatically joined column names to an ordered set of source names (dict). 149 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 150 151 for join in joins: 152 using = join.args.get("using") 153 154 if not using: 155 continue 156 157 join_table = join.alias_or_name 158 159 columns = {} 160 161 for source_name in scope.selected_sources: 162 if source_name in ordered: 163 for column_name in resolver.get_source_columns(source_name): 164 if column_name not in columns: 165 columns[column_name] = source_name 166 167 source_table = ordered[-1] 168 ordered.append(join_table) 169 join_columns = resolver.get_source_columns(join_table) 170 conditions = [] 171 172 for identifier in using: 173 identifier = identifier.name 174 table = columns.get(identifier) 175 176 if not table or identifier not in join_columns: 177 if (columns and "*" not in columns) and join_columns: 178 raise OptimizeError(f"Cannot automatically join: {identifier}") 179 180 table = table or source_table 181 conditions.append( 182 exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table)) 183 ) 184 185 # Set all values in the dict to None, because we only care about the key ordering 186 tables = column_tables.setdefault(identifier, {}) 187 if table not in tables: 188 tables[table] = None 189 if join_table not in tables: 190 tables[join_table] = None 191 192 join.args.pop("using") 193 join.set("on", exp.and_(*conditions, copy=False)) 194 195 if column_tables: 196 for column in scope.columns: 197 if not column.table and column.name in column_tables: 198 tables = column_tables[column.name] 199 coalesce = [exp.column(column.name, table=table) for table in tables] 200 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 201 202 # Ensure selects keep their output name 203 if isinstance(column.parent, exp.Select): 204 replacement = alias(replacement, alias=column.name, copy=False) 205 206 scope.replace(column, replacement) 207 208 return column_tables 209 210 211def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 212 expression = scope.expression 213 214 if not isinstance(expression, exp.Select): 215 return 216 217 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 218 219 def replace_columns( 220 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 221 ) -> None: 222 if not node: 223 return 224 225 for column in walk_in_scope(node, prune=lambda node: node.is_star): 226 if not isinstance(column, exp.Column): 227 continue 228 229 table = resolver.get_table(column.name) if resolve_table and not column.table else None 230 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 231 double_agg = ( 232 ( 233 alias_expr.find(exp.AggFunc) 234 and ( 235 column.find_ancestor(exp.AggFunc) 236 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 237 ) 238 ) 239 if alias_expr 240 else False 241 ) 242 243 if table and (not alias_expr or double_agg): 244 column.set("table", table) 245 elif not column.table and alias_expr and not double_agg: 246 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 247 if literal_index: 248 column.replace(exp.Literal.number(i)) 249 else: 250 column = column.replace(exp.paren(alias_expr)) 251 simplified = simplify_parens(column) 252 if simplified is not column: 253 column.replace(simplified) 254 255 for i, projection in enumerate(scope.expression.selects): 256 replace_columns(projection) 257 258 if isinstance(projection, exp.Alias): 259 alias_to_expression[projection.alias] = (projection.this, i + 1) 260 261 replace_columns(expression.args.get("where")) 262 replace_columns(expression.args.get("group"), literal_index=True) 263 replace_columns(expression.args.get("having"), resolve_table=True) 264 replace_columns(expression.args.get("qualify"), resolve_table=True) 265 266 scope.clear_cache() 267 268 269def _expand_group_by(scope: Scope) -> None: 270 expression = scope.expression 271 group = expression.args.get("group") 272 if not group: 273 return 274 275 group.set("expressions", _expand_positional_references(scope, group.expressions)) 276 expression.set("group", group) 277 278 279def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 280 order = scope.expression.args.get("order") 281 if not order: 282 return 283 284 ordereds = order.expressions 285 for ordered, new_expression in zip( 286 ordereds, 287 _expand_positional_references(scope, (o.this for o in ordereds), alias=True), 288 ): 289 for agg in ordered.find_all(exp.AggFunc): 290 for col in agg.find_all(exp.Column): 291 if not col.table: 292 col.set("table", resolver.get_table(col.name)) 293 294 ordered.set("this", new_expression) 295 296 if scope.expression.args.get("group"): 297 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 298 299 for ordered in ordereds: 300 ordered = ordered.this 301 302 ordered.replace( 303 exp.to_identifier(_select_by_pos(scope, ordered).alias) 304 if ordered.is_int 305 else selects.get(ordered, ordered) 306 ) 307 308 309def _expand_positional_references( 310 scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False 311) -> t.List[exp.Expression]: 312 new_nodes: t.List[exp.Expression] = [] 313 for node in expressions: 314 if node.is_int: 315 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 316 317 if alias: 318 new_nodes.append(exp.column(select.args["alias"].copy())) 319 else: 320 select = select.this 321 322 if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest): 323 new_nodes.append(node) 324 else: 325 new_nodes.append(select.copy()) 326 else: 327 new_nodes.append(node) 328 329 return new_nodes 330 331 332def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 333 try: 334 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 335 except IndexError: 336 raise OptimizeError(f"Unknown output column: {node.name}") 337 338 339def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 340 """ 341 Converts `Column` instances that represent struct field lookup into chained `Dots`. 342 343 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 344 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 345 """ 346 converted = False 347 for column in itertools.chain(scope.columns, scope.stars): 348 if isinstance(column, exp.Dot): 349 continue 350 351 column_table: t.Optional[str | exp.Identifier] = column.table 352 if ( 353 column_table 354 and column_table not in scope.sources 355 and ( 356 not scope.parent 357 or column_table not in scope.parent.sources 358 or not scope.is_correlated_subquery 359 ) 360 ): 361 root, *parts = column.parts 362 363 if root.name in scope.sources: 364 # The struct is already qualified, but we still need to change the AST 365 column_table = root 366 root, *parts = parts 367 else: 368 column_table = resolver.get_table(root.name) 369 370 if column_table: 371 converted = True 372 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 373 374 if converted: 375 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 376 # a `for column in scope.columns` iteration, even though they shouldn't be 377 scope.clear_cache() 378 379 380def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 381 """Disambiguate columns, ensuring each column specifies a source""" 382 for column in scope.columns: 383 column_table = column.table 384 column_name = column.name 385 386 if column_table and column_table in scope.sources: 387 source_columns = resolver.get_source_columns(column_table) 388 if source_columns and column_name not in source_columns and "*" not in source_columns: 389 raise OptimizeError(f"Unknown column: {column_name}") 390 391 if not column_table: 392 if scope.pivots and not column.find_ancestor(exp.Pivot): 393 # If the column is under the Pivot expression, we need to qualify it 394 # using the name of the pivoted source instead of the pivot's alias 395 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 396 continue 397 398 # column_table can be a '' because bigquery unnest has no table alias 399 column_table = resolver.get_table(column_name) 400 if column_table: 401 column.set("table", column_table) 402 403 for pivot in scope.pivots: 404 for column in pivot.find_all(exp.Column): 405 if not column.table and column.name in resolver.all_columns: 406 column_table = resolver.get_table(column.name) 407 if column_table: 408 column.set("table", column_table) 409 410 411def _expand_struct_stars( 412 expression: exp.Dot, 413) -> t.List[exp.Alias]: 414 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 415 416 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 417 if not dot_column.is_type(exp.DataType.Type.STRUCT): 418 return [] 419 420 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 421 dot_column = dot_column.copy() 422 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 423 424 # First part is the table name and last part is the star so they can be dropped 425 dot_parts = expression.parts[1:-1] 426 427 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 428 for part in dot_parts[1:]: 429 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 430 # Unable to expand star unless all fields are named 431 if not isinstance(field.this, exp.Identifier): 432 return [] 433 434 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 435 starting_struct = field 436 break 437 else: 438 # There is no matching field in the struct 439 return [] 440 441 taken_names = set() 442 new_selections = [] 443 444 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 445 name = field.name 446 447 # Ambiguous or anonymous fields can't be expanded 448 if name in taken_names or not isinstance(field.this, exp.Identifier): 449 return [] 450 451 taken_names.add(name) 452 453 this = field.this.copy() 454 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 455 new_column = exp.column( 456 t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts 457 ) 458 new_selections.append(alias(new_column, this, copy=False)) 459 460 return new_selections 461 462 463def _expand_stars( 464 scope: Scope, 465 resolver: Resolver, 466 using_column_tables: t.Dict[str, t.Any], 467 pseudocolumns: t.Set[str], 468 annotator: TypeAnnotator, 469) -> None: 470 """Expand stars to lists of column selections""" 471 472 new_selections = [] 473 except_columns: t.Dict[int, t.Set[str]] = {} 474 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 475 coalesced_columns = set() 476 dialect = resolver.schema.dialect 477 478 pivot_output_columns = None 479 pivot_exclude_columns = None 480 481 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 482 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 483 if pivot.unpivot: 484 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 485 486 field = pivot.args.get("field") 487 if isinstance(field, exp.In): 488 pivot_exclude_columns = { 489 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 490 } 491 else: 492 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 493 494 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 495 if not pivot_output_columns: 496 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 497 498 is_bigquery = dialect == "bigquery" 499 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 500 # Found struct expansion, annotate scope ahead of time 501 annotator.annotate_scope(scope) 502 503 for expression in scope.expression.selects: 504 tables = [] 505 if isinstance(expression, exp.Star): 506 tables.extend(scope.selected_sources) 507 _add_except_columns(expression, tables, except_columns) 508 _add_replace_columns(expression, tables, replace_columns) 509 elif expression.is_star: 510 if not isinstance(expression, exp.Dot): 511 tables.append(expression.table) 512 _add_except_columns(expression.this, tables, except_columns) 513 _add_replace_columns(expression.this, tables, replace_columns) 514 elif is_bigquery: 515 struct_fields = _expand_struct_stars(expression) 516 if struct_fields: 517 new_selections.extend(struct_fields) 518 continue 519 520 if not tables: 521 new_selections.append(expression) 522 continue 523 524 for table in tables: 525 if table not in scope.sources: 526 raise OptimizeError(f"Unknown table: {table}") 527 528 columns = resolver.get_source_columns(table, only_visible=True) 529 columns = columns or scope.outer_columns 530 531 if pseudocolumns: 532 columns = [name for name in columns if name.upper() not in pseudocolumns] 533 534 if not columns or "*" in columns: 535 return 536 537 table_id = id(table) 538 columns_to_exclude = except_columns.get(table_id) or set() 539 540 if pivot: 541 if pivot_output_columns and pivot_exclude_columns: 542 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 543 pivot_columns.extend(pivot_output_columns) 544 else: 545 pivot_columns = pivot.alias_column_names 546 547 if pivot_columns: 548 new_selections.extend( 549 alias(exp.column(name, table=pivot.alias), name, copy=False) 550 for name in pivot_columns 551 if name not in columns_to_exclude 552 ) 553 continue 554 555 for name in columns: 556 if name in columns_to_exclude or name in coalesced_columns: 557 continue 558 if name in using_column_tables and table in using_column_tables[name]: 559 coalesced_columns.add(name) 560 tables = using_column_tables[name] 561 coalesce = [exp.column(name, table=table) for table in tables] 562 563 new_selections.append( 564 alias( 565 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 566 alias=name, 567 copy=False, 568 ) 569 ) 570 else: 571 alias_ = replace_columns.get(table_id, {}).get(name, name) 572 column = exp.column(name, table=table) 573 new_selections.append( 574 alias(column, alias_, copy=False) if alias_ != name else column 575 ) 576 577 # Ensures we don't overwrite the initial selections with an empty list 578 if new_selections and isinstance(scope.expression, exp.Select): 579 scope.expression.set("expressions", new_selections) 580 581 582def _add_except_columns( 583 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 584) -> None: 585 except_ = expression.args.get("except") 586 587 if not except_: 588 return 589 590 columns = {e.name for e in except_} 591 592 for table in tables: 593 except_columns[id(table)] = columns 594 595 596def _add_replace_columns( 597 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 598) -> None: 599 replace = expression.args.get("replace") 600 601 if not replace: 602 return 603 604 columns = {e.this.name: e.alias for e in replace} 605 606 for table in tables: 607 replace_columns[id(table)] = columns 608 609 610def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 611 """Ensure all output columns are aliased""" 612 if isinstance(scope_or_expression, exp.Expression): 613 scope = build_scope(scope_or_expression) 614 if not isinstance(scope, Scope): 615 return 616 else: 617 scope = scope_or_expression 618 619 new_selections = [] 620 for i, (selection, aliased_column) in enumerate( 621 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 622 ): 623 if selection is None: 624 break 625 626 if isinstance(selection, exp.Subquery): 627 if not selection.output_name: 628 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 629 elif not isinstance(selection, exp.Alias) and not selection.is_star: 630 selection = alias( 631 selection, 632 alias=selection.output_name or f"_col_{i}", 633 copy=False, 634 ) 635 if aliased_column: 636 selection.set("alias", exp.to_identifier(aliased_column)) 637 638 new_selections.append(selection) 639 640 if isinstance(scope.expression, exp.Select): 641 scope.expression.set("expressions", new_selections) 642 643 644def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 645 """Makes sure all identifiers that need to be quoted are quoted.""" 646 return expression.transform( 647 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 648 ) # type: ignore 649 650 651def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 652 """ 653 Pushes down the CTE alias columns into the projection, 654 655 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 656 657 Example: 658 >>> import sqlglot 659 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 660 >>> pushdown_cte_alias_columns(expression).sql() 661 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 662 663 Args: 664 expression: Expression to pushdown. 665 666 Returns: 667 The expression with the CTE aliases pushed down into the projection. 668 """ 669 for cte in expression.find_all(exp.CTE): 670 if cte.alias_column_names: 671 new_expressions = [] 672 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 673 if isinstance(projection, exp.Alias): 674 projection.set("alias", _alias) 675 else: 676 projection = alias(projection, alias=_alias) 677 new_expressions.append(projection) 678 cte.this.set("expressions", new_expressions) 679 680 return expression 681 682 683class Resolver: 684 """ 685 Helper for resolving columns. 686 687 This is a class so we can lazily load some things and easily share them across functions. 688 """ 689 690 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 691 self.scope = scope 692 self.schema = schema 693 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 694 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 695 self._all_columns: t.Optional[t.Set[str]] = None 696 self._infer_schema = infer_schema 697 698 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 699 """ 700 Get the table for a column name. 701 702 Args: 703 column_name: The column name to find the table for. 704 Returns: 705 The table name if it can be found/inferred. 706 """ 707 if self._unambiguous_columns is None: 708 self._unambiguous_columns = self._get_unambiguous_columns( 709 self._get_all_source_columns() 710 ) 711 712 table_name = self._unambiguous_columns.get(column_name) 713 714 if not table_name and self._infer_schema: 715 sources_without_schema = tuple( 716 source 717 for source, columns in self._get_all_source_columns().items() 718 if not columns or "*" in columns 719 ) 720 if len(sources_without_schema) == 1: 721 table_name = sources_without_schema[0] 722 723 if table_name not in self.scope.selected_sources: 724 return exp.to_identifier(table_name) 725 726 node, _ = self.scope.selected_sources.get(table_name) 727 728 if isinstance(node, exp.Query): 729 while node and node.alias != table_name: 730 node = node.parent 731 732 node_alias = node.args.get("alias") 733 if node_alias: 734 return exp.to_identifier(node_alias.this) 735 736 return exp.to_identifier(table_name) 737 738 @property 739 def all_columns(self) -> t.Set[str]: 740 """All available columns of all sources in this scope""" 741 if self._all_columns is None: 742 self._all_columns = { 743 column for columns in self._get_all_source_columns().values() for column in columns 744 } 745 return self._all_columns 746 747 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 748 """Resolve the source columns for a given source `name`.""" 749 if name not in self.scope.sources: 750 raise OptimizeError(f"Unknown table: {name}") 751 752 source = self.scope.sources[name] 753 754 if isinstance(source, exp.Table): 755 columns = self.schema.column_names(source, only_visible) 756 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 757 columns = source.expression.named_selects 758 759 # in bigquery, unnest structs are automatically scoped as tables, so you can 760 # directly select a struct field in a query. 761 # this handles the case where the unnest is statically defined. 762 if self.schema.dialect == "bigquery": 763 if source.expression.is_type(exp.DataType.Type.STRUCT): 764 for k in source.expression.type.expressions: # type: ignore 765 columns.append(k.name) 766 else: 767 columns = source.expression.named_selects 768 769 node, _ = self.scope.selected_sources.get(name) or (None, None) 770 if isinstance(node, Scope): 771 column_aliases = node.expression.alias_column_names 772 elif isinstance(node, exp.Expression): 773 column_aliases = node.alias_column_names 774 else: 775 column_aliases = [] 776 777 if column_aliases: 778 # If the source's columns are aliased, their aliases shadow the corresponding column names. 779 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 780 return [ 781 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 782 ] 783 return columns 784 785 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 786 if self._source_columns is None: 787 self._source_columns = { 788 source_name: self.get_source_columns(source_name) 789 for source_name, source in itertools.chain( 790 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 791 ) 792 } 793 return self._source_columns 794 795 def _get_unambiguous_columns( 796 self, source_columns: t.Dict[str, t.Sequence[str]] 797 ) -> t.Mapping[str, str]: 798 """ 799 Find all the unambiguous columns in sources. 800 801 Args: 802 source_columns: Mapping of names to source columns. 803 804 Returns: 805 Mapping of column name to source name. 806 """ 807 if not source_columns: 808 return {} 809 810 source_columns_pairs = list(source_columns.items()) 811 812 first_table, first_columns = source_columns_pairs[0] 813 814 if len(source_columns_pairs) == 1: 815 # Performance optimization - avoid copying first_columns if there is only one table. 816 return SingleValuedMapping(first_columns, first_table) 817 818 unambiguous_columns = {col: first_table for col in first_columns} 819 all_columns = set(unambiguous_columns) 820 821 for table, columns in source_columns_pairs[1:]: 822 unique = set(columns) 823 ambiguous = all_columns.intersection(unique) 824 all_columns.update(columns) 825 826 for column in ambiguous: 827 unambiguous_columns.pop(column, None) 828 for column in unique.difference(ambiguous): 829 unambiguous_columns[column] = table 830 831 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 46 Returns: 47 The qualified expression. 48 49 Notes: 50 - Currently only handles a single PIVOT or UNPIVOT operator 51 """ 52 schema = ensure_schema(schema) 53 annotator = TypeAnnotator(schema) 54 infer_schema = schema.empty if infer_schema is None else infer_schema 55 dialect = Dialect.get_or_raise(schema.dialect) 56 pseudocolumns = dialect.PSEUDOCOLUMNS 57 58 for scope in traverse_scope(expression): 59 resolver = Resolver(scope, schema, infer_schema=infer_schema) 60 _pop_table_column_aliases(scope.ctes) 61 _pop_table_column_aliases(scope.derived_tables) 62 using_column_tables = _expand_using(scope, resolver) 63 64 if schema.empty and expand_alias_refs: 65 _expand_alias_refs(scope, resolver) 66 67 _convert_columns_to_dots(scope, resolver) 68 _qualify_columns(scope, resolver) 69 70 if not schema.empty and expand_alias_refs: 71 _expand_alias_refs(scope, resolver) 72 73 if not isinstance(scope.expression, exp.UDTF): 74 if expand_stars: 75 _expand_stars( 76 scope, 77 resolver, 78 using_column_tables, 79 pseudocolumns, 80 annotator, 81 ) 82 qualify_outputs(scope) 83 84 _expand_group_by(scope) 85 _expand_order_by(scope, resolver) 86 87 if dialect == "bigquery": 88 annotator.annotate_scope(scope) 89 90 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
93def validate_qualify_columns(expression: E) -> E: 94 """Raise an `OptimizeError` if any columns aren't qualified""" 95 all_unqualified_columns = [] 96 for scope in traverse_scope(expression): 97 if isinstance(scope.expression, exp.Select): 98 unqualified_columns = scope.unqualified_columns 99 100 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 101 column = scope.external_columns[0] 102 for_table = f" for table: '{column.table}'" if column.table else "" 103 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 104 105 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 106 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 107 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 108 # this list here to ensure those in the former category will be excluded. 109 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 110 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 111 112 all_unqualified_columns.extend(unqualified_columns) 113 114 if all_unqualified_columns: 115 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 116 117 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
611def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 612 """Ensure all output columns are aliased""" 613 if isinstance(scope_or_expression, exp.Expression): 614 scope = build_scope(scope_or_expression) 615 if not isinstance(scope, Scope): 616 return 617 else: 618 scope = scope_or_expression 619 620 new_selections = [] 621 for i, (selection, aliased_column) in enumerate( 622 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 623 ): 624 if selection is None: 625 break 626 627 if isinstance(selection, exp.Subquery): 628 if not selection.output_name: 629 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 630 elif not isinstance(selection, exp.Alias) and not selection.is_star: 631 selection = alias( 632 selection, 633 alias=selection.output_name or f"_col_{i}", 634 copy=False, 635 ) 636 if aliased_column: 637 selection.set("alias", exp.to_identifier(aliased_column)) 638 639 new_selections.append(selection) 640 641 if isinstance(scope.expression, exp.Select): 642 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
645def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 646 """Makes sure all identifiers that need to be quoted are quoted.""" 647 return expression.transform( 648 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 649 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
652def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 653 """ 654 Pushes down the CTE alias columns into the projection, 655 656 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 657 658 Example: 659 >>> import sqlglot 660 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 661 >>> pushdown_cte_alias_columns(expression).sql() 662 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 663 664 Args: 665 expression: Expression to pushdown. 666 667 Returns: 668 The expression with the CTE aliases pushed down into the projection. 669 """ 670 for cte in expression.find_all(exp.CTE): 671 if cte.alias_column_names: 672 new_expressions = [] 673 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 674 if isinstance(projection, exp.Alias): 675 projection.set("alias", _alias) 676 else: 677 projection = alias(projection, alias=_alias) 678 new_expressions.append(projection) 679 cte.this.set("expressions", new_expressions) 680 681 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
684class Resolver: 685 """ 686 Helper for resolving columns. 687 688 This is a class so we can lazily load some things and easily share them across functions. 689 """ 690 691 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 692 self.scope = scope 693 self.schema = schema 694 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 695 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 696 self._all_columns: t.Optional[t.Set[str]] = None 697 self._infer_schema = infer_schema 698 699 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 700 """ 701 Get the table for a column name. 702 703 Args: 704 column_name: The column name to find the table for. 705 Returns: 706 The table name if it can be found/inferred. 707 """ 708 if self._unambiguous_columns is None: 709 self._unambiguous_columns = self._get_unambiguous_columns( 710 self._get_all_source_columns() 711 ) 712 713 table_name = self._unambiguous_columns.get(column_name) 714 715 if not table_name and self._infer_schema: 716 sources_without_schema = tuple( 717 source 718 for source, columns in self._get_all_source_columns().items() 719 if not columns or "*" in columns 720 ) 721 if len(sources_without_schema) == 1: 722 table_name = sources_without_schema[0] 723 724 if table_name not in self.scope.selected_sources: 725 return exp.to_identifier(table_name) 726 727 node, _ = self.scope.selected_sources.get(table_name) 728 729 if isinstance(node, exp.Query): 730 while node and node.alias != table_name: 731 node = node.parent 732 733 node_alias = node.args.get("alias") 734 if node_alias: 735 return exp.to_identifier(node_alias.this) 736 737 return exp.to_identifier(table_name) 738 739 @property 740 def all_columns(self) -> t.Set[str]: 741 """All available columns of all sources in this scope""" 742 if self._all_columns is None: 743 self._all_columns = { 744 column for columns in self._get_all_source_columns().values() for column in columns 745 } 746 return self._all_columns 747 748 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 749 """Resolve the source columns for a given source `name`.""" 750 if name not in self.scope.sources: 751 raise OptimizeError(f"Unknown table: {name}") 752 753 source = self.scope.sources[name] 754 755 if isinstance(source, exp.Table): 756 columns = self.schema.column_names(source, only_visible) 757 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 758 columns = source.expression.named_selects 759 760 # in bigquery, unnest structs are automatically scoped as tables, so you can 761 # directly select a struct field in a query. 762 # this handles the case where the unnest is statically defined. 763 if self.schema.dialect == "bigquery": 764 if source.expression.is_type(exp.DataType.Type.STRUCT): 765 for k in source.expression.type.expressions: # type: ignore 766 columns.append(k.name) 767 else: 768 columns = source.expression.named_selects 769 770 node, _ = self.scope.selected_sources.get(name) or (None, None) 771 if isinstance(node, Scope): 772 column_aliases = node.expression.alias_column_names 773 elif isinstance(node, exp.Expression): 774 column_aliases = node.alias_column_names 775 else: 776 column_aliases = [] 777 778 if column_aliases: 779 # If the source's columns are aliased, their aliases shadow the corresponding column names. 780 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 781 return [ 782 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 783 ] 784 return columns 785 786 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 787 if self._source_columns is None: 788 self._source_columns = { 789 source_name: self.get_source_columns(source_name) 790 for source_name, source in itertools.chain( 791 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 792 ) 793 } 794 return self._source_columns 795 796 def _get_unambiguous_columns( 797 self, source_columns: t.Dict[str, t.Sequence[str]] 798 ) -> t.Mapping[str, str]: 799 """ 800 Find all the unambiguous columns in sources. 801 802 Args: 803 source_columns: Mapping of names to source columns. 804 805 Returns: 806 Mapping of column name to source name. 807 """ 808 if not source_columns: 809 return {} 810 811 source_columns_pairs = list(source_columns.items()) 812 813 first_table, first_columns = source_columns_pairs[0] 814 815 if len(source_columns_pairs) == 1: 816 # Performance optimization - avoid copying first_columns if there is only one table. 817 return SingleValuedMapping(first_columns, first_table) 818 819 unambiguous_columns = {col: first_table for col in first_columns} 820 all_columns = set(unambiguous_columns) 821 822 for table, columns in source_columns_pairs[1:]: 823 unique = set(columns) 824 ambiguous = all_columns.intersection(unique) 825 all_columns.update(columns) 826 827 for column in ambiguous: 828 unambiguous_columns.pop(column, None) 829 for column in unique.difference(ambiguous): 830 unambiguous_columns[column] = table 831 832 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
691 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 692 self.scope = scope 693 self.schema = schema 694 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 695 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 696 self._all_columns: t.Optional[t.Set[str]] = None 697 self._infer_schema = infer_schema
699 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 700 """ 701 Get the table for a column name. 702 703 Args: 704 column_name: The column name to find the table for. 705 Returns: 706 The table name if it can be found/inferred. 707 """ 708 if self._unambiguous_columns is None: 709 self._unambiguous_columns = self._get_unambiguous_columns( 710 self._get_all_source_columns() 711 ) 712 713 table_name = self._unambiguous_columns.get(column_name) 714 715 if not table_name and self._infer_schema: 716 sources_without_schema = tuple( 717 source 718 for source, columns in self._get_all_source_columns().items() 719 if not columns or "*" in columns 720 ) 721 if len(sources_without_schema) == 1: 722 table_name = sources_without_schema[0] 723 724 if table_name not in self.scope.selected_sources: 725 return exp.to_identifier(table_name) 726 727 node, _ = self.scope.selected_sources.get(table_name) 728 729 if isinstance(node, exp.Query): 730 while node and node.alias != table_name: 731 node = node.parent 732 733 node_alias = node.args.get("alias") 734 if node_alias: 735 return exp.to_identifier(node_alias.this) 736 737 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
739 @property 740 def all_columns(self) -> t.Set[str]: 741 """All available columns of all sources in this scope""" 742 if self._all_columns is None: 743 self._all_columns = { 744 column for columns in self._get_all_source_columns().values() for column in columns 745 } 746 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
748 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 749 """Resolve the source columns for a given source `name`.""" 750 if name not in self.scope.sources: 751 raise OptimizeError(f"Unknown table: {name}") 752 753 source = self.scope.sources[name] 754 755 if isinstance(source, exp.Table): 756 columns = self.schema.column_names(source, only_visible) 757 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 758 columns = source.expression.named_selects 759 760 # in bigquery, unnest structs are automatically scoped as tables, so you can 761 # directly select a struct field in a query. 762 # this handles the case where the unnest is statically defined. 763 if self.schema.dialect == "bigquery": 764 if source.expression.is_type(exp.DataType.Type.STRUCT): 765 for k in source.expression.type.expressions: # type: ignore 766 columns.append(k.name) 767 else: 768 columns = source.expression.named_selects 769 770 node, _ = self.scope.selected_sources.get(name) or (None, None) 771 if isinstance(node, Scope): 772 column_aliases = node.expression.alias_column_names 773 elif isinstance(node, exp.Expression): 774 column_aliases = node.alias_column_names 775 else: 776 column_aliases = [] 777 778 if column_aliases: 779 # If the source's columns are aliased, their aliases shadow the corresponding column names. 780 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 781 return [ 782 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 783 ] 784 return columns
Resolve the source columns for a given source name
.