sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 11from sqlglot.optimizer.simplify import simplify_parens 12from sqlglot.schema import Schema, ensure_schema 13 14if t.TYPE_CHECKING: 15 from sqlglot._typing import E 16 17 18def qualify_columns( 19 expression: exp.Expression, 20 schema: t.Dict | Schema, 21 expand_alias_refs: bool = True, 22 expand_stars: bool = True, 23 infer_schema: t.Optional[bool] = None, 24) -> exp.Expression: 25 """ 26 Rewrite sqlglot AST to have fully qualified columns. 27 28 Example: 29 >>> import sqlglot 30 >>> schema = {"tbl": {"col": "INT"}} 31 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 32 >>> qualify_columns(expression, schema).sql() 33 'SELECT tbl.col AS col FROM tbl' 34 35 Args: 36 expression: Expression to qualify. 37 schema: Database schema. 38 expand_alias_refs: Whether to expand references to aliases. 39 expand_stars: Whether to expand star queries. This is a necessary step 40 for most of the optimizer's rules to work; do not set to False unless you 41 know what you're doing! 42 infer_schema: Whether to infer the schema if missing. 43 44 Returns: 45 The qualified expression. 46 47 Notes: 48 - Currently only handles a single PIVOT or UNPIVOT operator 49 """ 50 schema = ensure_schema(schema) 51 infer_schema = schema.empty if infer_schema is None else infer_schema 52 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 53 54 for scope in traverse_scope(expression): 55 resolver = Resolver(scope, schema, infer_schema=infer_schema) 56 _pop_table_column_aliases(scope.ctes) 57 _pop_table_column_aliases(scope.derived_tables) 58 using_column_tables = _expand_using(scope, resolver) 59 60 if schema.empty and expand_alias_refs: 61 _expand_alias_refs(scope, resolver) 62 63 _qualify_columns(scope, resolver) 64 65 if not schema.empty and expand_alias_refs: 66 _expand_alias_refs(scope, resolver) 67 68 if not isinstance(scope.expression, exp.UDTF): 69 if expand_stars: 70 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 71 qualify_outputs(scope) 72 73 _expand_group_by(scope) 74 _expand_order_by(scope, resolver) 75 76 return expression 77 78 79def validate_qualify_columns(expression: E) -> E: 80 """Raise an `OptimizeError` if any columns aren't qualified""" 81 all_unqualified_columns = [] 82 for scope in traverse_scope(expression): 83 if isinstance(scope.expression, exp.Select): 84 unqualified_columns = scope.unqualified_columns 85 86 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 87 column = scope.external_columns[0] 88 for_table = f" for table: '{column.table}'" if column.table else "" 89 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 90 91 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 92 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 93 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 94 # this list here to ensure those in the former category will be excluded. 95 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 96 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 97 98 all_unqualified_columns.extend(unqualified_columns) 99 100 if all_unqualified_columns: 101 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 102 103 return expression 104 105 106def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 107 name_column = [] 108 field = unpivot.args.get("field") 109 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 110 name_column.append(field.this) 111 112 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 113 return itertools.chain(name_column, value_columns) 114 115 116def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 117 """ 118 Remove table column aliases. 119 120 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 121 """ 122 for derived_table in derived_tables: 123 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 124 continue 125 table_alias = derived_table.args.get("alias") 126 if table_alias: 127 table_alias.args.pop("columns", None) 128 129 130def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 131 joins = list(scope.find_all(exp.Join)) 132 names = {join.alias_or_name for join in joins} 133 ordered = [key for key in scope.selected_sources if key not in names] 134 135 # Mapping of automatically joined column names to an ordered set of source names (dict). 136 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 137 138 for join in joins: 139 using = join.args.get("using") 140 141 if not using: 142 continue 143 144 join_table = join.alias_or_name 145 146 columns = {} 147 148 for source_name in scope.selected_sources: 149 if source_name in ordered: 150 for column_name in resolver.get_source_columns(source_name): 151 if column_name not in columns: 152 columns[column_name] = source_name 153 154 source_table = ordered[-1] 155 ordered.append(join_table) 156 join_columns = resolver.get_source_columns(join_table) 157 conditions = [] 158 159 for identifier in using: 160 identifier = identifier.name 161 table = columns.get(identifier) 162 163 if not table or identifier not in join_columns: 164 if (columns and "*" not in columns) and join_columns: 165 raise OptimizeError(f"Cannot automatically join: {identifier}") 166 167 table = table or source_table 168 conditions.append( 169 exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table)) 170 ) 171 172 # Set all values in the dict to None, because we only care about the key ordering 173 tables = column_tables.setdefault(identifier, {}) 174 if table not in tables: 175 tables[table] = None 176 if join_table not in tables: 177 tables[join_table] = None 178 179 join.args.pop("using") 180 join.set("on", exp.and_(*conditions, copy=False)) 181 182 if column_tables: 183 for column in scope.columns: 184 if not column.table and column.name in column_tables: 185 tables = column_tables[column.name] 186 coalesce = [exp.column(column.name, table=table) for table in tables] 187 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 188 189 # Ensure selects keep their output name 190 if isinstance(column.parent, exp.Select): 191 replacement = alias(replacement, alias=column.name, copy=False) 192 193 scope.replace(column, replacement) 194 195 return column_tables 196 197 198def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 199 expression = scope.expression 200 201 if not isinstance(expression, exp.Select): 202 return 203 204 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 205 206 def replace_columns( 207 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 208 ) -> None: 209 if not node: 210 return 211 212 for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star): 213 if not isinstance(column, exp.Column): 214 continue 215 216 table = resolver.get_table(column.name) if resolve_table and not column.table else None 217 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 218 double_agg = ( 219 (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) 220 if alias_expr 221 else False 222 ) 223 224 if table and (not alias_expr or double_agg): 225 column.set("table", table) 226 elif not column.table and alias_expr and not double_agg: 227 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 228 if literal_index: 229 column.replace(exp.Literal.number(i)) 230 else: 231 column = column.replace(exp.paren(alias_expr)) 232 simplified = simplify_parens(column) 233 if simplified is not column: 234 column.replace(simplified) 235 236 for i, projection in enumerate(scope.expression.selects): 237 replace_columns(projection) 238 239 if isinstance(projection, exp.Alias): 240 alias_to_expression[projection.alias] = (projection.this, i + 1) 241 242 replace_columns(expression.args.get("where")) 243 replace_columns(expression.args.get("group"), literal_index=True) 244 replace_columns(expression.args.get("having"), resolve_table=True) 245 replace_columns(expression.args.get("qualify"), resolve_table=True) 246 247 scope.clear_cache() 248 249 250def _expand_group_by(scope: Scope) -> None: 251 expression = scope.expression 252 group = expression.args.get("group") 253 if not group: 254 return 255 256 group.set("expressions", _expand_positional_references(scope, group.expressions)) 257 expression.set("group", group) 258 259 260def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 261 order = scope.expression.args.get("order") 262 if not order: 263 return 264 265 ordereds = order.expressions 266 for ordered, new_expression in zip( 267 ordereds, 268 _expand_positional_references(scope, (o.this for o in ordereds), alias=True), 269 ): 270 for agg in ordered.find_all(exp.AggFunc): 271 for col in agg.find_all(exp.Column): 272 if not col.table: 273 col.set("table", resolver.get_table(col.name)) 274 275 ordered.set("this", new_expression) 276 277 if scope.expression.args.get("group"): 278 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 279 280 for ordered in ordereds: 281 ordered = ordered.this 282 283 ordered.replace( 284 exp.to_identifier(_select_by_pos(scope, ordered).alias) 285 if ordered.is_int 286 else selects.get(ordered, ordered) 287 ) 288 289 290def _expand_positional_references( 291 scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False 292) -> t.List[exp.Expression]: 293 new_nodes: t.List[exp.Expression] = [] 294 for node in expressions: 295 if node.is_int: 296 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 297 298 if alias: 299 new_nodes.append(exp.column(select.args["alias"].copy())) 300 else: 301 select = select.this 302 303 if isinstance(select, exp.Literal): 304 new_nodes.append(node) 305 else: 306 new_nodes.append(select.copy()) 307 else: 308 new_nodes.append(node) 309 310 return new_nodes 311 312 313def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 314 try: 315 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 316 except IndexError: 317 raise OptimizeError(f"Unknown output column: {node.name}") 318 319 320def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 321 """Disambiguate columns, ensuring each column specifies a source""" 322 for column in scope.columns: 323 column_table = column.table 324 column_name = column.name 325 326 if column_table and column_table in scope.sources: 327 source_columns = resolver.get_source_columns(column_table) 328 if source_columns and column_name not in source_columns and "*" not in source_columns: 329 raise OptimizeError(f"Unknown column: {column_name}") 330 331 if not column_table: 332 if scope.pivots and not column.find_ancestor(exp.Pivot): 333 # If the column is under the Pivot expression, we need to qualify it 334 # using the name of the pivoted source instead of the pivot's alias 335 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 336 continue 337 338 column_table = resolver.get_table(column_name) 339 340 # column_table can be a '' because bigquery unnest has no table alias 341 if column_table: 342 column.set("table", column_table) 343 elif column_table not in scope.sources and ( 344 not scope.parent 345 or column_table not in scope.parent.sources 346 or not scope.is_correlated_subquery 347 ): 348 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 349 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 350 351 root, *parts = column.parts 352 353 if root.name in scope.sources: 354 # struct is already qualified, but we still need to change the AST representation 355 column_table = root 356 root, *parts = parts 357 else: 358 column_table = resolver.get_table(root.name) 359 360 if column_table: 361 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 362 363 for pivot in scope.pivots: 364 for column in pivot.find_all(exp.Column): 365 if not column.table and column.name in resolver.all_columns: 366 column_table = resolver.get_table(column.name) 367 if column_table: 368 column.set("table", column_table) 369 370 371def _expand_stars( 372 scope: Scope, 373 resolver: Resolver, 374 using_column_tables: t.Dict[str, t.Any], 375 pseudocolumns: t.Set[str], 376) -> None: 377 """Expand stars to lists of column selections""" 378 379 new_selections = [] 380 except_columns: t.Dict[int, t.Set[str]] = {} 381 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 382 coalesced_columns = set() 383 384 pivot_output_columns = None 385 pivot_exclude_columns = None 386 387 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 388 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 389 if pivot.unpivot: 390 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 391 392 field = pivot.args.get("field") 393 if isinstance(field, exp.In): 394 pivot_exclude_columns = { 395 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 396 } 397 else: 398 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 399 400 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 401 if not pivot_output_columns: 402 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 403 404 for expression in scope.expression.selects: 405 if isinstance(expression, exp.Star): 406 tables = list(scope.selected_sources) 407 _add_except_columns(expression, tables, except_columns) 408 _add_replace_columns(expression, tables, replace_columns) 409 elif expression.is_star and not isinstance(expression, exp.Dot): 410 tables = [expression.table] 411 _add_except_columns(expression.this, tables, except_columns) 412 _add_replace_columns(expression.this, tables, replace_columns) 413 else: 414 new_selections.append(expression) 415 continue 416 417 for table in tables: 418 if table not in scope.sources: 419 raise OptimizeError(f"Unknown table: {table}") 420 421 columns = resolver.get_source_columns(table, only_visible=True) 422 columns = columns or scope.outer_column_list 423 424 if pseudocolumns: 425 columns = [name for name in columns if name.upper() not in pseudocolumns] 426 427 if not columns or "*" in columns: 428 return 429 430 table_id = id(table) 431 columns_to_exclude = except_columns.get(table_id) or set() 432 433 if pivot: 434 if pivot_output_columns and pivot_exclude_columns: 435 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 436 pivot_columns.extend(pivot_output_columns) 437 else: 438 pivot_columns = pivot.alias_column_names 439 440 if pivot_columns: 441 new_selections.extend( 442 alias(exp.column(name, table=pivot.alias), name, copy=False) 443 for name in pivot_columns 444 if name not in columns_to_exclude 445 ) 446 continue 447 448 for name in columns: 449 if name in columns_to_exclude or name in coalesced_columns: 450 continue 451 if name in using_column_tables and table in using_column_tables[name]: 452 coalesced_columns.add(name) 453 tables = using_column_tables[name] 454 coalesce = [exp.column(name, table=table) for table in tables] 455 456 new_selections.append( 457 alias( 458 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 459 alias=name, 460 copy=False, 461 ) 462 ) 463 else: 464 alias_ = replace_columns.get(table_id, {}).get(name, name) 465 column = exp.column(name, table=table) 466 new_selections.append( 467 alias(column, alias_, copy=False) if alias_ != name else column 468 ) 469 470 # Ensures we don't overwrite the initial selections with an empty list 471 if new_selections and isinstance(scope.expression, exp.Select): 472 scope.expression.set("expressions", new_selections) 473 474 475def _add_except_columns( 476 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 477) -> None: 478 except_ = expression.args.get("except") 479 480 if not except_: 481 return 482 483 columns = {e.name for e in except_} 484 485 for table in tables: 486 except_columns[id(table)] = columns 487 488 489def _add_replace_columns( 490 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 491) -> None: 492 replace = expression.args.get("replace") 493 494 if not replace: 495 return 496 497 columns = {e.this.name: e.alias for e in replace} 498 499 for table in tables: 500 replace_columns[id(table)] = columns 501 502 503def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 504 """Ensure all output columns are aliased""" 505 if isinstance(scope_or_expression, exp.Expression): 506 scope = build_scope(scope_or_expression) 507 if not isinstance(scope, Scope): 508 return 509 else: 510 scope = scope_or_expression 511 512 new_selections = [] 513 for i, (selection, aliased_column) in enumerate( 514 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 515 ): 516 if selection is None: 517 break 518 519 if isinstance(selection, exp.Subquery): 520 if not selection.output_name: 521 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 522 elif not isinstance(selection, exp.Alias) and not selection.is_star: 523 selection = alias( 524 selection, 525 alias=selection.output_name or f"_col_{i}", 526 copy=False, 527 ) 528 if aliased_column: 529 selection.set("alias", exp.to_identifier(aliased_column)) 530 531 new_selections.append(selection) 532 533 if isinstance(scope.expression, exp.Select): 534 scope.expression.set("expressions", new_selections) 535 536 537def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 538 """Makes sure all identifiers that need to be quoted are quoted.""" 539 return expression.transform( 540 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 541 ) 542 543 544def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 545 """ 546 Pushes down the CTE alias columns into the projection, 547 548 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 549 550 Example: 551 >>> import sqlglot 552 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 553 >>> pushdown_cte_alias_columns(expression).sql() 554 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 555 556 Args: 557 expression: Expression to pushdown. 558 559 Returns: 560 The expression with the CTE aliases pushed down into the projection. 561 """ 562 for cte in expression.find_all(exp.CTE): 563 if cte.alias_column_names: 564 new_expressions = [] 565 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 566 if isinstance(projection, exp.Alias): 567 projection.set("alias", _alias) 568 else: 569 projection = alias(projection, alias=_alias) 570 new_expressions.append(projection) 571 cte.this.set("expressions", new_expressions) 572 573 return expression 574 575 576class Resolver: 577 """ 578 Helper for resolving columns. 579 580 This is a class so we can lazily load some things and easily share them across functions. 581 """ 582 583 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 584 self.scope = scope 585 self.schema = schema 586 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 587 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 588 self._all_columns: t.Optional[t.Set[str]] = None 589 self._infer_schema = infer_schema 590 591 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 592 """ 593 Get the table for a column name. 594 595 Args: 596 column_name: The column name to find the table for. 597 Returns: 598 The table name if it can be found/inferred. 599 """ 600 if self._unambiguous_columns is None: 601 self._unambiguous_columns = self._get_unambiguous_columns( 602 self._get_all_source_columns() 603 ) 604 605 table_name = self._unambiguous_columns.get(column_name) 606 607 if not table_name and self._infer_schema: 608 sources_without_schema = tuple( 609 source 610 for source, columns in self._get_all_source_columns().items() 611 if not columns or "*" in columns 612 ) 613 if len(sources_without_schema) == 1: 614 table_name = sources_without_schema[0] 615 616 if table_name not in self.scope.selected_sources: 617 return exp.to_identifier(table_name) 618 619 node, _ = self.scope.selected_sources.get(table_name) 620 621 if isinstance(node, exp.Query): 622 while node and node.alias != table_name: 623 node = node.parent 624 625 node_alias = node.args.get("alias") 626 if node_alias: 627 return exp.to_identifier(node_alias.this) 628 629 return exp.to_identifier(table_name) 630 631 @property 632 def all_columns(self) -> t.Set[str]: 633 """All available columns of all sources in this scope""" 634 if self._all_columns is None: 635 self._all_columns = { 636 column for columns in self._get_all_source_columns().values() for column in columns 637 } 638 return self._all_columns 639 640 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 641 """Resolve the source columns for a given source `name`.""" 642 if name not in self.scope.sources: 643 raise OptimizeError(f"Unknown table: {name}") 644 645 source = self.scope.sources[name] 646 647 if isinstance(source, exp.Table): 648 columns = self.schema.column_names(source, only_visible) 649 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 650 columns = source.expression.alias_column_names 651 else: 652 columns = source.expression.named_selects 653 654 node, _ = self.scope.selected_sources.get(name) or (None, None) 655 if isinstance(node, Scope): 656 column_aliases = node.expression.alias_column_names 657 elif isinstance(node, exp.Expression): 658 column_aliases = node.alias_column_names 659 else: 660 column_aliases = [] 661 662 if column_aliases: 663 # If the source's columns are aliased, their aliases shadow the corresponding column names. 664 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 665 return [ 666 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 667 ] 668 return columns 669 670 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 671 if self._source_columns is None: 672 self._source_columns = { 673 source_name: self.get_source_columns(source_name) 674 for source_name, source in itertools.chain( 675 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 676 ) 677 } 678 return self._source_columns 679 680 def _get_unambiguous_columns( 681 self, source_columns: t.Dict[str, t.Sequence[str]] 682 ) -> t.Mapping[str, str]: 683 """ 684 Find all the unambiguous columns in sources. 685 686 Args: 687 source_columns: Mapping of names to source columns. 688 689 Returns: 690 Mapping of column name to source name. 691 """ 692 if not source_columns: 693 return {} 694 695 source_columns_pairs = list(source_columns.items()) 696 697 first_table, first_columns = source_columns_pairs[0] 698 699 if len(source_columns_pairs) == 1: 700 # Performance optimization - avoid copying first_columns if there is only one table. 701 return SingleValuedMapping(first_columns, first_table) 702 703 unambiguous_columns = {col: first_table for col in first_columns} 704 all_columns = set(unambiguous_columns) 705 706 for table, columns in source_columns_pairs[1:]: 707 unique = set(columns) 708 ambiguous = all_columns.intersection(unique) 709 all_columns.update(columns) 710 711 for column in ambiguous: 712 unambiguous_columns.pop(column, None) 713 for column in unique.difference(ambiguous): 714 unambiguous_columns[column] = table 715 716 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25) -> exp.Expression: 26 """ 27 Rewrite sqlglot AST to have fully qualified columns. 28 29 Example: 30 >>> import sqlglot 31 >>> schema = {"tbl": {"col": "INT"}} 32 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 33 >>> qualify_columns(expression, schema).sql() 34 'SELECT tbl.col AS col FROM tbl' 35 36 Args: 37 expression: Expression to qualify. 38 schema: Database schema. 39 expand_alias_refs: Whether to expand references to aliases. 40 expand_stars: Whether to expand star queries. This is a necessary step 41 for most of the optimizer's rules to work; do not set to False unless you 42 know what you're doing! 43 infer_schema: Whether to infer the schema if missing. 44 45 Returns: 46 The qualified expression. 47 48 Notes: 49 - Currently only handles a single PIVOT or UNPIVOT operator 50 """ 51 schema = ensure_schema(schema) 52 infer_schema = schema.empty if infer_schema is None else infer_schema 53 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 54 55 for scope in traverse_scope(expression): 56 resolver = Resolver(scope, schema, infer_schema=infer_schema) 57 _pop_table_column_aliases(scope.ctes) 58 _pop_table_column_aliases(scope.derived_tables) 59 using_column_tables = _expand_using(scope, resolver) 60 61 if schema.empty and expand_alias_refs: 62 _expand_alias_refs(scope, resolver) 63 64 _qualify_columns(scope, resolver) 65 66 if not schema.empty and expand_alias_refs: 67 _expand_alias_refs(scope, resolver) 68 69 if not isinstance(scope.expression, exp.UDTF): 70 if expand_stars: 71 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 72 qualify_outputs(scope) 73 74 _expand_group_by(scope) 75 _expand_order_by(scope, resolver) 76 77 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
80def validate_qualify_columns(expression: E) -> E: 81 """Raise an `OptimizeError` if any columns aren't qualified""" 82 all_unqualified_columns = [] 83 for scope in traverse_scope(expression): 84 if isinstance(scope.expression, exp.Select): 85 unqualified_columns = scope.unqualified_columns 86 87 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 88 column = scope.external_columns[0] 89 for_table = f" for table: '{column.table}'" if column.table else "" 90 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 91 92 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 93 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 94 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 95 # this list here to ensure those in the former category will be excluded. 96 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 97 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 98 99 all_unqualified_columns.extend(unqualified_columns) 100 101 if all_unqualified_columns: 102 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 103 104 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
504def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 505 """Ensure all output columns are aliased""" 506 if isinstance(scope_or_expression, exp.Expression): 507 scope = build_scope(scope_or_expression) 508 if not isinstance(scope, Scope): 509 return 510 else: 511 scope = scope_or_expression 512 513 new_selections = [] 514 for i, (selection, aliased_column) in enumerate( 515 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 516 ): 517 if selection is None: 518 break 519 520 if isinstance(selection, exp.Subquery): 521 if not selection.output_name: 522 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 523 elif not isinstance(selection, exp.Alias) and not selection.is_star: 524 selection = alias( 525 selection, 526 alias=selection.output_name or f"_col_{i}", 527 copy=False, 528 ) 529 if aliased_column: 530 selection.set("alias", exp.to_identifier(aliased_column)) 531 532 new_selections.append(selection) 533 534 if isinstance(scope.expression, exp.Select): 535 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
538def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 539 """Makes sure all identifiers that need to be quoted are quoted.""" 540 return expression.transform( 541 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 542 )
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
545def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 546 """ 547 Pushes down the CTE alias columns into the projection, 548 549 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 550 551 Example: 552 >>> import sqlglot 553 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 554 >>> pushdown_cte_alias_columns(expression).sql() 555 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 556 557 Args: 558 expression: Expression to pushdown. 559 560 Returns: 561 The expression with the CTE aliases pushed down into the projection. 562 """ 563 for cte in expression.find_all(exp.CTE): 564 if cte.alias_column_names: 565 new_expressions = [] 566 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 567 if isinstance(projection, exp.Alias): 568 projection.set("alias", _alias) 569 else: 570 projection = alias(projection, alias=_alias) 571 new_expressions.append(projection) 572 cte.this.set("expressions", new_expressions) 573 574 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
577class Resolver: 578 """ 579 Helper for resolving columns. 580 581 This is a class so we can lazily load some things and easily share them across functions. 582 """ 583 584 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 585 self.scope = scope 586 self.schema = schema 587 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 588 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 589 self._all_columns: t.Optional[t.Set[str]] = None 590 self._infer_schema = infer_schema 591 592 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 593 """ 594 Get the table for a column name. 595 596 Args: 597 column_name: The column name to find the table for. 598 Returns: 599 The table name if it can be found/inferred. 600 """ 601 if self._unambiguous_columns is None: 602 self._unambiguous_columns = self._get_unambiguous_columns( 603 self._get_all_source_columns() 604 ) 605 606 table_name = self._unambiguous_columns.get(column_name) 607 608 if not table_name and self._infer_schema: 609 sources_without_schema = tuple( 610 source 611 for source, columns in self._get_all_source_columns().items() 612 if not columns or "*" in columns 613 ) 614 if len(sources_without_schema) == 1: 615 table_name = sources_without_schema[0] 616 617 if table_name not in self.scope.selected_sources: 618 return exp.to_identifier(table_name) 619 620 node, _ = self.scope.selected_sources.get(table_name) 621 622 if isinstance(node, exp.Query): 623 while node and node.alias != table_name: 624 node = node.parent 625 626 node_alias = node.args.get("alias") 627 if node_alias: 628 return exp.to_identifier(node_alias.this) 629 630 return exp.to_identifier(table_name) 631 632 @property 633 def all_columns(self) -> t.Set[str]: 634 """All available columns of all sources in this scope""" 635 if self._all_columns is None: 636 self._all_columns = { 637 column for columns in self._get_all_source_columns().values() for column in columns 638 } 639 return self._all_columns 640 641 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 642 """Resolve the source columns for a given source `name`.""" 643 if name not in self.scope.sources: 644 raise OptimizeError(f"Unknown table: {name}") 645 646 source = self.scope.sources[name] 647 648 if isinstance(source, exp.Table): 649 columns = self.schema.column_names(source, only_visible) 650 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 651 columns = source.expression.alias_column_names 652 else: 653 columns = source.expression.named_selects 654 655 node, _ = self.scope.selected_sources.get(name) or (None, None) 656 if isinstance(node, Scope): 657 column_aliases = node.expression.alias_column_names 658 elif isinstance(node, exp.Expression): 659 column_aliases = node.alias_column_names 660 else: 661 column_aliases = [] 662 663 if column_aliases: 664 # If the source's columns are aliased, their aliases shadow the corresponding column names. 665 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 666 return [ 667 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 668 ] 669 return columns 670 671 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 672 if self._source_columns is None: 673 self._source_columns = { 674 source_name: self.get_source_columns(source_name) 675 for source_name, source in itertools.chain( 676 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 677 ) 678 } 679 return self._source_columns 680 681 def _get_unambiguous_columns( 682 self, source_columns: t.Dict[str, t.Sequence[str]] 683 ) -> t.Mapping[str, str]: 684 """ 685 Find all the unambiguous columns in sources. 686 687 Args: 688 source_columns: Mapping of names to source columns. 689 690 Returns: 691 Mapping of column name to source name. 692 """ 693 if not source_columns: 694 return {} 695 696 source_columns_pairs = list(source_columns.items()) 697 698 first_table, first_columns = source_columns_pairs[0] 699 700 if len(source_columns_pairs) == 1: 701 # Performance optimization - avoid copying first_columns if there is only one table. 702 return SingleValuedMapping(first_columns, first_table) 703 704 unambiguous_columns = {col: first_table for col in first_columns} 705 all_columns = set(unambiguous_columns) 706 707 for table, columns in source_columns_pairs[1:]: 708 unique = set(columns) 709 ambiguous = all_columns.intersection(unique) 710 all_columns.update(columns) 711 712 for column in ambiguous: 713 unambiguous_columns.pop(column, None) 714 for column in unique.difference(ambiguous): 715 unambiguous_columns[column] = table 716 717 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
584 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 585 self.scope = scope 586 self.schema = schema 587 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 588 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 589 self._all_columns: t.Optional[t.Set[str]] = None 590 self._infer_schema = infer_schema
592 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 593 """ 594 Get the table for a column name. 595 596 Args: 597 column_name: The column name to find the table for. 598 Returns: 599 The table name if it can be found/inferred. 600 """ 601 if self._unambiguous_columns is None: 602 self._unambiguous_columns = self._get_unambiguous_columns( 603 self._get_all_source_columns() 604 ) 605 606 table_name = self._unambiguous_columns.get(column_name) 607 608 if not table_name and self._infer_schema: 609 sources_without_schema = tuple( 610 source 611 for source, columns in self._get_all_source_columns().items() 612 if not columns or "*" in columns 613 ) 614 if len(sources_without_schema) == 1: 615 table_name = sources_without_schema[0] 616 617 if table_name not in self.scope.selected_sources: 618 return exp.to_identifier(table_name) 619 620 node, _ = self.scope.selected_sources.get(table_name) 621 622 if isinstance(node, exp.Query): 623 while node and node.alias != table_name: 624 node = node.parent 625 626 node_alias = node.args.get("alias") 627 if node_alias: 628 return exp.to_identifier(node_alias.this) 629 630 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
632 @property 633 def all_columns(self) -> t.Set[str]: 634 """All available columns of all sources in this scope""" 635 if self._all_columns is None: 636 self._all_columns = { 637 column for columns in self._get_all_source_columns().values() for column in columns 638 } 639 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
641 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 642 """Resolve the source columns for a given source `name`.""" 643 if name not in self.scope.sources: 644 raise OptimizeError(f"Unknown table: {name}") 645 646 source = self.scope.sources[name] 647 648 if isinstance(source, exp.Table): 649 columns = self.schema.column_names(source, only_visible) 650 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 651 columns = source.expression.alias_column_names 652 else: 653 columns = source.expression.named_selects 654 655 node, _ = self.scope.selected_sources.get(name) or (None, None) 656 if isinstance(node, Scope): 657 column_aliases = node.expression.alias_column_names 658 elif isinstance(node, exp.Expression): 659 column_aliases = node.alias_column_names 660 else: 661 column_aliases = [] 662 663 if column_aliases: 664 # If the source's columns are aliased, their aliases shadow the corresponding column names. 665 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 666 return [ 667 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 668 ] 669 return columns
Resolve the source columns for a given source name
.