sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot._typing import E 8from sqlglot.dialects.dialect import Dialect, DialectType 9from sqlglot.errors import OptimizeError 10from sqlglot.helper import seq_get 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15 16def qualify_columns( 17 expression: exp.Expression, 18 schema: t.Dict | Schema, 19 expand_alias_refs: bool = True, 20 expand_stars: bool = True, 21 infer_schema: t.Optional[bool] = None, 22) -> exp.Expression: 23 """ 24 Rewrite sqlglot AST to have fully qualified columns. 25 26 Example: 27 >>> import sqlglot 28 >>> schema = {"tbl": {"col": "INT"}} 29 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 30 >>> qualify_columns(expression, schema).sql() 31 'SELECT tbl.col AS col FROM tbl' 32 33 Args: 34 expression: Expression to qualify. 35 schema: Database schema. 36 expand_alias_refs: Whether or not to expand references to aliases. 37 expand_stars: Whether or not to expand star queries. This is a necessary step 38 for most of the optimizer's rules to work; do not set to False unless you 39 know what you're doing! 40 infer_schema: Whether or not to infer the schema if missing. 41 42 Returns: 43 The qualified expression. 44 45 Notes: 46 - Currently only handles a single PIVOT or UNPIVOT operator 47 """ 48 schema = ensure_schema(schema) 49 infer_schema = schema.empty if infer_schema is None else infer_schema 50 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 51 52 for scope in traverse_scope(expression): 53 resolver = Resolver(scope, schema, infer_schema=infer_schema) 54 _pop_table_column_aliases(scope.ctes) 55 _pop_table_column_aliases(scope.derived_tables) 56 using_column_tables = _expand_using(scope, resolver) 57 58 if schema.empty and expand_alias_refs: 59 _expand_alias_refs(scope, resolver) 60 61 _qualify_columns(scope, resolver) 62 63 if not schema.empty and expand_alias_refs: 64 _expand_alias_refs(scope, resolver) 65 66 if not isinstance(scope.expression, exp.UDTF): 67 if expand_stars: 68 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 69 qualify_outputs(scope) 70 71 _expand_group_by(scope) 72 _expand_order_by(scope, resolver) 73 74 return expression 75 76 77def validate_qualify_columns(expression: E) -> E: 78 """Raise an `OptimizeError` if any columns aren't qualified""" 79 all_unqualified_columns = [] 80 for scope in traverse_scope(expression): 81 if isinstance(scope.expression, exp.Select): 82 unqualified_columns = scope.unqualified_columns 83 84 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 85 column = scope.external_columns[0] 86 for_table = f" for table: '{column.table}'" if column.table else "" 87 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 88 89 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 90 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 91 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 92 # this list here to ensure those in the former category will be excluded. 93 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 94 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 95 96 all_unqualified_columns.extend(unqualified_columns) 97 98 if all_unqualified_columns: 99 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 100 101 return expression 102 103 104def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 105 name_column = [] 106 field = unpivot.args.get("field") 107 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 108 name_column.append(field.this) 109 110 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 111 return itertools.chain(name_column, value_columns) 112 113 114def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 115 """ 116 Remove table column aliases. 117 118 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 119 """ 120 for derived_table in derived_tables: 121 table_alias = derived_table.args.get("alias") 122 if table_alias: 123 table_alias.args.pop("columns", None) 124 125 126def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 127 joins = list(scope.find_all(exp.Join)) 128 names = {join.alias_or_name for join in joins} 129 ordered = [key for key in scope.selected_sources if key not in names] 130 131 # Mapping of automatically joined column names to an ordered set of source names (dict). 132 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 133 134 for join in joins: 135 using = join.args.get("using") 136 137 if not using: 138 continue 139 140 join_table = join.alias_or_name 141 142 columns = {} 143 144 for source_name in scope.selected_sources: 145 if source_name in ordered: 146 for column_name in resolver.get_source_columns(source_name): 147 if column_name not in columns: 148 columns[column_name] = source_name 149 150 source_table = ordered[-1] 151 ordered.append(join_table) 152 join_columns = resolver.get_source_columns(join_table) 153 conditions = [] 154 155 for identifier in using: 156 identifier = identifier.name 157 table = columns.get(identifier) 158 159 if not table or identifier not in join_columns: 160 if (columns and "*" not in columns) and join_columns: 161 raise OptimizeError(f"Cannot automatically join: {identifier}") 162 163 table = table or source_table 164 conditions.append( 165 exp.condition( 166 exp.EQ( 167 this=exp.column(identifier, table=table), 168 expression=exp.column(identifier, table=join_table), 169 ) 170 ) 171 ) 172 173 # Set all values in the dict to None, because we only care about the key ordering 174 tables = column_tables.setdefault(identifier, {}) 175 if table not in tables: 176 tables[table] = None 177 if join_table not in tables: 178 tables[join_table] = None 179 180 join.args.pop("using") 181 join.set("on", exp.and_(*conditions, copy=False)) 182 183 if column_tables: 184 for column in scope.columns: 185 if not column.table and column.name in column_tables: 186 tables = column_tables[column.name] 187 coalesce = [exp.column(column.name, table=table) for table in tables] 188 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 189 190 # Ensure selects keep their output name 191 if isinstance(column.parent, exp.Select): 192 replacement = alias(replacement, alias=column.name, copy=False) 193 194 scope.replace(column, replacement) 195 196 return column_tables 197 198 199def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 200 expression = scope.expression 201 202 if not isinstance(expression, exp.Select): 203 return 204 205 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 206 207 def replace_columns( 208 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 209 ) -> None: 210 if not node: 211 return 212 213 for column, *_ in walk_in_scope(node): 214 if not isinstance(column, exp.Column): 215 continue 216 217 table = resolver.get_table(column.name) if resolve_table and not column.table else None 218 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 219 double_agg = ( 220 (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) 221 if alias_expr 222 else False 223 ) 224 225 if table and (not alias_expr or double_agg): 226 column.set("table", table) 227 elif not column.table and alias_expr and not double_agg: 228 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 229 if literal_index: 230 column.replace(exp.Literal.number(i)) 231 else: 232 column = column.replace(exp.paren(alias_expr)) 233 simplified = simplify_parens(column) 234 if simplified is not column: 235 column.replace(simplified) 236 237 for i, projection in enumerate(scope.expression.selects): 238 replace_columns(projection) 239 240 if isinstance(projection, exp.Alias): 241 alias_to_expression[projection.alias] = (projection.this, i + 1) 242 243 replace_columns(expression.args.get("where")) 244 replace_columns(expression.args.get("group"), literal_index=True) 245 replace_columns(expression.args.get("having"), resolve_table=True) 246 replace_columns(expression.args.get("qualify"), resolve_table=True) 247 248 scope.clear_cache() 249 250 251def _expand_group_by(scope: Scope) -> None: 252 expression = scope.expression 253 group = expression.args.get("group") 254 if not group: 255 return 256 257 group.set("expressions", _expand_positional_references(scope, group.expressions)) 258 expression.set("group", group) 259 260 261def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 262 order = scope.expression.args.get("order") 263 if not order: 264 return 265 266 ordereds = order.expressions 267 for ordered, new_expression in zip( 268 ordereds, 269 _expand_positional_references(scope, (o.this for o in ordereds), alias=True), 270 ): 271 for agg in ordered.find_all(exp.AggFunc): 272 for col in agg.find_all(exp.Column): 273 if not col.table: 274 col.set("table", resolver.get_table(col.name)) 275 276 ordered.set("this", new_expression) 277 278 if scope.expression.args.get("group"): 279 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 280 281 for ordered in ordereds: 282 ordered = ordered.this 283 284 ordered.replace( 285 exp.to_identifier(_select_by_pos(scope, ordered).alias) 286 if ordered.is_int 287 else selects.get(ordered, ordered) 288 ) 289 290 291def _expand_positional_references( 292 scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False 293) -> t.List[exp.Expression]: 294 new_nodes: t.List[exp.Expression] = [] 295 for node in expressions: 296 if node.is_int: 297 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 298 299 if alias: 300 new_nodes.append(exp.column(select.args["alias"].copy())) 301 else: 302 select = select.this 303 304 if isinstance(select, exp.Literal): 305 new_nodes.append(node) 306 else: 307 new_nodes.append(select.copy()) 308 else: 309 new_nodes.append(node) 310 311 return new_nodes 312 313 314def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 315 try: 316 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 317 except IndexError: 318 raise OptimizeError(f"Unknown output column: {node.name}") 319 320 321def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 322 """Disambiguate columns, ensuring each column specifies a source""" 323 for column in scope.columns: 324 column_table = column.table 325 column_name = column.name 326 327 if column_table and column_table in scope.sources: 328 source_columns = resolver.get_source_columns(column_table) 329 if source_columns and column_name not in source_columns and "*" not in source_columns: 330 raise OptimizeError(f"Unknown column: {column_name}") 331 332 if not column_table: 333 if scope.pivots and not column.find_ancestor(exp.Pivot): 334 # If the column is under the Pivot expression, we need to qualify it 335 # using the name of the pivoted source instead of the pivot's alias 336 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 337 continue 338 339 column_table = resolver.get_table(column_name) 340 341 # column_table can be a '' because bigquery unnest has no table alias 342 if column_table: 343 column.set("table", column_table) 344 elif column_table not in scope.sources and ( 345 not scope.parent 346 or column_table not in scope.parent.sources 347 or not scope.is_correlated_subquery 348 ): 349 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 350 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 351 352 root, *parts = column.parts 353 354 if root.name in scope.sources: 355 # struct is already qualified, but we still need to change the AST representation 356 column_table = root 357 root, *parts = parts 358 else: 359 column_table = resolver.get_table(root.name) 360 361 if column_table: 362 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 363 364 for pivot in scope.pivots: 365 for column in pivot.find_all(exp.Column): 366 if not column.table and column.name in resolver.all_columns: 367 column_table = resolver.get_table(column.name) 368 if column_table: 369 column.set("table", column_table) 370 371 372def _expand_stars( 373 scope: Scope, 374 resolver: Resolver, 375 using_column_tables: t.Dict[str, t.Any], 376 pseudocolumns: t.Set[str], 377) -> None: 378 """Expand stars to lists of column selections""" 379 380 new_selections = [] 381 except_columns: t.Dict[int, t.Set[str]] = {} 382 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 383 coalesced_columns = set() 384 385 pivot_output_columns = None 386 pivot_exclude_columns = None 387 388 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 389 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 390 if pivot.unpivot: 391 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 392 393 field = pivot.args.get("field") 394 if isinstance(field, exp.In): 395 pivot_exclude_columns = { 396 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 397 } 398 else: 399 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 400 401 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 402 if not pivot_output_columns: 403 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 404 405 for expression in scope.expression.selects: 406 if isinstance(expression, exp.Star): 407 tables = list(scope.selected_sources) 408 _add_except_columns(expression, tables, except_columns) 409 _add_replace_columns(expression, tables, replace_columns) 410 elif expression.is_star: 411 tables = [expression.table] 412 _add_except_columns(expression.this, tables, except_columns) 413 _add_replace_columns(expression.this, tables, replace_columns) 414 else: 415 new_selections.append(expression) 416 continue 417 418 for table in tables: 419 if table not in scope.sources: 420 raise OptimizeError(f"Unknown table: {table}") 421 422 columns = resolver.get_source_columns(table, only_visible=True) 423 columns = columns or scope.outer_column_list 424 425 if pseudocolumns: 426 columns = [name for name in columns if name.upper() not in pseudocolumns] 427 428 if not columns or "*" in columns: 429 return 430 431 table_id = id(table) 432 columns_to_exclude = except_columns.get(table_id) or set() 433 434 if pivot: 435 if pivot_output_columns and pivot_exclude_columns: 436 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 437 pivot_columns.extend(pivot_output_columns) 438 else: 439 pivot_columns = pivot.alias_column_names 440 441 if pivot_columns: 442 new_selections.extend( 443 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 444 for name in pivot_columns 445 if name not in columns_to_exclude 446 ) 447 continue 448 449 for name in columns: 450 if name in using_column_tables and table in using_column_tables[name]: 451 if name in coalesced_columns: 452 continue 453 454 coalesced_columns.add(name) 455 tables = using_column_tables[name] 456 coalesce = [exp.column(name, table=table) for table in tables] 457 458 new_selections.append( 459 alias( 460 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 461 alias=name, 462 copy=False, 463 ) 464 ) 465 elif name not in columns_to_exclude: 466 alias_ = replace_columns.get(table_id, {}).get(name, name) 467 column = exp.column(name, table=table) 468 new_selections.append( 469 alias(column, alias_, copy=False) if alias_ != name else column 470 ) 471 472 # Ensures we don't overwrite the initial selections with an empty list 473 if new_selections: 474 scope.expression.set("expressions", new_selections) 475 476 477def _add_except_columns( 478 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 479) -> None: 480 except_ = expression.args.get("except") 481 482 if not except_: 483 return 484 485 columns = {e.name for e in except_} 486 487 for table in tables: 488 except_columns[id(table)] = columns 489 490 491def _add_replace_columns( 492 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 493) -> None: 494 replace = expression.args.get("replace") 495 496 if not replace: 497 return 498 499 columns = {e.this.name: e.alias for e in replace} 500 501 for table in tables: 502 replace_columns[id(table)] = columns 503 504 505def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 506 """Ensure all output columns are aliased""" 507 if isinstance(scope_or_expression, exp.Expression): 508 scope = build_scope(scope_or_expression) 509 if not isinstance(scope, Scope): 510 return 511 else: 512 scope = scope_or_expression 513 514 new_selections = [] 515 for i, (selection, aliased_column) in enumerate( 516 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 517 ): 518 if selection is None: 519 break 520 521 if isinstance(selection, exp.Subquery): 522 if not selection.output_name: 523 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 524 elif not isinstance(selection, exp.Alias) and not selection.is_star: 525 selection = alias( 526 selection, 527 alias=selection.output_name or f"_col_{i}", 528 ) 529 if aliased_column: 530 selection.set("alias", exp.to_identifier(aliased_column)) 531 532 new_selections.append(selection) 533 534 scope.expression.set("expressions", new_selections) 535 536 537def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 538 """Makes sure all identifiers that need to be quoted are quoted.""" 539 return expression.transform( 540 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 541 ) 542 543 544def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 545 """ 546 Pushes down the CTE alias columns into the projection, 547 548 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 549 550 Example: 551 >>> import sqlglot 552 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 553 >>> pushdown_cte_alias_columns(expression).sql() 554 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 555 556 Args: 557 expression: Expression to pushdown. 558 559 Returns: 560 The expression with the CTE aliases pushed down into the projection. 561 """ 562 for cte in expression.find_all(exp.CTE): 563 if cte.alias_column_names: 564 new_expressions = [] 565 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 566 if isinstance(projection, exp.Alias): 567 projection.set("alias", _alias) 568 else: 569 projection = alias(projection, alias=_alias) 570 new_expressions.append(projection) 571 cte.this.set("expressions", new_expressions) 572 573 return expression 574 575 576class Resolver: 577 """ 578 Helper for resolving columns. 579 580 This is a class so we can lazily load some things and easily share them across functions. 581 """ 582 583 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 584 self.scope = scope 585 self.schema = schema 586 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 587 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 588 self._all_columns: t.Optional[t.Set[str]] = None 589 self._infer_schema = infer_schema 590 591 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 592 """ 593 Get the table for a column name. 594 595 Args: 596 column_name: The column name to find the table for. 597 Returns: 598 The table name if it can be found/inferred. 599 """ 600 if self._unambiguous_columns is None: 601 self._unambiguous_columns = self._get_unambiguous_columns( 602 self._get_all_source_columns() 603 ) 604 605 table_name = self._unambiguous_columns.get(column_name) 606 607 if not table_name and self._infer_schema: 608 sources_without_schema = tuple( 609 source 610 for source, columns in self._get_all_source_columns().items() 611 if not columns or "*" in columns 612 ) 613 if len(sources_without_schema) == 1: 614 table_name = sources_without_schema[0] 615 616 if table_name not in self.scope.selected_sources: 617 return exp.to_identifier(table_name) 618 619 node, _ = self.scope.selected_sources.get(table_name) 620 621 if isinstance(node, exp.Subqueryable): 622 while node and node.alias != table_name: 623 node = node.parent 624 625 node_alias = node.args.get("alias") 626 if node_alias: 627 return exp.to_identifier(node_alias.this) 628 629 return exp.to_identifier(table_name) 630 631 @property 632 def all_columns(self) -> t.Set[str]: 633 """All available columns of all sources in this scope""" 634 if self._all_columns is None: 635 self._all_columns = { 636 column for columns in self._get_all_source_columns().values() for column in columns 637 } 638 return self._all_columns 639 640 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 641 """Resolve the source columns for a given source `name`.""" 642 if name not in self.scope.sources: 643 raise OptimizeError(f"Unknown table: {name}") 644 645 source = self.scope.sources[name] 646 647 if isinstance(source, exp.Table): 648 columns = self.schema.column_names(source, only_visible) 649 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 650 columns = source.expression.alias_column_names 651 else: 652 columns = source.expression.named_selects 653 654 node, _ = self.scope.selected_sources.get(name) or (None, None) 655 if isinstance(node, Scope): 656 column_aliases = node.expression.alias_column_names 657 elif isinstance(node, exp.Expression): 658 column_aliases = node.alias_column_names 659 else: 660 column_aliases = [] 661 662 # If the source's columns are aliased, their aliases shadow the corresponding column names 663 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] 664 665 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 666 if self._source_columns is None: 667 self._source_columns = { 668 source_name: self.get_source_columns(source_name) 669 for source_name, source in itertools.chain( 670 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 671 ) 672 } 673 return self._source_columns 674 675 def _get_unambiguous_columns( 676 self, source_columns: t.Dict[str, t.List[str]] 677 ) -> t.Dict[str, str]: 678 """ 679 Find all the unambiguous columns in sources. 680 681 Args: 682 source_columns: Mapping of names to source columns. 683 684 Returns: 685 Mapping of column name to source name. 686 """ 687 if not source_columns: 688 return {} 689 690 source_columns_pairs = list(source_columns.items()) 691 692 first_table, first_columns = source_columns_pairs[0] 693 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 694 all_columns = set(unambiguous_columns) 695 696 for table, columns in source_columns_pairs[1:]: 697 unique = self._find_unique_columns(columns) 698 ambiguous = set(all_columns).intersection(unique) 699 all_columns.update(columns) 700 701 for column in ambiguous: 702 unambiguous_columns.pop(column, None) 703 for column in unique.difference(ambiguous): 704 unambiguous_columns[column] = table 705 706 return unambiguous_columns 707 708 @staticmethod 709 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 710 """ 711 Find the unique columns in a list of columns. 712 713 Example: 714 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 715 ['a', 'c'] 716 717 This is necessary because duplicate column names are ambiguous. 718 """ 719 counts: t.Dict[str, int] = {} 720 for column in columns: 721 counts[column] = counts.get(column, 0) + 1 722 return {column for column, count in counts.items() if count == 1}
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
17def qualify_columns( 18 expression: exp.Expression, 19 schema: t.Dict | Schema, 20 expand_alias_refs: bool = True, 21 expand_stars: bool = True, 22 infer_schema: t.Optional[bool] = None, 23) -> exp.Expression: 24 """ 25 Rewrite sqlglot AST to have fully qualified columns. 26 27 Example: 28 >>> import sqlglot 29 >>> schema = {"tbl": {"col": "INT"}} 30 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 31 >>> qualify_columns(expression, schema).sql() 32 'SELECT tbl.col AS col FROM tbl' 33 34 Args: 35 expression: Expression to qualify. 36 schema: Database schema. 37 expand_alias_refs: Whether or not to expand references to aliases. 38 expand_stars: Whether or not to expand star queries. This is a necessary step 39 for most of the optimizer's rules to work; do not set to False unless you 40 know what you're doing! 41 infer_schema: Whether or not to infer the schema if missing. 42 43 Returns: 44 The qualified expression. 45 46 Notes: 47 - Currently only handles a single PIVOT or UNPIVOT operator 48 """ 49 schema = ensure_schema(schema) 50 infer_schema = schema.empty if infer_schema is None else infer_schema 51 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 52 53 for scope in traverse_scope(expression): 54 resolver = Resolver(scope, schema, infer_schema=infer_schema) 55 _pop_table_column_aliases(scope.ctes) 56 _pop_table_column_aliases(scope.derived_tables) 57 using_column_tables = _expand_using(scope, resolver) 58 59 if schema.empty and expand_alias_refs: 60 _expand_alias_refs(scope, resolver) 61 62 _qualify_columns(scope, resolver) 63 64 if not schema.empty and expand_alias_refs: 65 _expand_alias_refs(scope, resolver) 66 67 if not isinstance(scope.expression, exp.UDTF): 68 if expand_stars: 69 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 70 qualify_outputs(scope) 71 72 _expand_group_by(scope) 73 _expand_order_by(scope, resolver) 74 75 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether or not to expand references to aliases.
- expand_stars: Whether or not to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether or not to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
78def validate_qualify_columns(expression: E) -> E: 79 """Raise an `OptimizeError` if any columns aren't qualified""" 80 all_unqualified_columns = [] 81 for scope in traverse_scope(expression): 82 if isinstance(scope.expression, exp.Select): 83 unqualified_columns = scope.unqualified_columns 84 85 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 86 column = scope.external_columns[0] 87 for_table = f" for table: '{column.table}'" if column.table else "" 88 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 89 90 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 91 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 92 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 93 # this list here to ensure those in the former category will be excluded. 94 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 95 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 96 97 all_unqualified_columns.extend(unqualified_columns) 98 99 if all_unqualified_columns: 100 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 101 102 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
506def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 507 """Ensure all output columns are aliased""" 508 if isinstance(scope_or_expression, exp.Expression): 509 scope = build_scope(scope_or_expression) 510 if not isinstance(scope, Scope): 511 return 512 else: 513 scope = scope_or_expression 514 515 new_selections = [] 516 for i, (selection, aliased_column) in enumerate( 517 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 518 ): 519 if selection is None: 520 break 521 522 if isinstance(selection, exp.Subquery): 523 if not selection.output_name: 524 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 525 elif not isinstance(selection, exp.Alias) and not selection.is_star: 526 selection = alias( 527 selection, 528 alias=selection.output_name or f"_col_{i}", 529 ) 530 if aliased_column: 531 selection.set("alias", exp.to_identifier(aliased_column)) 532 533 new_selections.append(selection) 534 535 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
538def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 539 """Makes sure all identifiers that need to be quoted are quoted.""" 540 return expression.transform( 541 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 542 )
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
545def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 546 """ 547 Pushes down the CTE alias columns into the projection, 548 549 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 550 551 Example: 552 >>> import sqlglot 553 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 554 >>> pushdown_cte_alias_columns(expression).sql() 555 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 556 557 Args: 558 expression: Expression to pushdown. 559 560 Returns: 561 The expression with the CTE aliases pushed down into the projection. 562 """ 563 for cte in expression.find_all(exp.CTE): 564 if cte.alias_column_names: 565 new_expressions = [] 566 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 567 if isinstance(projection, exp.Alias): 568 projection.set("alias", _alias) 569 else: 570 projection = alias(projection, alias=_alias) 571 new_expressions.append(projection) 572 cte.this.set("expressions", new_expressions) 573 574 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
577class Resolver: 578 """ 579 Helper for resolving columns. 580 581 This is a class so we can lazily load some things and easily share them across functions. 582 """ 583 584 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 585 self.scope = scope 586 self.schema = schema 587 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 588 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 589 self._all_columns: t.Optional[t.Set[str]] = None 590 self._infer_schema = infer_schema 591 592 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 593 """ 594 Get the table for a column name. 595 596 Args: 597 column_name: The column name to find the table for. 598 Returns: 599 The table name if it can be found/inferred. 600 """ 601 if self._unambiguous_columns is None: 602 self._unambiguous_columns = self._get_unambiguous_columns( 603 self._get_all_source_columns() 604 ) 605 606 table_name = self._unambiguous_columns.get(column_name) 607 608 if not table_name and self._infer_schema: 609 sources_without_schema = tuple( 610 source 611 for source, columns in self._get_all_source_columns().items() 612 if not columns or "*" in columns 613 ) 614 if len(sources_without_schema) == 1: 615 table_name = sources_without_schema[0] 616 617 if table_name not in self.scope.selected_sources: 618 return exp.to_identifier(table_name) 619 620 node, _ = self.scope.selected_sources.get(table_name) 621 622 if isinstance(node, exp.Subqueryable): 623 while node and node.alias != table_name: 624 node = node.parent 625 626 node_alias = node.args.get("alias") 627 if node_alias: 628 return exp.to_identifier(node_alias.this) 629 630 return exp.to_identifier(table_name) 631 632 @property 633 def all_columns(self) -> t.Set[str]: 634 """All available columns of all sources in this scope""" 635 if self._all_columns is None: 636 self._all_columns = { 637 column for columns in self._get_all_source_columns().values() for column in columns 638 } 639 return self._all_columns 640 641 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 642 """Resolve the source columns for a given source `name`.""" 643 if name not in self.scope.sources: 644 raise OptimizeError(f"Unknown table: {name}") 645 646 source = self.scope.sources[name] 647 648 if isinstance(source, exp.Table): 649 columns = self.schema.column_names(source, only_visible) 650 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 651 columns = source.expression.alias_column_names 652 else: 653 columns = source.expression.named_selects 654 655 node, _ = self.scope.selected_sources.get(name) or (None, None) 656 if isinstance(node, Scope): 657 column_aliases = node.expression.alias_column_names 658 elif isinstance(node, exp.Expression): 659 column_aliases = node.alias_column_names 660 else: 661 column_aliases = [] 662 663 # If the source's columns are aliased, their aliases shadow the corresponding column names 664 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] 665 666 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 667 if self._source_columns is None: 668 self._source_columns = { 669 source_name: self.get_source_columns(source_name) 670 for source_name, source in itertools.chain( 671 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 672 ) 673 } 674 return self._source_columns 675 676 def _get_unambiguous_columns( 677 self, source_columns: t.Dict[str, t.List[str]] 678 ) -> t.Dict[str, str]: 679 """ 680 Find all the unambiguous columns in sources. 681 682 Args: 683 source_columns: Mapping of names to source columns. 684 685 Returns: 686 Mapping of column name to source name. 687 """ 688 if not source_columns: 689 return {} 690 691 source_columns_pairs = list(source_columns.items()) 692 693 first_table, first_columns = source_columns_pairs[0] 694 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 695 all_columns = set(unambiguous_columns) 696 697 for table, columns in source_columns_pairs[1:]: 698 unique = self._find_unique_columns(columns) 699 ambiguous = set(all_columns).intersection(unique) 700 all_columns.update(columns) 701 702 for column in ambiguous: 703 unambiguous_columns.pop(column, None) 704 for column in unique.difference(ambiguous): 705 unambiguous_columns[column] = table 706 707 return unambiguous_columns 708 709 @staticmethod 710 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 711 """ 712 Find the unique columns in a list of columns. 713 714 Example: 715 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 716 ['a', 'c'] 717 718 This is necessary because duplicate column names are ambiguous. 719 """ 720 counts: t.Dict[str, int] = {} 721 for column in columns: 722 counts[column] = counts.get(column, 0) + 1 723 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
584 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 585 self.scope = scope 586 self.schema = schema 587 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 588 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 589 self._all_columns: t.Optional[t.Set[str]] = None 590 self._infer_schema = infer_schema
592 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 593 """ 594 Get the table for a column name. 595 596 Args: 597 column_name: The column name to find the table for. 598 Returns: 599 The table name if it can be found/inferred. 600 """ 601 if self._unambiguous_columns is None: 602 self._unambiguous_columns = self._get_unambiguous_columns( 603 self._get_all_source_columns() 604 ) 605 606 table_name = self._unambiguous_columns.get(column_name) 607 608 if not table_name and self._infer_schema: 609 sources_without_schema = tuple( 610 source 611 for source, columns in self._get_all_source_columns().items() 612 if not columns or "*" in columns 613 ) 614 if len(sources_without_schema) == 1: 615 table_name = sources_without_schema[0] 616 617 if table_name not in self.scope.selected_sources: 618 return exp.to_identifier(table_name) 619 620 node, _ = self.scope.selected_sources.get(table_name) 621 622 if isinstance(node, exp.Subqueryable): 623 while node and node.alias != table_name: 624 node = node.parent 625 626 node_alias = node.args.get("alias") 627 if node_alias: 628 return exp.to_identifier(node_alias.this) 629 630 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
632 @property 633 def all_columns(self) -> t.Set[str]: 634 """All available columns of all sources in this scope""" 635 if self._all_columns is None: 636 self._all_columns = { 637 column for columns in self._get_all_source_columns().values() for column in columns 638 } 639 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
641 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 642 """Resolve the source columns for a given source `name`.""" 643 if name not in self.scope.sources: 644 raise OptimizeError(f"Unknown table: {name}") 645 646 source = self.scope.sources[name] 647 648 if isinstance(source, exp.Table): 649 columns = self.schema.column_names(source, only_visible) 650 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 651 columns = source.expression.alias_column_names 652 else: 653 columns = source.expression.named_selects 654 655 node, _ = self.scope.selected_sources.get(name) or (None, None) 656 if isinstance(node, Scope): 657 column_aliases = node.expression.alias_column_names 658 elif isinstance(node, exp.Expression): 659 column_aliases = node.alias_column_names 660 else: 661 column_aliases = [] 662 663 # If the source's columns are aliased, their aliases shadow the corresponding column names 664 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
Resolve the source columns for a given source name
.