sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot._typing import E 8from sqlglot.dialects.dialect import Dialect, DialectType 9from sqlglot.errors import OptimizeError 10from sqlglot.helper import seq_get 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15 16def qualify_columns( 17 expression: exp.Expression, 18 schema: t.Dict | Schema, 19 expand_alias_refs: bool = True, 20 expand_stars: bool = True, 21 infer_schema: t.Optional[bool] = None, 22) -> exp.Expression: 23 """ 24 Rewrite sqlglot AST to have fully qualified columns. 25 26 Example: 27 >>> import sqlglot 28 >>> schema = {"tbl": {"col": "INT"}} 29 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 30 >>> qualify_columns(expression, schema).sql() 31 'SELECT tbl.col AS col FROM tbl' 32 33 Args: 34 expression: Expression to qualify. 35 schema: Database schema. 36 expand_alias_refs: Whether or not to expand references to aliases. 37 expand_stars: Whether or not to expand star queries. This is a necessary step 38 for most of the optimizer's rules to work; do not set to False unless you 39 know what you're doing! 40 infer_schema: Whether or not to infer the schema if missing. 41 42 Returns: 43 The qualified expression. 44 """ 45 schema = ensure_schema(schema) 46 infer_schema = schema.empty if infer_schema is None else infer_schema 47 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 48 49 for scope in traverse_scope(expression): 50 resolver = Resolver(scope, schema, infer_schema=infer_schema) 51 _pop_table_column_aliases(scope.ctes) 52 _pop_table_column_aliases(scope.derived_tables) 53 using_column_tables = _expand_using(scope, resolver) 54 55 if schema.empty and expand_alias_refs: 56 _expand_alias_refs(scope, resolver) 57 58 _qualify_columns(scope, resolver) 59 60 if not schema.empty and expand_alias_refs: 61 _expand_alias_refs(scope, resolver) 62 63 if not isinstance(scope.expression, exp.UDTF): 64 if expand_stars: 65 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 66 qualify_outputs(scope) 67 68 _expand_group_by(scope) 69 _expand_order_by(scope, resolver) 70 71 return expression 72 73 74def validate_qualify_columns(expression: E) -> E: 75 """Raise an `OptimizeError` if any columns aren't qualified""" 76 unqualified_columns = [] 77 for scope in traverse_scope(expression): 78 if isinstance(scope.expression, exp.Select): 79 unqualified_columns.extend(scope.unqualified_columns) 80 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 81 column = scope.external_columns[0] 82 raise OptimizeError( 83 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 84 ) 85 86 if unqualified_columns: 87 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 88 return expression 89 90 91def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 92 """ 93 Remove table column aliases. 94 95 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 96 """ 97 for derived_table in derived_tables: 98 table_alias = derived_table.args.get("alias") 99 if table_alias: 100 table_alias.args.pop("columns", None) 101 102 103def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 104 joins = list(scope.find_all(exp.Join)) 105 names = {join.alias_or_name for join in joins} 106 ordered = [key for key in scope.selected_sources if key not in names] 107 108 # Mapping of automatically joined column names to an ordered set of source names (dict). 109 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 110 111 for join in joins: 112 using = join.args.get("using") 113 114 if not using: 115 continue 116 117 join_table = join.alias_or_name 118 119 columns = {} 120 121 for source_name in scope.selected_sources: 122 if source_name in ordered: 123 for column_name in resolver.get_source_columns(source_name): 124 if column_name not in columns: 125 columns[column_name] = source_name 126 127 source_table = ordered[-1] 128 ordered.append(join_table) 129 join_columns = resolver.get_source_columns(join_table) 130 conditions = [] 131 132 for identifier in using: 133 identifier = identifier.name 134 table = columns.get(identifier) 135 136 if not table or identifier not in join_columns: 137 if (columns and "*" not in columns) and join_columns: 138 raise OptimizeError(f"Cannot automatically join: {identifier}") 139 140 table = table or source_table 141 conditions.append( 142 exp.condition( 143 exp.EQ( 144 this=exp.column(identifier, table=table), 145 expression=exp.column(identifier, table=join_table), 146 ) 147 ) 148 ) 149 150 # Set all values in the dict to None, because we only care about the key ordering 151 tables = column_tables.setdefault(identifier, {}) 152 if table not in tables: 153 tables[table] = None 154 if join_table not in tables: 155 tables[join_table] = None 156 157 join.args.pop("using") 158 join.set("on", exp.and_(*conditions, copy=False)) 159 160 if column_tables: 161 for column in scope.columns: 162 if not column.table and column.name in column_tables: 163 tables = column_tables[column.name] 164 coalesce = [exp.column(column.name, table=table) for table in tables] 165 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 166 167 # Ensure selects keep their output name 168 if isinstance(column.parent, exp.Select): 169 replacement = alias(replacement, alias=column.name, copy=False) 170 171 scope.replace(column, replacement) 172 173 return column_tables 174 175 176def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 177 expression = scope.expression 178 179 if not isinstance(expression, exp.Select): 180 return 181 182 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 183 184 def replace_columns( 185 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 186 ) -> None: 187 if not node: 188 return 189 190 for column, *_ in walk_in_scope(node): 191 if not isinstance(column, exp.Column): 192 continue 193 194 table = resolver.get_table(column.name) if resolve_table and not column.table else None 195 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 196 double_agg = ( 197 (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) 198 if alias_expr 199 else False 200 ) 201 202 if table and (not alias_expr or double_agg): 203 column.set("table", table) 204 elif not column.table and alias_expr and not double_agg: 205 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 206 if literal_index: 207 column.replace(exp.Literal.number(i)) 208 else: 209 column = column.replace(exp.paren(alias_expr)) 210 simplified = simplify_parens(column) 211 if simplified is not column: 212 column.replace(simplified) 213 214 for i, projection in enumerate(scope.expression.selects): 215 replace_columns(projection) 216 217 if isinstance(projection, exp.Alias): 218 alias_to_expression[projection.alias] = (projection.this, i + 1) 219 220 replace_columns(expression.args.get("where")) 221 replace_columns(expression.args.get("group"), literal_index=True) 222 replace_columns(expression.args.get("having"), resolve_table=True) 223 replace_columns(expression.args.get("qualify"), resolve_table=True) 224 scope.clear_cache() 225 226 227def _expand_group_by(scope: Scope) -> None: 228 expression = scope.expression 229 group = expression.args.get("group") 230 if not group: 231 return 232 233 group.set("expressions", _expand_positional_references(scope, group.expressions)) 234 expression.set("group", group) 235 236 237def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 238 order = scope.expression.args.get("order") 239 if not order: 240 return 241 242 ordereds = order.expressions 243 for ordered, new_expression in zip( 244 ordereds, 245 _expand_positional_references(scope, (o.this for o in ordereds), alias=True), 246 ): 247 for agg in ordered.find_all(exp.AggFunc): 248 for col in agg.find_all(exp.Column): 249 if not col.table: 250 col.set("table", resolver.get_table(col.name)) 251 252 ordered.set("this", new_expression) 253 254 if scope.expression.args.get("group"): 255 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 256 257 for ordered in ordereds: 258 ordered = ordered.this 259 260 ordered.replace( 261 exp.to_identifier(_select_by_pos(scope, ordered).alias) 262 if ordered.is_int 263 else selects.get(ordered, ordered) 264 ) 265 266 267def _expand_positional_references( 268 scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False 269) -> t.List[exp.Expression]: 270 new_nodes: t.List[exp.Expression] = [] 271 for node in expressions: 272 if node.is_int: 273 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 274 275 if alias: 276 new_nodes.append(exp.column(select.args["alias"].copy())) 277 else: 278 select = select.this 279 280 if isinstance(select, exp.Literal): 281 new_nodes.append(node) 282 else: 283 new_nodes.append(select.copy()) 284 else: 285 new_nodes.append(node) 286 287 return new_nodes 288 289 290def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 291 try: 292 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 293 except IndexError: 294 raise OptimizeError(f"Unknown output column: {node.name}") 295 296 297def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 298 """Disambiguate columns, ensuring each column specifies a source""" 299 for column in scope.columns: 300 column_table = column.table 301 column_name = column.name 302 303 if column_table and column_table in scope.sources: 304 source_columns = resolver.get_source_columns(column_table) 305 if source_columns and column_name not in source_columns and "*" not in source_columns: 306 raise OptimizeError(f"Unknown column: {column_name}") 307 308 if not column_table: 309 if scope.pivots and not column.find_ancestor(exp.Pivot): 310 # If the column is under the Pivot expression, we need to qualify it 311 # using the name of the pivoted source instead of the pivot's alias 312 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 313 continue 314 315 column_table = resolver.get_table(column_name) 316 317 # column_table can be a '' because bigquery unnest has no table alias 318 if column_table: 319 column.set("table", column_table) 320 elif column_table not in scope.sources and ( 321 not scope.parent 322 or column_table not in scope.parent.sources 323 or not scope.is_correlated_subquery 324 ): 325 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 326 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 327 328 root, *parts = column.parts 329 330 if root.name in scope.sources: 331 # struct is already qualified, but we still need to change the AST representation 332 column_table = root 333 root, *parts = parts 334 else: 335 column_table = resolver.get_table(root.name) 336 337 if column_table: 338 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 339 340 for pivot in scope.pivots: 341 for column in pivot.find_all(exp.Column): 342 if not column.table and column.name in resolver.all_columns: 343 column_table = resolver.get_table(column.name) 344 if column_table: 345 column.set("table", column_table) 346 347 348def _expand_stars( 349 scope: Scope, 350 resolver: Resolver, 351 using_column_tables: t.Dict[str, t.Any], 352 pseudocolumns: t.Set[str], 353) -> None: 354 """Expand stars to lists of column selections""" 355 356 new_selections = [] 357 except_columns: t.Dict[int, t.Set[str]] = {} 358 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 359 coalesced_columns = set() 360 361 # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future 362 pivot_columns = None 363 pivot_output_columns = None 364 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 365 366 has_pivoted_source = pivot and not pivot.args.get("unpivot") 367 if pivot and has_pivoted_source: 368 pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) 369 370 pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] 371 if not pivot_output_columns: 372 pivot_output_columns = [col.alias_or_name for col in pivot.expressions] 373 374 for expression in scope.expression.selects: 375 if isinstance(expression, exp.Star): 376 tables = list(scope.selected_sources) 377 _add_except_columns(expression, tables, except_columns) 378 _add_replace_columns(expression, tables, replace_columns) 379 elif expression.is_star: 380 tables = [expression.table] 381 _add_except_columns(expression.this, tables, except_columns) 382 _add_replace_columns(expression.this, tables, replace_columns) 383 else: 384 new_selections.append(expression) 385 continue 386 387 for table in tables: 388 if table not in scope.sources: 389 raise OptimizeError(f"Unknown table: {table}") 390 391 columns = resolver.get_source_columns(table, only_visible=True) 392 393 if pseudocolumns: 394 columns = [name for name in columns if name.upper() not in pseudocolumns] 395 396 if columns and "*" not in columns: 397 table_id = id(table) 398 columns_to_exclude = except_columns.get(table_id) or set() 399 400 if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: 401 implicit_columns = [col for col in columns if col not in pivot_columns] 402 new_selections.extend( 403 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 404 for name in implicit_columns + pivot_output_columns 405 if name not in columns_to_exclude 406 ) 407 continue 408 409 for name in columns: 410 if name in using_column_tables and table in using_column_tables[name]: 411 if name in coalesced_columns: 412 continue 413 414 coalesced_columns.add(name) 415 tables = using_column_tables[name] 416 coalesce = [exp.column(name, table=table) for table in tables] 417 418 new_selections.append( 419 alias( 420 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 421 alias=name, 422 copy=False, 423 ) 424 ) 425 elif name not in columns_to_exclude: 426 alias_ = replace_columns.get(table_id, {}).get(name, name) 427 column = exp.column(name, table=table) 428 new_selections.append( 429 alias(column, alias_, copy=False) if alias_ != name else column 430 ) 431 else: 432 return 433 434 # Ensures we don't overwrite the initial selections with an empty list 435 if new_selections: 436 scope.expression.set("expressions", new_selections) 437 438 439def _add_except_columns( 440 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 441) -> None: 442 except_ = expression.args.get("except") 443 444 if not except_: 445 return 446 447 columns = {e.name for e in except_} 448 449 for table in tables: 450 except_columns[id(table)] = columns 451 452 453def _add_replace_columns( 454 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 455) -> None: 456 replace = expression.args.get("replace") 457 458 if not replace: 459 return 460 461 columns = {e.this.name: e.alias for e in replace} 462 463 for table in tables: 464 replace_columns[id(table)] = columns 465 466 467def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 468 """Ensure all output columns are aliased""" 469 if isinstance(scope_or_expression, exp.Expression): 470 scope = build_scope(scope_or_expression) 471 if not isinstance(scope, Scope): 472 return 473 else: 474 scope = scope_or_expression 475 476 new_selections = [] 477 for i, (selection, aliased_column) in enumerate( 478 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 479 ): 480 if isinstance(selection, exp.Subquery): 481 if not selection.output_name: 482 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 483 elif not isinstance(selection, exp.Alias) and not selection.is_star: 484 selection = alias( 485 selection, 486 alias=selection.output_name or f"_col_{i}", 487 ) 488 if aliased_column: 489 selection.set("alias", exp.to_identifier(aliased_column)) 490 491 new_selections.append(selection) 492 493 scope.expression.set("expressions", new_selections) 494 495 496def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 497 """Makes sure all identifiers that need to be quoted are quoted.""" 498 return expression.transform( 499 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 500 ) 501 502 503def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 504 """ 505 Pushes down the CTE alias columns into the projection, 506 507 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 508 509 Example: 510 >>> import sqlglot 511 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 512 >>> pushdown_cte_alias_columns(expression).sql() 513 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 514 515 Args: 516 expression: Expression to pushdown. 517 518 Returns: 519 The expression with the CTE aliases pushed down into the projection. 520 """ 521 for cte in expression.find_all(exp.CTE): 522 if cte.alias_column_names: 523 new_expressions = [] 524 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 525 if isinstance(projection, exp.Alias): 526 projection.set("alias", _alias) 527 else: 528 projection = alias(projection, alias=_alias) 529 new_expressions.append(projection) 530 cte.this.set("expressions", new_expressions) 531 532 return expression 533 534 535class Resolver: 536 """ 537 Helper for resolving columns. 538 539 This is a class so we can lazily load some things and easily share them across functions. 540 """ 541 542 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 543 self.scope = scope 544 self.schema = schema 545 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 546 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 547 self._all_columns: t.Optional[t.Set[str]] = None 548 self._infer_schema = infer_schema 549 550 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 551 """ 552 Get the table for a column name. 553 554 Args: 555 column_name: The column name to find the table for. 556 Returns: 557 The table name if it can be found/inferred. 558 """ 559 if self._unambiguous_columns is None: 560 self._unambiguous_columns = self._get_unambiguous_columns( 561 self._get_all_source_columns() 562 ) 563 564 table_name = self._unambiguous_columns.get(column_name) 565 566 if not table_name and self._infer_schema: 567 sources_without_schema = tuple( 568 source 569 for source, columns in self._get_all_source_columns().items() 570 if not columns or "*" in columns 571 ) 572 if len(sources_without_schema) == 1: 573 table_name = sources_without_schema[0] 574 575 if table_name not in self.scope.selected_sources: 576 return exp.to_identifier(table_name) 577 578 node, _ = self.scope.selected_sources.get(table_name) 579 580 if isinstance(node, exp.Subqueryable): 581 while node and node.alias != table_name: 582 node = node.parent 583 584 node_alias = node.args.get("alias") 585 if node_alias: 586 return exp.to_identifier(node_alias.this) 587 588 return exp.to_identifier(table_name) 589 590 @property 591 def all_columns(self) -> t.Set[str]: 592 """All available columns of all sources in this scope""" 593 if self._all_columns is None: 594 self._all_columns = { 595 column for columns in self._get_all_source_columns().values() for column in columns 596 } 597 return self._all_columns 598 599 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 600 """Resolve the source columns for a given source `name`.""" 601 if name not in self.scope.sources: 602 raise OptimizeError(f"Unknown table: {name}") 603 604 source = self.scope.sources[name] 605 606 if isinstance(source, exp.Table): 607 columns = self.schema.column_names(source, only_visible) 608 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 609 columns = source.expression.alias_column_names 610 else: 611 columns = source.expression.named_selects 612 613 node, _ = self.scope.selected_sources.get(name) or (None, None) 614 if isinstance(node, Scope): 615 column_aliases = node.expression.alias_column_names 616 elif isinstance(node, exp.Expression): 617 column_aliases = node.alias_column_names 618 else: 619 column_aliases = [] 620 621 # If the source's columns are aliased, their aliases shadow the corresponding column names 622 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] 623 624 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 625 if self._source_columns is None: 626 self._source_columns = { 627 source_name: self.get_source_columns(source_name) 628 for source_name, source in itertools.chain( 629 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 630 ) 631 } 632 return self._source_columns 633 634 def _get_unambiguous_columns( 635 self, source_columns: t.Dict[str, t.List[str]] 636 ) -> t.Dict[str, str]: 637 """ 638 Find all the unambiguous columns in sources. 639 640 Args: 641 source_columns: Mapping of names to source columns. 642 643 Returns: 644 Mapping of column name to source name. 645 """ 646 if not source_columns: 647 return {} 648 649 source_columns_pairs = list(source_columns.items()) 650 651 first_table, first_columns = source_columns_pairs[0] 652 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 653 all_columns = set(unambiguous_columns) 654 655 for table, columns in source_columns_pairs[1:]: 656 unique = self._find_unique_columns(columns) 657 ambiguous = set(all_columns).intersection(unique) 658 all_columns.update(columns) 659 660 for column in ambiguous: 661 unambiguous_columns.pop(column, None) 662 for column in unique.difference(ambiguous): 663 unambiguous_columns[column] = table 664 665 return unambiguous_columns 666 667 @staticmethod 668 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 669 """ 670 Find the unique columns in a list of columns. 671 672 Example: 673 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 674 ['a', 'c'] 675 676 This is necessary because duplicate column names are ambiguous. 677 """ 678 counts: t.Dict[str, int] = {} 679 for column in columns: 680 counts[column] = counts.get(column, 0) + 1 681 return {column for column, count in counts.items() if count == 1}
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
17def qualify_columns( 18 expression: exp.Expression, 19 schema: t.Dict | Schema, 20 expand_alias_refs: bool = True, 21 expand_stars: bool = True, 22 infer_schema: t.Optional[bool] = None, 23) -> exp.Expression: 24 """ 25 Rewrite sqlglot AST to have fully qualified columns. 26 27 Example: 28 >>> import sqlglot 29 >>> schema = {"tbl": {"col": "INT"}} 30 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 31 >>> qualify_columns(expression, schema).sql() 32 'SELECT tbl.col AS col FROM tbl' 33 34 Args: 35 expression: Expression to qualify. 36 schema: Database schema. 37 expand_alias_refs: Whether or not to expand references to aliases. 38 expand_stars: Whether or not to expand star queries. This is a necessary step 39 for most of the optimizer's rules to work; do not set to False unless you 40 know what you're doing! 41 infer_schema: Whether or not to infer the schema if missing. 42 43 Returns: 44 The qualified expression. 45 """ 46 schema = ensure_schema(schema) 47 infer_schema = schema.empty if infer_schema is None else infer_schema 48 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 49 50 for scope in traverse_scope(expression): 51 resolver = Resolver(scope, schema, infer_schema=infer_schema) 52 _pop_table_column_aliases(scope.ctes) 53 _pop_table_column_aliases(scope.derived_tables) 54 using_column_tables = _expand_using(scope, resolver) 55 56 if schema.empty and expand_alias_refs: 57 _expand_alias_refs(scope, resolver) 58 59 _qualify_columns(scope, resolver) 60 61 if not schema.empty and expand_alias_refs: 62 _expand_alias_refs(scope, resolver) 63 64 if not isinstance(scope.expression, exp.UDTF): 65 if expand_stars: 66 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 67 qualify_outputs(scope) 68 69 _expand_group_by(scope) 70 _expand_order_by(scope, resolver) 71 72 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether or not to expand references to aliases.
- expand_stars: Whether or not to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether or not to infer the schema if missing.
Returns:
The qualified expression.
def
validate_qualify_columns(expression: ~E) -> ~E:
75def validate_qualify_columns(expression: E) -> E: 76 """Raise an `OptimizeError` if any columns aren't qualified""" 77 unqualified_columns = [] 78 for scope in traverse_scope(expression): 79 if isinstance(scope.expression, exp.Select): 80 unqualified_columns.extend(scope.unqualified_columns) 81 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 82 column = scope.external_columns[0] 83 raise OptimizeError( 84 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 85 ) 86 87 if unqualified_columns: 88 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 89 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
468def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 469 """Ensure all output columns are aliased""" 470 if isinstance(scope_or_expression, exp.Expression): 471 scope = build_scope(scope_or_expression) 472 if not isinstance(scope, Scope): 473 return 474 else: 475 scope = scope_or_expression 476 477 new_selections = [] 478 for i, (selection, aliased_column) in enumerate( 479 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 480 ): 481 if isinstance(selection, exp.Subquery): 482 if not selection.output_name: 483 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 484 elif not isinstance(selection, exp.Alias) and not selection.is_star: 485 selection = alias( 486 selection, 487 alias=selection.output_name or f"_col_{i}", 488 ) 489 if aliased_column: 490 selection.set("alias", exp.to_identifier(aliased_column)) 491 492 new_selections.append(selection) 493 494 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
497def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 498 """Makes sure all identifiers that need to be quoted are quoted.""" 499 return expression.transform( 500 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 501 )
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
504def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 505 """ 506 Pushes down the CTE alias columns into the projection, 507 508 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 509 510 Example: 511 >>> import sqlglot 512 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 513 >>> pushdown_cte_alias_columns(expression).sql() 514 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 515 516 Args: 517 expression: Expression to pushdown. 518 519 Returns: 520 The expression with the CTE aliases pushed down into the projection. 521 """ 522 for cte in expression.find_all(exp.CTE): 523 if cte.alias_column_names: 524 new_expressions = [] 525 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 526 if isinstance(projection, exp.Alias): 527 projection.set("alias", _alias) 528 else: 529 projection = alias(projection, alias=_alias) 530 new_expressions.append(projection) 531 cte.this.set("expressions", new_expressions) 532 533 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
536class Resolver: 537 """ 538 Helper for resolving columns. 539 540 This is a class so we can lazily load some things and easily share them across functions. 541 """ 542 543 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 544 self.scope = scope 545 self.schema = schema 546 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 547 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 548 self._all_columns: t.Optional[t.Set[str]] = None 549 self._infer_schema = infer_schema 550 551 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 552 """ 553 Get the table for a column name. 554 555 Args: 556 column_name: The column name to find the table for. 557 Returns: 558 The table name if it can be found/inferred. 559 """ 560 if self._unambiguous_columns is None: 561 self._unambiguous_columns = self._get_unambiguous_columns( 562 self._get_all_source_columns() 563 ) 564 565 table_name = self._unambiguous_columns.get(column_name) 566 567 if not table_name and self._infer_schema: 568 sources_without_schema = tuple( 569 source 570 for source, columns in self._get_all_source_columns().items() 571 if not columns or "*" in columns 572 ) 573 if len(sources_without_schema) == 1: 574 table_name = sources_without_schema[0] 575 576 if table_name not in self.scope.selected_sources: 577 return exp.to_identifier(table_name) 578 579 node, _ = self.scope.selected_sources.get(table_name) 580 581 if isinstance(node, exp.Subqueryable): 582 while node and node.alias != table_name: 583 node = node.parent 584 585 node_alias = node.args.get("alias") 586 if node_alias: 587 return exp.to_identifier(node_alias.this) 588 589 return exp.to_identifier(table_name) 590 591 @property 592 def all_columns(self) -> t.Set[str]: 593 """All available columns of all sources in this scope""" 594 if self._all_columns is None: 595 self._all_columns = { 596 column for columns in self._get_all_source_columns().values() for column in columns 597 } 598 return self._all_columns 599 600 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 601 """Resolve the source columns for a given source `name`.""" 602 if name not in self.scope.sources: 603 raise OptimizeError(f"Unknown table: {name}") 604 605 source = self.scope.sources[name] 606 607 if isinstance(source, exp.Table): 608 columns = self.schema.column_names(source, only_visible) 609 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 610 columns = source.expression.alias_column_names 611 else: 612 columns = source.expression.named_selects 613 614 node, _ = self.scope.selected_sources.get(name) or (None, None) 615 if isinstance(node, Scope): 616 column_aliases = node.expression.alias_column_names 617 elif isinstance(node, exp.Expression): 618 column_aliases = node.alias_column_names 619 else: 620 column_aliases = [] 621 622 # If the source's columns are aliased, their aliases shadow the corresponding column names 623 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] 624 625 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 626 if self._source_columns is None: 627 self._source_columns = { 628 source_name: self.get_source_columns(source_name) 629 for source_name, source in itertools.chain( 630 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 631 ) 632 } 633 return self._source_columns 634 635 def _get_unambiguous_columns( 636 self, source_columns: t.Dict[str, t.List[str]] 637 ) -> t.Dict[str, str]: 638 """ 639 Find all the unambiguous columns in sources. 640 641 Args: 642 source_columns: Mapping of names to source columns. 643 644 Returns: 645 Mapping of column name to source name. 646 """ 647 if not source_columns: 648 return {} 649 650 source_columns_pairs = list(source_columns.items()) 651 652 first_table, first_columns = source_columns_pairs[0] 653 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 654 all_columns = set(unambiguous_columns) 655 656 for table, columns in source_columns_pairs[1:]: 657 unique = self._find_unique_columns(columns) 658 ambiguous = set(all_columns).intersection(unique) 659 all_columns.update(columns) 660 661 for column in ambiguous: 662 unambiguous_columns.pop(column, None) 663 for column in unique.difference(ambiguous): 664 unambiguous_columns[column] = table 665 666 return unambiguous_columns 667 668 @staticmethod 669 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 670 """ 671 Find the unique columns in a list of columns. 672 673 Example: 674 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 675 ['a', 'c'] 676 677 This is necessary because duplicate column names are ambiguous. 678 """ 679 counts: t.Dict[str, int] = {} 680 for column in columns: 681 counts[column] = counts.get(column, 0) + 1 682 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
543 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 544 self.scope = scope 545 self.schema = schema 546 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 547 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 548 self._all_columns: t.Optional[t.Set[str]] = None 549 self._infer_schema = infer_schema
551 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 552 """ 553 Get the table for a column name. 554 555 Args: 556 column_name: The column name to find the table for. 557 Returns: 558 The table name if it can be found/inferred. 559 """ 560 if self._unambiguous_columns is None: 561 self._unambiguous_columns = self._get_unambiguous_columns( 562 self._get_all_source_columns() 563 ) 564 565 table_name = self._unambiguous_columns.get(column_name) 566 567 if not table_name and self._infer_schema: 568 sources_without_schema = tuple( 569 source 570 for source, columns in self._get_all_source_columns().items() 571 if not columns or "*" in columns 572 ) 573 if len(sources_without_schema) == 1: 574 table_name = sources_without_schema[0] 575 576 if table_name not in self.scope.selected_sources: 577 return exp.to_identifier(table_name) 578 579 node, _ = self.scope.selected_sources.get(table_name) 580 581 if isinstance(node, exp.Subqueryable): 582 while node and node.alias != table_name: 583 node = node.parent 584 585 node_alias = node.args.get("alias") 586 if node_alias: 587 return exp.to_identifier(node_alias.this) 588 589 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
591 @property 592 def all_columns(self) -> t.Set[str]: 593 """All available columns of all sources in this scope""" 594 if self._all_columns is None: 595 self._all_columns = { 596 column for columns in self._get_all_source_columns().values() for column in columns 597 } 598 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
600 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 601 """Resolve the source columns for a given source `name`.""" 602 if name not in self.scope.sources: 603 raise OptimizeError(f"Unknown table: {name}") 604 605 source = self.scope.sources[name] 606 607 if isinstance(source, exp.Table): 608 columns = self.schema.column_names(source, only_visible) 609 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 610 columns = source.expression.alias_column_names 611 else: 612 columns = source.expression.named_selects 613 614 node, _ = self.scope.selected_sources.get(name) or (None, None) 615 if isinstance(node, Scope): 616 column_aliases = node.expression.alias_column_names 617 elif isinstance(node, exp.Expression): 618 column_aliases = node.alias_column_names 619 else: 620 column_aliases = [] 621 622 # If the source's columns are aliased, their aliases shadow the corresponding column names 623 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
Resolve the source columns for a given source name
.