sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 11from sqlglot.optimizer.simplify import simplify_parens 12from sqlglot.schema import Schema, ensure_schema 13 14if t.TYPE_CHECKING: 15 from sqlglot._typing import E 16 17 18def qualify_columns( 19 expression: exp.Expression, 20 schema: t.Dict | Schema, 21 expand_alias_refs: bool = True, 22 expand_stars: bool = True, 23 infer_schema: t.Optional[bool] = None, 24) -> exp.Expression: 25 """ 26 Rewrite sqlglot AST to have fully qualified columns. 27 28 Example: 29 >>> import sqlglot 30 >>> schema = {"tbl": {"col": "INT"}} 31 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 32 >>> qualify_columns(expression, schema).sql() 33 'SELECT tbl.col AS col FROM tbl' 34 35 Args: 36 expression: Expression to qualify. 37 schema: Database schema. 38 expand_alias_refs: Whether to expand references to aliases. 39 expand_stars: Whether to expand star queries. This is a necessary step 40 for most of the optimizer's rules to work; do not set to False unless you 41 know what you're doing! 42 infer_schema: Whether to infer the schema if missing. 43 44 Returns: 45 The qualified expression. 46 47 Notes: 48 - Currently only handles a single PIVOT or UNPIVOT operator 49 """ 50 schema = ensure_schema(schema) 51 infer_schema = schema.empty if infer_schema is None else infer_schema 52 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 53 54 for scope in traverse_scope(expression): 55 resolver = Resolver(scope, schema, infer_schema=infer_schema) 56 _pop_table_column_aliases(scope.ctes) 57 _pop_table_column_aliases(scope.derived_tables) 58 using_column_tables = _expand_using(scope, resolver) 59 60 if schema.empty and expand_alias_refs: 61 _expand_alias_refs(scope, resolver) 62 63 _qualify_columns(scope, resolver) 64 65 if not schema.empty and expand_alias_refs: 66 _expand_alias_refs(scope, resolver) 67 68 if not isinstance(scope.expression, exp.UDTF): 69 if expand_stars: 70 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 71 qualify_outputs(scope) 72 73 _expand_group_by(scope) 74 _expand_order_by(scope, resolver) 75 76 return expression 77 78 79def validate_qualify_columns(expression: E) -> E: 80 """Raise an `OptimizeError` if any columns aren't qualified""" 81 all_unqualified_columns = [] 82 for scope in traverse_scope(expression): 83 if isinstance(scope.expression, exp.Select): 84 unqualified_columns = scope.unqualified_columns 85 86 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 87 column = scope.external_columns[0] 88 for_table = f" for table: '{column.table}'" if column.table else "" 89 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 90 91 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 92 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 93 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 94 # this list here to ensure those in the former category will be excluded. 95 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 96 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 97 98 all_unqualified_columns.extend(unqualified_columns) 99 100 if all_unqualified_columns: 101 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 102 103 return expression 104 105 106def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 107 name_column = [] 108 field = unpivot.args.get("field") 109 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 110 name_column.append(field.this) 111 112 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 113 return itertools.chain(name_column, value_columns) 114 115 116def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 117 """ 118 Remove table column aliases. 119 120 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 121 """ 122 for derived_table in derived_tables: 123 table_alias = derived_table.args.get("alias") 124 if table_alias: 125 table_alias.args.pop("columns", None) 126 127 128def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 129 joins = list(scope.find_all(exp.Join)) 130 names = {join.alias_or_name for join in joins} 131 ordered = [key for key in scope.selected_sources if key not in names] 132 133 # Mapping of automatically joined column names to an ordered set of source names (dict). 134 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 135 136 for join in joins: 137 using = join.args.get("using") 138 139 if not using: 140 continue 141 142 join_table = join.alias_or_name 143 144 columns = {} 145 146 for source_name in scope.selected_sources: 147 if source_name in ordered: 148 for column_name in resolver.get_source_columns(source_name): 149 if column_name not in columns: 150 columns[column_name] = source_name 151 152 source_table = ordered[-1] 153 ordered.append(join_table) 154 join_columns = resolver.get_source_columns(join_table) 155 conditions = [] 156 157 for identifier in using: 158 identifier = identifier.name 159 table = columns.get(identifier) 160 161 if not table or identifier not in join_columns: 162 if (columns and "*" not in columns) and join_columns: 163 raise OptimizeError(f"Cannot automatically join: {identifier}") 164 165 table = table or source_table 166 conditions.append( 167 exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table)) 168 ) 169 170 # Set all values in the dict to None, because we only care about the key ordering 171 tables = column_tables.setdefault(identifier, {}) 172 if table not in tables: 173 tables[table] = None 174 if join_table not in tables: 175 tables[join_table] = None 176 177 join.args.pop("using") 178 join.set("on", exp.and_(*conditions, copy=False)) 179 180 if column_tables: 181 for column in scope.columns: 182 if not column.table and column.name in column_tables: 183 tables = column_tables[column.name] 184 coalesce = [exp.column(column.name, table=table) for table in tables] 185 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 186 187 # Ensure selects keep their output name 188 if isinstance(column.parent, exp.Select): 189 replacement = alias(replacement, alias=column.name, copy=False) 190 191 scope.replace(column, replacement) 192 193 return column_tables 194 195 196def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 197 expression = scope.expression 198 199 if not isinstance(expression, exp.Select): 200 return 201 202 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 203 204 def replace_columns( 205 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 206 ) -> None: 207 if not node: 208 return 209 210 for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star): 211 if not isinstance(column, exp.Column): 212 continue 213 214 table = resolver.get_table(column.name) if resolve_table and not column.table else None 215 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 216 double_agg = ( 217 (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) 218 if alias_expr 219 else False 220 ) 221 222 if table and (not alias_expr or double_agg): 223 column.set("table", table) 224 elif not column.table and alias_expr and not double_agg: 225 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 226 if literal_index: 227 column.replace(exp.Literal.number(i)) 228 else: 229 column = column.replace(exp.paren(alias_expr)) 230 simplified = simplify_parens(column) 231 if simplified is not column: 232 column.replace(simplified) 233 234 for i, projection in enumerate(scope.expression.selects): 235 replace_columns(projection) 236 237 if isinstance(projection, exp.Alias): 238 alias_to_expression[projection.alias] = (projection.this, i + 1) 239 240 replace_columns(expression.args.get("where")) 241 replace_columns(expression.args.get("group"), literal_index=True) 242 replace_columns(expression.args.get("having"), resolve_table=True) 243 replace_columns(expression.args.get("qualify"), resolve_table=True) 244 245 scope.clear_cache() 246 247 248def _expand_group_by(scope: Scope) -> None: 249 expression = scope.expression 250 group = expression.args.get("group") 251 if not group: 252 return 253 254 group.set("expressions", _expand_positional_references(scope, group.expressions)) 255 expression.set("group", group) 256 257 258def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 259 order = scope.expression.args.get("order") 260 if not order: 261 return 262 263 ordereds = order.expressions 264 for ordered, new_expression in zip( 265 ordereds, 266 _expand_positional_references(scope, (o.this for o in ordereds), alias=True), 267 ): 268 for agg in ordered.find_all(exp.AggFunc): 269 for col in agg.find_all(exp.Column): 270 if not col.table: 271 col.set("table", resolver.get_table(col.name)) 272 273 ordered.set("this", new_expression) 274 275 if scope.expression.args.get("group"): 276 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 277 278 for ordered in ordereds: 279 ordered = ordered.this 280 281 ordered.replace( 282 exp.to_identifier(_select_by_pos(scope, ordered).alias) 283 if ordered.is_int 284 else selects.get(ordered, ordered) 285 ) 286 287 288def _expand_positional_references( 289 scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False 290) -> t.List[exp.Expression]: 291 new_nodes: t.List[exp.Expression] = [] 292 for node in expressions: 293 if node.is_int: 294 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 295 296 if alias: 297 new_nodes.append(exp.column(select.args["alias"].copy())) 298 else: 299 select = select.this 300 301 if isinstance(select, exp.Literal): 302 new_nodes.append(node) 303 else: 304 new_nodes.append(select.copy()) 305 else: 306 new_nodes.append(node) 307 308 return new_nodes 309 310 311def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 312 try: 313 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 314 except IndexError: 315 raise OptimizeError(f"Unknown output column: {node.name}") 316 317 318def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 319 """Disambiguate columns, ensuring each column specifies a source""" 320 for column in scope.columns: 321 column_table = column.table 322 column_name = column.name 323 324 if column_table and column_table in scope.sources: 325 source_columns = resolver.get_source_columns(column_table) 326 if source_columns and column_name not in source_columns and "*" not in source_columns: 327 raise OptimizeError(f"Unknown column: {column_name}") 328 329 if not column_table: 330 if scope.pivots and not column.find_ancestor(exp.Pivot): 331 # If the column is under the Pivot expression, we need to qualify it 332 # using the name of the pivoted source instead of the pivot's alias 333 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 334 continue 335 336 column_table = resolver.get_table(column_name) 337 338 # column_table can be a '' because bigquery unnest has no table alias 339 if column_table: 340 column.set("table", column_table) 341 elif column_table not in scope.sources and ( 342 not scope.parent 343 or column_table not in scope.parent.sources 344 or not scope.is_correlated_subquery 345 ): 346 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 347 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 348 349 root, *parts = column.parts 350 351 if root.name in scope.sources: 352 # struct is already qualified, but we still need to change the AST representation 353 column_table = root 354 root, *parts = parts 355 else: 356 column_table = resolver.get_table(root.name) 357 358 if column_table: 359 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 360 361 for pivot in scope.pivots: 362 for column in pivot.find_all(exp.Column): 363 if not column.table and column.name in resolver.all_columns: 364 column_table = resolver.get_table(column.name) 365 if column_table: 366 column.set("table", column_table) 367 368 369def _expand_stars( 370 scope: Scope, 371 resolver: Resolver, 372 using_column_tables: t.Dict[str, t.Any], 373 pseudocolumns: t.Set[str], 374) -> None: 375 """Expand stars to lists of column selections""" 376 377 new_selections = [] 378 except_columns: t.Dict[int, t.Set[str]] = {} 379 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 380 coalesced_columns = set() 381 382 pivot_output_columns = None 383 pivot_exclude_columns = None 384 385 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 386 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 387 if pivot.unpivot: 388 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 389 390 field = pivot.args.get("field") 391 if isinstance(field, exp.In): 392 pivot_exclude_columns = { 393 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 394 } 395 else: 396 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 397 398 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 399 if not pivot_output_columns: 400 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 401 402 for expression in scope.expression.selects: 403 if isinstance(expression, exp.Star): 404 tables = list(scope.selected_sources) 405 _add_except_columns(expression, tables, except_columns) 406 _add_replace_columns(expression, tables, replace_columns) 407 elif expression.is_star: 408 tables = [expression.table] 409 _add_except_columns(expression.this, tables, except_columns) 410 _add_replace_columns(expression.this, tables, replace_columns) 411 else: 412 new_selections.append(expression) 413 continue 414 415 for table in tables: 416 if table not in scope.sources: 417 raise OptimizeError(f"Unknown table: {table}") 418 419 columns = resolver.get_source_columns(table, only_visible=True) 420 columns = columns or scope.outer_column_list 421 422 if pseudocolumns: 423 columns = [name for name in columns if name.upper() not in pseudocolumns] 424 425 if not columns or "*" in columns: 426 return 427 428 table_id = id(table) 429 columns_to_exclude = except_columns.get(table_id) or set() 430 431 if pivot: 432 if pivot_output_columns and pivot_exclude_columns: 433 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 434 pivot_columns.extend(pivot_output_columns) 435 else: 436 pivot_columns = pivot.alias_column_names 437 438 if pivot_columns: 439 new_selections.extend( 440 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 441 for name in pivot_columns 442 if name not in columns_to_exclude 443 ) 444 continue 445 446 for name in columns: 447 if name in columns_to_exclude or name in coalesced_columns: 448 continue 449 if name in using_column_tables and table in using_column_tables[name]: 450 coalesced_columns.add(name) 451 tables = using_column_tables[name] 452 coalesce = [exp.column(name, table=table) for table in tables] 453 454 new_selections.append( 455 alias( 456 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 457 alias=name, 458 copy=False, 459 ) 460 ) 461 else: 462 alias_ = replace_columns.get(table_id, {}).get(name, name) 463 column = exp.column(name, table=table) 464 new_selections.append( 465 alias(column, alias_, copy=False) if alias_ != name else column 466 ) 467 468 # Ensures we don't overwrite the initial selections with an empty list 469 if new_selections and isinstance(scope.expression, exp.Select): 470 scope.expression.set("expressions", new_selections) 471 472 473def _add_except_columns( 474 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 475) -> None: 476 except_ = expression.args.get("except") 477 478 if not except_: 479 return 480 481 columns = {e.name for e in except_} 482 483 for table in tables: 484 except_columns[id(table)] = columns 485 486 487def _add_replace_columns( 488 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 489) -> None: 490 replace = expression.args.get("replace") 491 492 if not replace: 493 return 494 495 columns = {e.this.name: e.alias for e in replace} 496 497 for table in tables: 498 replace_columns[id(table)] = columns 499 500 501def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 502 """Ensure all output columns are aliased""" 503 if isinstance(scope_or_expression, exp.Expression): 504 scope = build_scope(scope_or_expression) 505 if not isinstance(scope, Scope): 506 return 507 else: 508 scope = scope_or_expression 509 510 new_selections = [] 511 for i, (selection, aliased_column) in enumerate( 512 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 513 ): 514 if selection is None: 515 break 516 517 if isinstance(selection, exp.Subquery): 518 if not selection.output_name: 519 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 520 elif not isinstance(selection, exp.Alias) and not selection.is_star: 521 selection = alias( 522 selection, 523 alias=selection.output_name or f"_col_{i}", 524 copy=False, 525 ) 526 if aliased_column: 527 selection.set("alias", exp.to_identifier(aliased_column)) 528 529 new_selections.append(selection) 530 531 if isinstance(scope.expression, exp.Select): 532 scope.expression.set("expressions", new_selections) 533 534 535def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 536 """Makes sure all identifiers that need to be quoted are quoted.""" 537 return expression.transform( 538 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 539 ) 540 541 542def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 543 """ 544 Pushes down the CTE alias columns into the projection, 545 546 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 547 548 Example: 549 >>> import sqlglot 550 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 551 >>> pushdown_cte_alias_columns(expression).sql() 552 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 553 554 Args: 555 expression: Expression to pushdown. 556 557 Returns: 558 The expression with the CTE aliases pushed down into the projection. 559 """ 560 for cte in expression.find_all(exp.CTE): 561 if cte.alias_column_names: 562 new_expressions = [] 563 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 564 if isinstance(projection, exp.Alias): 565 projection.set("alias", _alias) 566 else: 567 projection = alias(projection, alias=_alias) 568 new_expressions.append(projection) 569 cte.this.set("expressions", new_expressions) 570 571 return expression 572 573 574class Resolver: 575 """ 576 Helper for resolving columns. 577 578 This is a class so we can lazily load some things and easily share them across functions. 579 """ 580 581 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 582 self.scope = scope 583 self.schema = schema 584 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 585 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 586 self._all_columns: t.Optional[t.Set[str]] = None 587 self._infer_schema = infer_schema 588 589 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 590 """ 591 Get the table for a column name. 592 593 Args: 594 column_name: The column name to find the table for. 595 Returns: 596 The table name if it can be found/inferred. 597 """ 598 if self._unambiguous_columns is None: 599 self._unambiguous_columns = self._get_unambiguous_columns( 600 self._get_all_source_columns() 601 ) 602 603 table_name = self._unambiguous_columns.get(column_name) 604 605 if not table_name and self._infer_schema: 606 sources_without_schema = tuple( 607 source 608 for source, columns in self._get_all_source_columns().items() 609 if not columns or "*" in columns 610 ) 611 if len(sources_without_schema) == 1: 612 table_name = sources_without_schema[0] 613 614 if table_name not in self.scope.selected_sources: 615 return exp.to_identifier(table_name) 616 617 node, _ = self.scope.selected_sources.get(table_name) 618 619 if isinstance(node, exp.Query): 620 while node and node.alias != table_name: 621 node = node.parent 622 623 node_alias = node.args.get("alias") 624 if node_alias: 625 return exp.to_identifier(node_alias.this) 626 627 return exp.to_identifier(table_name) 628 629 @property 630 def all_columns(self) -> t.Set[str]: 631 """All available columns of all sources in this scope""" 632 if self._all_columns is None: 633 self._all_columns = { 634 column for columns in self._get_all_source_columns().values() for column in columns 635 } 636 return self._all_columns 637 638 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 639 """Resolve the source columns for a given source `name`.""" 640 if name not in self.scope.sources: 641 raise OptimizeError(f"Unknown table: {name}") 642 643 source = self.scope.sources[name] 644 645 if isinstance(source, exp.Table): 646 columns = self.schema.column_names(source, only_visible) 647 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 648 columns = source.expression.alias_column_names 649 else: 650 columns = source.expression.named_selects 651 652 node, _ = self.scope.selected_sources.get(name) or (None, None) 653 if isinstance(node, Scope): 654 column_aliases = node.expression.alias_column_names 655 elif isinstance(node, exp.Expression): 656 column_aliases = node.alias_column_names 657 else: 658 column_aliases = [] 659 660 if column_aliases: 661 # If the source's columns are aliased, their aliases shadow the corresponding column names. 662 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 663 return [ 664 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 665 ] 666 return columns 667 668 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 669 if self._source_columns is None: 670 self._source_columns = { 671 source_name: self.get_source_columns(source_name) 672 for source_name, source in itertools.chain( 673 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 674 ) 675 } 676 return self._source_columns 677 678 def _get_unambiguous_columns( 679 self, source_columns: t.Dict[str, t.Sequence[str]] 680 ) -> t.Mapping[str, str]: 681 """ 682 Find all the unambiguous columns in sources. 683 684 Args: 685 source_columns: Mapping of names to source columns. 686 687 Returns: 688 Mapping of column name to source name. 689 """ 690 if not source_columns: 691 return {} 692 693 source_columns_pairs = list(source_columns.items()) 694 695 first_table, first_columns = source_columns_pairs[0] 696 697 if len(source_columns_pairs) == 1: 698 # Performance optimization - avoid copying first_columns if there is only one table. 699 return SingleValuedMapping(first_columns, first_table) 700 701 unambiguous_columns = {col: first_table for col in first_columns} 702 all_columns = set(unambiguous_columns) 703 704 for table, columns in source_columns_pairs[1:]: 705 unique = set(columns) 706 ambiguous = all_columns.intersection(unique) 707 all_columns.update(columns) 708 709 for column in ambiguous: 710 unambiguous_columns.pop(column, None) 711 for column in unique.difference(ambiguous): 712 unambiguous_columns[column] = table 713 714 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25) -> exp.Expression: 26 """ 27 Rewrite sqlglot AST to have fully qualified columns. 28 29 Example: 30 >>> import sqlglot 31 >>> schema = {"tbl": {"col": "INT"}} 32 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 33 >>> qualify_columns(expression, schema).sql() 34 'SELECT tbl.col AS col FROM tbl' 35 36 Args: 37 expression: Expression to qualify. 38 schema: Database schema. 39 expand_alias_refs: Whether to expand references to aliases. 40 expand_stars: Whether to expand star queries. This is a necessary step 41 for most of the optimizer's rules to work; do not set to False unless you 42 know what you're doing! 43 infer_schema: Whether to infer the schema if missing. 44 45 Returns: 46 The qualified expression. 47 48 Notes: 49 - Currently only handles a single PIVOT or UNPIVOT operator 50 """ 51 schema = ensure_schema(schema) 52 infer_schema = schema.empty if infer_schema is None else infer_schema 53 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 54 55 for scope in traverse_scope(expression): 56 resolver = Resolver(scope, schema, infer_schema=infer_schema) 57 _pop_table_column_aliases(scope.ctes) 58 _pop_table_column_aliases(scope.derived_tables) 59 using_column_tables = _expand_using(scope, resolver) 60 61 if schema.empty and expand_alias_refs: 62 _expand_alias_refs(scope, resolver) 63 64 _qualify_columns(scope, resolver) 65 66 if not schema.empty and expand_alias_refs: 67 _expand_alias_refs(scope, resolver) 68 69 if not isinstance(scope.expression, exp.UDTF): 70 if expand_stars: 71 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 72 qualify_outputs(scope) 73 74 _expand_group_by(scope) 75 _expand_order_by(scope, resolver) 76 77 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
80def validate_qualify_columns(expression: E) -> E: 81 """Raise an `OptimizeError` if any columns aren't qualified""" 82 all_unqualified_columns = [] 83 for scope in traverse_scope(expression): 84 if isinstance(scope.expression, exp.Select): 85 unqualified_columns = scope.unqualified_columns 86 87 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 88 column = scope.external_columns[0] 89 for_table = f" for table: '{column.table}'" if column.table else "" 90 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 91 92 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 93 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 94 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 95 # this list here to ensure those in the former category will be excluded. 96 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 97 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 98 99 all_unqualified_columns.extend(unqualified_columns) 100 101 if all_unqualified_columns: 102 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 103 104 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
502def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 503 """Ensure all output columns are aliased""" 504 if isinstance(scope_or_expression, exp.Expression): 505 scope = build_scope(scope_or_expression) 506 if not isinstance(scope, Scope): 507 return 508 else: 509 scope = scope_or_expression 510 511 new_selections = [] 512 for i, (selection, aliased_column) in enumerate( 513 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 514 ): 515 if selection is None: 516 break 517 518 if isinstance(selection, exp.Subquery): 519 if not selection.output_name: 520 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 521 elif not isinstance(selection, exp.Alias) and not selection.is_star: 522 selection = alias( 523 selection, 524 alias=selection.output_name or f"_col_{i}", 525 copy=False, 526 ) 527 if aliased_column: 528 selection.set("alias", exp.to_identifier(aliased_column)) 529 530 new_selections.append(selection) 531 532 if isinstance(scope.expression, exp.Select): 533 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
536def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 537 """Makes sure all identifiers that need to be quoted are quoted.""" 538 return expression.transform( 539 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 540 )
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
543def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 544 """ 545 Pushes down the CTE alias columns into the projection, 546 547 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 548 549 Example: 550 >>> import sqlglot 551 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 552 >>> pushdown_cte_alias_columns(expression).sql() 553 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 554 555 Args: 556 expression: Expression to pushdown. 557 558 Returns: 559 The expression with the CTE aliases pushed down into the projection. 560 """ 561 for cte in expression.find_all(exp.CTE): 562 if cte.alias_column_names: 563 new_expressions = [] 564 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 565 if isinstance(projection, exp.Alias): 566 projection.set("alias", _alias) 567 else: 568 projection = alias(projection, alias=_alias) 569 new_expressions.append(projection) 570 cte.this.set("expressions", new_expressions) 571 572 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
575class Resolver: 576 """ 577 Helper for resolving columns. 578 579 This is a class so we can lazily load some things and easily share them across functions. 580 """ 581 582 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 583 self.scope = scope 584 self.schema = schema 585 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 586 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 587 self._all_columns: t.Optional[t.Set[str]] = None 588 self._infer_schema = infer_schema 589 590 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 591 """ 592 Get the table for a column name. 593 594 Args: 595 column_name: The column name to find the table for. 596 Returns: 597 The table name if it can be found/inferred. 598 """ 599 if self._unambiguous_columns is None: 600 self._unambiguous_columns = self._get_unambiguous_columns( 601 self._get_all_source_columns() 602 ) 603 604 table_name = self._unambiguous_columns.get(column_name) 605 606 if not table_name and self._infer_schema: 607 sources_without_schema = tuple( 608 source 609 for source, columns in self._get_all_source_columns().items() 610 if not columns or "*" in columns 611 ) 612 if len(sources_without_schema) == 1: 613 table_name = sources_without_schema[0] 614 615 if table_name not in self.scope.selected_sources: 616 return exp.to_identifier(table_name) 617 618 node, _ = self.scope.selected_sources.get(table_name) 619 620 if isinstance(node, exp.Query): 621 while node and node.alias != table_name: 622 node = node.parent 623 624 node_alias = node.args.get("alias") 625 if node_alias: 626 return exp.to_identifier(node_alias.this) 627 628 return exp.to_identifier(table_name) 629 630 @property 631 def all_columns(self) -> t.Set[str]: 632 """All available columns of all sources in this scope""" 633 if self._all_columns is None: 634 self._all_columns = { 635 column for columns in self._get_all_source_columns().values() for column in columns 636 } 637 return self._all_columns 638 639 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 640 """Resolve the source columns for a given source `name`.""" 641 if name not in self.scope.sources: 642 raise OptimizeError(f"Unknown table: {name}") 643 644 source = self.scope.sources[name] 645 646 if isinstance(source, exp.Table): 647 columns = self.schema.column_names(source, only_visible) 648 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 649 columns = source.expression.alias_column_names 650 else: 651 columns = source.expression.named_selects 652 653 node, _ = self.scope.selected_sources.get(name) or (None, None) 654 if isinstance(node, Scope): 655 column_aliases = node.expression.alias_column_names 656 elif isinstance(node, exp.Expression): 657 column_aliases = node.alias_column_names 658 else: 659 column_aliases = [] 660 661 if column_aliases: 662 # If the source's columns are aliased, their aliases shadow the corresponding column names. 663 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 664 return [ 665 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 666 ] 667 return columns 668 669 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 670 if self._source_columns is None: 671 self._source_columns = { 672 source_name: self.get_source_columns(source_name) 673 for source_name, source in itertools.chain( 674 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 675 ) 676 } 677 return self._source_columns 678 679 def _get_unambiguous_columns( 680 self, source_columns: t.Dict[str, t.Sequence[str]] 681 ) -> t.Mapping[str, str]: 682 """ 683 Find all the unambiguous columns in sources. 684 685 Args: 686 source_columns: Mapping of names to source columns. 687 688 Returns: 689 Mapping of column name to source name. 690 """ 691 if not source_columns: 692 return {} 693 694 source_columns_pairs = list(source_columns.items()) 695 696 first_table, first_columns = source_columns_pairs[0] 697 698 if len(source_columns_pairs) == 1: 699 # Performance optimization - avoid copying first_columns if there is only one table. 700 return SingleValuedMapping(first_columns, first_table) 701 702 unambiguous_columns = {col: first_table for col in first_columns} 703 all_columns = set(unambiguous_columns) 704 705 for table, columns in source_columns_pairs[1:]: 706 unique = set(columns) 707 ambiguous = all_columns.intersection(unique) 708 all_columns.update(columns) 709 710 for column in ambiguous: 711 unambiguous_columns.pop(column, None) 712 for column in unique.difference(ambiguous): 713 unambiguous_columns[column] = table 714 715 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
582 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 583 self.scope = scope 584 self.schema = schema 585 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 586 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 587 self._all_columns: t.Optional[t.Set[str]] = None 588 self._infer_schema = infer_schema
590 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 591 """ 592 Get the table for a column name. 593 594 Args: 595 column_name: The column name to find the table for. 596 Returns: 597 The table name if it can be found/inferred. 598 """ 599 if self._unambiguous_columns is None: 600 self._unambiguous_columns = self._get_unambiguous_columns( 601 self._get_all_source_columns() 602 ) 603 604 table_name = self._unambiguous_columns.get(column_name) 605 606 if not table_name and self._infer_schema: 607 sources_without_schema = tuple( 608 source 609 for source, columns in self._get_all_source_columns().items() 610 if not columns or "*" in columns 611 ) 612 if len(sources_without_schema) == 1: 613 table_name = sources_without_schema[0] 614 615 if table_name not in self.scope.selected_sources: 616 return exp.to_identifier(table_name) 617 618 node, _ = self.scope.selected_sources.get(table_name) 619 620 if isinstance(node, exp.Query): 621 while node and node.alias != table_name: 622 node = node.parent 623 624 node_alias = node.args.get("alias") 625 if node_alias: 626 return exp.to_identifier(node_alias.this) 627 628 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
630 @property 631 def all_columns(self) -> t.Set[str]: 632 """All available columns of all sources in this scope""" 633 if self._all_columns is None: 634 self._all_columns = { 635 column for columns in self._get_all_source_columns().values() for column in columns 636 } 637 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
639 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 640 """Resolve the source columns for a given source `name`.""" 641 if name not in self.scope.sources: 642 raise OptimizeError(f"Unknown table: {name}") 643 644 source = self.scope.sources[name] 645 646 if isinstance(source, exp.Table): 647 columns = self.schema.column_names(source, only_visible) 648 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 649 columns = source.expression.alias_column_names 650 else: 651 columns = source.expression.named_selects 652 653 node, _ = self.scope.selected_sources.get(name) or (None, None) 654 if isinstance(node, Scope): 655 column_aliases = node.expression.alias_column_names 656 elif isinstance(node, exp.Expression): 657 column_aliases = node.alias_column_names 658 else: 659 column_aliases = [] 660 661 if column_aliases: 662 # If the source's columns are aliased, their aliases shadow the corresponding column names. 663 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 664 return [ 665 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 666 ] 667 return columns
Resolve the source columns for a given source name
.