sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot._typing import E 8from sqlglot.dialects.dialect import Dialect, DialectType 9from sqlglot.errors import OptimizeError 10from sqlglot.helper import seq_get 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15 16def qualify_columns( 17 expression: exp.Expression, 18 schema: t.Dict | Schema, 19 expand_alias_refs: bool = True, 20 expand_stars: bool = True, 21 infer_schema: t.Optional[bool] = None, 22) -> exp.Expression: 23 """ 24 Rewrite sqlglot AST to have fully qualified columns. 25 26 Example: 27 >>> import sqlglot 28 >>> schema = {"tbl": {"col": "INT"}} 29 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 30 >>> qualify_columns(expression, schema).sql() 31 'SELECT tbl.col AS col FROM tbl' 32 33 Args: 34 expression: Expression to qualify. 35 schema: Database schema. 36 expand_alias_refs: Whether or not to expand references to aliases. 37 expand_stars: Whether or not to expand star queries. This is a necessary step 38 for most of the optimizer's rules to work; do not set to False unless you 39 know what you're doing! 40 infer_schema: Whether or not to infer the schema if missing. 41 42 Returns: 43 The qualified expression. 44 45 Notes: 46 - Currently only handles a single PIVOT or UNPIVOT operator 47 """ 48 schema = ensure_schema(schema) 49 infer_schema = schema.empty if infer_schema is None else infer_schema 50 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 51 52 for scope in traverse_scope(expression): 53 resolver = Resolver(scope, schema, infer_schema=infer_schema) 54 _pop_table_column_aliases(scope.ctes) 55 _pop_table_column_aliases(scope.derived_tables) 56 using_column_tables = _expand_using(scope, resolver) 57 58 if schema.empty and expand_alias_refs: 59 _expand_alias_refs(scope, resolver) 60 61 _qualify_columns(scope, resolver) 62 63 if not schema.empty and expand_alias_refs: 64 _expand_alias_refs(scope, resolver) 65 66 if not isinstance(scope.expression, exp.UDTF): 67 if expand_stars: 68 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 69 qualify_outputs(scope) 70 71 _expand_group_by(scope) 72 _expand_order_by(scope, resolver) 73 74 return expression 75 76 77def validate_qualify_columns(expression: E) -> E: 78 """Raise an `OptimizeError` if any columns aren't qualified""" 79 all_unqualified_columns = [] 80 for scope in traverse_scope(expression): 81 if isinstance(scope.expression, exp.Select): 82 unqualified_columns = scope.unqualified_columns 83 84 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 85 column = scope.external_columns[0] 86 for_table = f" for table: '{column.table}'" if column.table else "" 87 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 88 89 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 90 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 91 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 92 # this list here to ensure those in the former category will be excluded. 93 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 94 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 95 96 all_unqualified_columns.extend(unqualified_columns) 97 98 if all_unqualified_columns: 99 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 100 101 return expression 102 103 104def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 105 name_column = [] 106 field = unpivot.args.get("field") 107 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 108 name_column.append(field.this) 109 110 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 111 return itertools.chain(name_column, value_columns) 112 113 114def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 115 """ 116 Remove table column aliases. 117 118 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 119 """ 120 for derived_table in derived_tables: 121 table_alias = derived_table.args.get("alias") 122 if table_alias: 123 table_alias.args.pop("columns", None) 124 125 126def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 127 joins = list(scope.find_all(exp.Join)) 128 names = {join.alias_or_name for join in joins} 129 ordered = [key for key in scope.selected_sources if key not in names] 130 131 # Mapping of automatically joined column names to an ordered set of source names (dict). 132 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 133 134 for join in joins: 135 using = join.args.get("using") 136 137 if not using: 138 continue 139 140 join_table = join.alias_or_name 141 142 columns = {} 143 144 for source_name in scope.selected_sources: 145 if source_name in ordered: 146 for column_name in resolver.get_source_columns(source_name): 147 if column_name not in columns: 148 columns[column_name] = source_name 149 150 source_table = ordered[-1] 151 ordered.append(join_table) 152 join_columns = resolver.get_source_columns(join_table) 153 conditions = [] 154 155 for identifier in using: 156 identifier = identifier.name 157 table = columns.get(identifier) 158 159 if not table or identifier not in join_columns: 160 if (columns and "*" not in columns) and join_columns: 161 raise OptimizeError(f"Cannot automatically join: {identifier}") 162 163 table = table or source_table 164 conditions.append( 165 exp.condition( 166 exp.EQ( 167 this=exp.column(identifier, table=table), 168 expression=exp.column(identifier, table=join_table), 169 ) 170 ) 171 ) 172 173 # Set all values in the dict to None, because we only care about the key ordering 174 tables = column_tables.setdefault(identifier, {}) 175 if table not in tables: 176 tables[table] = None 177 if join_table not in tables: 178 tables[join_table] = None 179 180 join.args.pop("using") 181 join.set("on", exp.and_(*conditions, copy=False)) 182 183 if column_tables: 184 for column in scope.columns: 185 if not column.table and column.name in column_tables: 186 tables = column_tables[column.name] 187 coalesce = [exp.column(column.name, table=table) for table in tables] 188 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 189 190 # Ensure selects keep their output name 191 if isinstance(column.parent, exp.Select): 192 replacement = alias(replacement, alias=column.name, copy=False) 193 194 scope.replace(column, replacement) 195 196 return column_tables 197 198 199def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 200 expression = scope.expression 201 202 if not isinstance(expression, exp.Select): 203 return 204 205 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 206 207 def replace_columns( 208 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 209 ) -> None: 210 if not node: 211 return 212 213 for column, *_ in walk_in_scope(node): 214 if not isinstance(column, exp.Column): 215 continue 216 217 table = resolver.get_table(column.name) if resolve_table and not column.table else None 218 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 219 double_agg = ( 220 (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) 221 if alias_expr 222 else False 223 ) 224 225 if table and (not alias_expr or double_agg): 226 column.set("table", table) 227 elif not column.table and alias_expr and not double_agg: 228 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 229 if literal_index: 230 column.replace(exp.Literal.number(i)) 231 else: 232 column = column.replace(exp.paren(alias_expr)) 233 simplified = simplify_parens(column) 234 if simplified is not column: 235 column.replace(simplified) 236 237 for i, projection in enumerate(scope.expression.selects): 238 replace_columns(projection) 239 240 if isinstance(projection, exp.Alias): 241 alias_to_expression[projection.alias] = (projection.this, i + 1) 242 243 replace_columns(expression.args.get("where")) 244 replace_columns(expression.args.get("group"), literal_index=True) 245 replace_columns(expression.args.get("having"), resolve_table=True) 246 replace_columns(expression.args.get("qualify"), resolve_table=True) 247 248 scope.clear_cache() 249 250 251def _expand_group_by(scope: Scope) -> None: 252 expression = scope.expression 253 group = expression.args.get("group") 254 if not group: 255 return 256 257 group.set("expressions", _expand_positional_references(scope, group.expressions)) 258 expression.set("group", group) 259 260 261def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 262 order = scope.expression.args.get("order") 263 if not order: 264 return 265 266 ordereds = order.expressions 267 for ordered, new_expression in zip( 268 ordereds, 269 _expand_positional_references(scope, (o.this for o in ordereds), alias=True), 270 ): 271 for agg in ordered.find_all(exp.AggFunc): 272 for col in agg.find_all(exp.Column): 273 if not col.table: 274 col.set("table", resolver.get_table(col.name)) 275 276 ordered.set("this", new_expression) 277 278 if scope.expression.args.get("group"): 279 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 280 281 for ordered in ordereds: 282 ordered = ordered.this 283 284 ordered.replace( 285 exp.to_identifier(_select_by_pos(scope, ordered).alias) 286 if ordered.is_int 287 else selects.get(ordered, ordered) 288 ) 289 290 291def _expand_positional_references( 292 scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False 293) -> t.List[exp.Expression]: 294 new_nodes: t.List[exp.Expression] = [] 295 for node in expressions: 296 if node.is_int: 297 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 298 299 if alias: 300 new_nodes.append(exp.column(select.args["alias"].copy())) 301 else: 302 select = select.this 303 304 if isinstance(select, exp.Literal): 305 new_nodes.append(node) 306 else: 307 new_nodes.append(select.copy()) 308 else: 309 new_nodes.append(node) 310 311 return new_nodes 312 313 314def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 315 try: 316 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 317 except IndexError: 318 raise OptimizeError(f"Unknown output column: {node.name}") 319 320 321def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 322 """Disambiguate columns, ensuring each column specifies a source""" 323 for column in scope.columns: 324 column_table = column.table 325 column_name = column.name 326 327 if column_table and column_table in scope.sources: 328 source_columns = resolver.get_source_columns(column_table) 329 if source_columns and column_name not in source_columns and "*" not in source_columns: 330 raise OptimizeError(f"Unknown column: {column_name}") 331 332 if not column_table: 333 if scope.pivots and not column.find_ancestor(exp.Pivot): 334 # If the column is under the Pivot expression, we need to qualify it 335 # using the name of the pivoted source instead of the pivot's alias 336 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 337 continue 338 339 column_table = resolver.get_table(column_name) 340 341 # column_table can be a '' because bigquery unnest has no table alias 342 if column_table: 343 column.set("table", column_table) 344 elif column_table not in scope.sources and ( 345 not scope.parent 346 or column_table not in scope.parent.sources 347 or not scope.is_correlated_subquery 348 ): 349 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 350 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 351 352 root, *parts = column.parts 353 354 if root.name in scope.sources: 355 # struct is already qualified, but we still need to change the AST representation 356 column_table = root 357 root, *parts = parts 358 else: 359 column_table = resolver.get_table(root.name) 360 361 if column_table: 362 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 363 364 for pivot in scope.pivots: 365 for column in pivot.find_all(exp.Column): 366 if not column.table and column.name in resolver.all_columns: 367 column_table = resolver.get_table(column.name) 368 if column_table: 369 column.set("table", column_table) 370 371 372def _expand_stars( 373 scope: Scope, 374 resolver: Resolver, 375 using_column_tables: t.Dict[str, t.Any], 376 pseudocolumns: t.Set[str], 377) -> None: 378 """Expand stars to lists of column selections""" 379 380 new_selections = [] 381 except_columns: t.Dict[int, t.Set[str]] = {} 382 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 383 coalesced_columns = set() 384 385 pivot_output_columns = None 386 pivot_exclude_columns = None 387 388 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 389 if isinstance(pivot, exp.Pivot): 390 if pivot.unpivot: 391 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 392 393 field = pivot.args.get("field") 394 if isinstance(field, exp.In): 395 pivot_exclude_columns = { 396 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 397 } 398 else: 399 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 400 401 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 402 if not pivot_output_columns: 403 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 404 405 for expression in scope.expression.selects: 406 if isinstance(expression, exp.Star): 407 tables = list(scope.selected_sources) 408 _add_except_columns(expression, tables, except_columns) 409 _add_replace_columns(expression, tables, replace_columns) 410 elif expression.is_star: 411 tables = [expression.table] 412 _add_except_columns(expression.this, tables, except_columns) 413 _add_replace_columns(expression.this, tables, replace_columns) 414 else: 415 new_selections.append(expression) 416 continue 417 418 for table in tables: 419 if table not in scope.sources: 420 raise OptimizeError(f"Unknown table: {table}") 421 422 columns = resolver.get_source_columns(table, only_visible=True) 423 columns = columns or scope.outer_column_list 424 425 if pseudocolumns: 426 columns = [name for name in columns if name.upper() not in pseudocolumns] 427 428 if not columns or "*" in columns: 429 return 430 431 table_id = id(table) 432 columns_to_exclude = except_columns.get(table_id) or set() 433 434 if pivot and pivot_output_columns and pivot_exclude_columns: 435 implicit_columns = [c for c in columns if c not in pivot_exclude_columns] 436 if implicit_columns: 437 new_selections.extend( 438 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 439 for name in implicit_columns + pivot_output_columns 440 if name not in columns_to_exclude 441 ) 442 continue 443 444 for name in columns: 445 if name in using_column_tables and table in using_column_tables[name]: 446 if name in coalesced_columns: 447 continue 448 449 coalesced_columns.add(name) 450 tables = using_column_tables[name] 451 coalesce = [exp.column(name, table=table) for table in tables] 452 453 new_selections.append( 454 alias( 455 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 456 alias=name, 457 copy=False, 458 ) 459 ) 460 elif name not in columns_to_exclude: 461 alias_ = replace_columns.get(table_id, {}).get(name, name) 462 column = exp.column(name, table=table) 463 new_selections.append( 464 alias(column, alias_, copy=False) if alias_ != name else column 465 ) 466 467 # Ensures we don't overwrite the initial selections with an empty list 468 if new_selections: 469 scope.expression.set("expressions", new_selections) 470 471 472def _add_except_columns( 473 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 474) -> None: 475 except_ = expression.args.get("except") 476 477 if not except_: 478 return 479 480 columns = {e.name for e in except_} 481 482 for table in tables: 483 except_columns[id(table)] = columns 484 485 486def _add_replace_columns( 487 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 488) -> None: 489 replace = expression.args.get("replace") 490 491 if not replace: 492 return 493 494 columns = {e.this.name: e.alias for e in replace} 495 496 for table in tables: 497 replace_columns[id(table)] = columns 498 499 500def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 501 """Ensure all output columns are aliased""" 502 if isinstance(scope_or_expression, exp.Expression): 503 scope = build_scope(scope_or_expression) 504 if not isinstance(scope, Scope): 505 return 506 else: 507 scope = scope_or_expression 508 509 new_selections = [] 510 for i, (selection, aliased_column) in enumerate( 511 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 512 ): 513 if selection is None: 514 break 515 516 if isinstance(selection, exp.Subquery): 517 if not selection.output_name: 518 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 519 elif not isinstance(selection, exp.Alias) and not selection.is_star: 520 selection = alias( 521 selection, 522 alias=selection.output_name or f"_col_{i}", 523 ) 524 if aliased_column: 525 selection.set("alias", exp.to_identifier(aliased_column)) 526 527 new_selections.append(selection) 528 529 scope.expression.set("expressions", new_selections) 530 531 532def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 533 """Makes sure all identifiers that need to be quoted are quoted.""" 534 return expression.transform( 535 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 536 ) 537 538 539def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 540 """ 541 Pushes down the CTE alias columns into the projection, 542 543 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 544 545 Example: 546 >>> import sqlglot 547 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 548 >>> pushdown_cte_alias_columns(expression).sql() 549 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 550 551 Args: 552 expression: Expression to pushdown. 553 554 Returns: 555 The expression with the CTE aliases pushed down into the projection. 556 """ 557 for cte in expression.find_all(exp.CTE): 558 if cte.alias_column_names: 559 new_expressions = [] 560 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 561 if isinstance(projection, exp.Alias): 562 projection.set("alias", _alias) 563 else: 564 projection = alias(projection, alias=_alias) 565 new_expressions.append(projection) 566 cte.this.set("expressions", new_expressions) 567 568 return expression 569 570 571class Resolver: 572 """ 573 Helper for resolving columns. 574 575 This is a class so we can lazily load some things and easily share them across functions. 576 """ 577 578 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 579 self.scope = scope 580 self.schema = schema 581 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 582 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 583 self._all_columns: t.Optional[t.Set[str]] = None 584 self._infer_schema = infer_schema 585 586 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 587 """ 588 Get the table for a column name. 589 590 Args: 591 column_name: The column name to find the table for. 592 Returns: 593 The table name if it can be found/inferred. 594 """ 595 if self._unambiguous_columns is None: 596 self._unambiguous_columns = self._get_unambiguous_columns( 597 self._get_all_source_columns() 598 ) 599 600 table_name = self._unambiguous_columns.get(column_name) 601 602 if not table_name and self._infer_schema: 603 sources_without_schema = tuple( 604 source 605 for source, columns in self._get_all_source_columns().items() 606 if not columns or "*" in columns 607 ) 608 if len(sources_without_schema) == 1: 609 table_name = sources_without_schema[0] 610 611 if table_name not in self.scope.selected_sources: 612 return exp.to_identifier(table_name) 613 614 node, _ = self.scope.selected_sources.get(table_name) 615 616 if isinstance(node, exp.Subqueryable): 617 while node and node.alias != table_name: 618 node = node.parent 619 620 node_alias = node.args.get("alias") 621 if node_alias: 622 return exp.to_identifier(node_alias.this) 623 624 return exp.to_identifier(table_name) 625 626 @property 627 def all_columns(self) -> t.Set[str]: 628 """All available columns of all sources in this scope""" 629 if self._all_columns is None: 630 self._all_columns = { 631 column for columns in self._get_all_source_columns().values() for column in columns 632 } 633 return self._all_columns 634 635 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 636 """Resolve the source columns for a given source `name`.""" 637 if name not in self.scope.sources: 638 raise OptimizeError(f"Unknown table: {name}") 639 640 source = self.scope.sources[name] 641 642 if isinstance(source, exp.Table): 643 columns = self.schema.column_names(source, only_visible) 644 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 645 columns = source.expression.alias_column_names 646 else: 647 columns = source.expression.named_selects 648 649 node, _ = self.scope.selected_sources.get(name) or (None, None) 650 if isinstance(node, Scope): 651 column_aliases = node.expression.alias_column_names 652 elif isinstance(node, exp.Expression): 653 column_aliases = node.alias_column_names 654 else: 655 column_aliases = [] 656 657 # If the source's columns are aliased, their aliases shadow the corresponding column names 658 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] 659 660 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 661 if self._source_columns is None: 662 self._source_columns = { 663 source_name: self.get_source_columns(source_name) 664 for source_name, source in itertools.chain( 665 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 666 ) 667 } 668 return self._source_columns 669 670 def _get_unambiguous_columns( 671 self, source_columns: t.Dict[str, t.List[str]] 672 ) -> t.Dict[str, str]: 673 """ 674 Find all the unambiguous columns in sources. 675 676 Args: 677 source_columns: Mapping of names to source columns. 678 679 Returns: 680 Mapping of column name to source name. 681 """ 682 if not source_columns: 683 return {} 684 685 source_columns_pairs = list(source_columns.items()) 686 687 first_table, first_columns = source_columns_pairs[0] 688 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 689 all_columns = set(unambiguous_columns) 690 691 for table, columns in source_columns_pairs[1:]: 692 unique = self._find_unique_columns(columns) 693 ambiguous = set(all_columns).intersection(unique) 694 all_columns.update(columns) 695 696 for column in ambiguous: 697 unambiguous_columns.pop(column, None) 698 for column in unique.difference(ambiguous): 699 unambiguous_columns[column] = table 700 701 return unambiguous_columns 702 703 @staticmethod 704 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 705 """ 706 Find the unique columns in a list of columns. 707 708 Example: 709 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 710 ['a', 'c'] 711 712 This is necessary because duplicate column names are ambiguous. 713 """ 714 counts: t.Dict[str, int] = {} 715 for column in columns: 716 counts[column] = counts.get(column, 0) + 1 717 return {column for column, count in counts.items() if count == 1}
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
17def qualify_columns( 18 expression: exp.Expression, 19 schema: t.Dict | Schema, 20 expand_alias_refs: bool = True, 21 expand_stars: bool = True, 22 infer_schema: t.Optional[bool] = None, 23) -> exp.Expression: 24 """ 25 Rewrite sqlglot AST to have fully qualified columns. 26 27 Example: 28 >>> import sqlglot 29 >>> schema = {"tbl": {"col": "INT"}} 30 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 31 >>> qualify_columns(expression, schema).sql() 32 'SELECT tbl.col AS col FROM tbl' 33 34 Args: 35 expression: Expression to qualify. 36 schema: Database schema. 37 expand_alias_refs: Whether or not to expand references to aliases. 38 expand_stars: Whether or not to expand star queries. This is a necessary step 39 for most of the optimizer's rules to work; do not set to False unless you 40 know what you're doing! 41 infer_schema: Whether or not to infer the schema if missing. 42 43 Returns: 44 The qualified expression. 45 46 Notes: 47 - Currently only handles a single PIVOT or UNPIVOT operator 48 """ 49 schema = ensure_schema(schema) 50 infer_schema = schema.empty if infer_schema is None else infer_schema 51 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 52 53 for scope in traverse_scope(expression): 54 resolver = Resolver(scope, schema, infer_schema=infer_schema) 55 _pop_table_column_aliases(scope.ctes) 56 _pop_table_column_aliases(scope.derived_tables) 57 using_column_tables = _expand_using(scope, resolver) 58 59 if schema.empty and expand_alias_refs: 60 _expand_alias_refs(scope, resolver) 61 62 _qualify_columns(scope, resolver) 63 64 if not schema.empty and expand_alias_refs: 65 _expand_alias_refs(scope, resolver) 66 67 if not isinstance(scope.expression, exp.UDTF): 68 if expand_stars: 69 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 70 qualify_outputs(scope) 71 72 _expand_group_by(scope) 73 _expand_order_by(scope, resolver) 74 75 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether or not to expand references to aliases.
- expand_stars: Whether or not to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether or not to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
78def validate_qualify_columns(expression: E) -> E: 79 """Raise an `OptimizeError` if any columns aren't qualified""" 80 all_unqualified_columns = [] 81 for scope in traverse_scope(expression): 82 if isinstance(scope.expression, exp.Select): 83 unqualified_columns = scope.unqualified_columns 84 85 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 86 column = scope.external_columns[0] 87 for_table = f" for table: '{column.table}'" if column.table else "" 88 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 89 90 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 91 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 92 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 93 # this list here to ensure those in the former category will be excluded. 94 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 95 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 96 97 all_unqualified_columns.extend(unqualified_columns) 98 99 if all_unqualified_columns: 100 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 101 102 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
501def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 502 """Ensure all output columns are aliased""" 503 if isinstance(scope_or_expression, exp.Expression): 504 scope = build_scope(scope_or_expression) 505 if not isinstance(scope, Scope): 506 return 507 else: 508 scope = scope_or_expression 509 510 new_selections = [] 511 for i, (selection, aliased_column) in enumerate( 512 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 513 ): 514 if selection is None: 515 break 516 517 if isinstance(selection, exp.Subquery): 518 if not selection.output_name: 519 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 520 elif not isinstance(selection, exp.Alias) and not selection.is_star: 521 selection = alias( 522 selection, 523 alias=selection.output_name or f"_col_{i}", 524 ) 525 if aliased_column: 526 selection.set("alias", exp.to_identifier(aliased_column)) 527 528 new_selections.append(selection) 529 530 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
533def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 534 """Makes sure all identifiers that need to be quoted are quoted.""" 535 return expression.transform( 536 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 537 )
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
540def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 541 """ 542 Pushes down the CTE alias columns into the projection, 543 544 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 545 546 Example: 547 >>> import sqlglot 548 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 549 >>> pushdown_cte_alias_columns(expression).sql() 550 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 551 552 Args: 553 expression: Expression to pushdown. 554 555 Returns: 556 The expression with the CTE aliases pushed down into the projection. 557 """ 558 for cte in expression.find_all(exp.CTE): 559 if cte.alias_column_names: 560 new_expressions = [] 561 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 562 if isinstance(projection, exp.Alias): 563 projection.set("alias", _alias) 564 else: 565 projection = alias(projection, alias=_alias) 566 new_expressions.append(projection) 567 cte.this.set("expressions", new_expressions) 568 569 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
572class Resolver: 573 """ 574 Helper for resolving columns. 575 576 This is a class so we can lazily load some things and easily share them across functions. 577 """ 578 579 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 580 self.scope = scope 581 self.schema = schema 582 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 583 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 584 self._all_columns: t.Optional[t.Set[str]] = None 585 self._infer_schema = infer_schema 586 587 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 588 """ 589 Get the table for a column name. 590 591 Args: 592 column_name: The column name to find the table for. 593 Returns: 594 The table name if it can be found/inferred. 595 """ 596 if self._unambiguous_columns is None: 597 self._unambiguous_columns = self._get_unambiguous_columns( 598 self._get_all_source_columns() 599 ) 600 601 table_name = self._unambiguous_columns.get(column_name) 602 603 if not table_name and self._infer_schema: 604 sources_without_schema = tuple( 605 source 606 for source, columns in self._get_all_source_columns().items() 607 if not columns or "*" in columns 608 ) 609 if len(sources_without_schema) == 1: 610 table_name = sources_without_schema[0] 611 612 if table_name not in self.scope.selected_sources: 613 return exp.to_identifier(table_name) 614 615 node, _ = self.scope.selected_sources.get(table_name) 616 617 if isinstance(node, exp.Subqueryable): 618 while node and node.alias != table_name: 619 node = node.parent 620 621 node_alias = node.args.get("alias") 622 if node_alias: 623 return exp.to_identifier(node_alias.this) 624 625 return exp.to_identifier(table_name) 626 627 @property 628 def all_columns(self) -> t.Set[str]: 629 """All available columns of all sources in this scope""" 630 if self._all_columns is None: 631 self._all_columns = { 632 column for columns in self._get_all_source_columns().values() for column in columns 633 } 634 return self._all_columns 635 636 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 637 """Resolve the source columns for a given source `name`.""" 638 if name not in self.scope.sources: 639 raise OptimizeError(f"Unknown table: {name}") 640 641 source = self.scope.sources[name] 642 643 if isinstance(source, exp.Table): 644 columns = self.schema.column_names(source, only_visible) 645 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 646 columns = source.expression.alias_column_names 647 else: 648 columns = source.expression.named_selects 649 650 node, _ = self.scope.selected_sources.get(name) or (None, None) 651 if isinstance(node, Scope): 652 column_aliases = node.expression.alias_column_names 653 elif isinstance(node, exp.Expression): 654 column_aliases = node.alias_column_names 655 else: 656 column_aliases = [] 657 658 # If the source's columns are aliased, their aliases shadow the corresponding column names 659 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] 660 661 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 662 if self._source_columns is None: 663 self._source_columns = { 664 source_name: self.get_source_columns(source_name) 665 for source_name, source in itertools.chain( 666 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 667 ) 668 } 669 return self._source_columns 670 671 def _get_unambiguous_columns( 672 self, source_columns: t.Dict[str, t.List[str]] 673 ) -> t.Dict[str, str]: 674 """ 675 Find all the unambiguous columns in sources. 676 677 Args: 678 source_columns: Mapping of names to source columns. 679 680 Returns: 681 Mapping of column name to source name. 682 """ 683 if not source_columns: 684 return {} 685 686 source_columns_pairs = list(source_columns.items()) 687 688 first_table, first_columns = source_columns_pairs[0] 689 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 690 all_columns = set(unambiguous_columns) 691 692 for table, columns in source_columns_pairs[1:]: 693 unique = self._find_unique_columns(columns) 694 ambiguous = set(all_columns).intersection(unique) 695 all_columns.update(columns) 696 697 for column in ambiguous: 698 unambiguous_columns.pop(column, None) 699 for column in unique.difference(ambiguous): 700 unambiguous_columns[column] = table 701 702 return unambiguous_columns 703 704 @staticmethod 705 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 706 """ 707 Find the unique columns in a list of columns. 708 709 Example: 710 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 711 ['a', 'c'] 712 713 This is necessary because duplicate column names are ambiguous. 714 """ 715 counts: t.Dict[str, int] = {} 716 for column in columns: 717 counts[column] = counts.get(column, 0) + 1 718 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
579 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 580 self.scope = scope 581 self.schema = schema 582 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 583 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 584 self._all_columns: t.Optional[t.Set[str]] = None 585 self._infer_schema = infer_schema
587 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 588 """ 589 Get the table for a column name. 590 591 Args: 592 column_name: The column name to find the table for. 593 Returns: 594 The table name if it can be found/inferred. 595 """ 596 if self._unambiguous_columns is None: 597 self._unambiguous_columns = self._get_unambiguous_columns( 598 self._get_all_source_columns() 599 ) 600 601 table_name = self._unambiguous_columns.get(column_name) 602 603 if not table_name and self._infer_schema: 604 sources_without_schema = tuple( 605 source 606 for source, columns in self._get_all_source_columns().items() 607 if not columns or "*" in columns 608 ) 609 if len(sources_without_schema) == 1: 610 table_name = sources_without_schema[0] 611 612 if table_name not in self.scope.selected_sources: 613 return exp.to_identifier(table_name) 614 615 node, _ = self.scope.selected_sources.get(table_name) 616 617 if isinstance(node, exp.Subqueryable): 618 while node and node.alias != table_name: 619 node = node.parent 620 621 node_alias = node.args.get("alias") 622 if node_alias: 623 return exp.to_identifier(node_alias.this) 624 625 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
627 @property 628 def all_columns(self) -> t.Set[str]: 629 """All available columns of all sources in this scope""" 630 if self._all_columns is None: 631 self._all_columns = { 632 column for columns in self._get_all_source_columns().values() for column in columns 633 } 634 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
636 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 637 """Resolve the source columns for a given source `name`.""" 638 if name not in self.scope.sources: 639 raise OptimizeError(f"Unknown table: {name}") 640 641 source = self.scope.sources[name] 642 643 if isinstance(source, exp.Table): 644 columns = self.schema.column_names(source, only_visible) 645 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 646 columns = source.expression.alias_column_names 647 else: 648 columns = source.expression.named_selects 649 650 node, _ = self.scope.selected_sources.get(name) or (None, None) 651 if isinstance(node, Scope): 652 column_aliases = node.expression.alias_column_names 653 elif isinstance(node, exp.Expression): 654 column_aliases = node.alias_column_names 655 else: 656 column_aliases = [] 657 658 # If the source's columns are aliased, their aliases shadow the corresponding column names 659 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
Resolve the source columns for a given source name
.