sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot._typing import E 8from sqlglot.dialects.dialect import Dialect, DialectType 9from sqlglot.errors import OptimizeError 10from sqlglot.helper import seq_get 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15 16def qualify_columns( 17 expression: exp.Expression, 18 schema: t.Dict | Schema, 19 expand_alias_refs: bool = True, 20 infer_schema: t.Optional[bool] = None, 21) -> exp.Expression: 22 """ 23 Rewrite sqlglot AST to have fully qualified columns. 24 25 Example: 26 >>> import sqlglot 27 >>> schema = {"tbl": {"col": "INT"}} 28 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 29 >>> qualify_columns(expression, schema).sql() 30 'SELECT tbl.col AS col FROM tbl' 31 32 Args: 33 expression: Expression to qualify. 34 schema: Database schema. 35 expand_alias_refs: Whether or not to expand references to aliases. 36 infer_schema: Whether or not to infer the schema if missing. 37 38 Returns: 39 The qualified expression. 40 """ 41 schema = ensure_schema(schema) 42 infer_schema = schema.empty if infer_schema is None else infer_schema 43 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 44 45 for scope in traverse_scope(expression): 46 resolver = Resolver(scope, schema, infer_schema=infer_schema) 47 _pop_table_column_aliases(scope.ctes) 48 _pop_table_column_aliases(scope.derived_tables) 49 using_column_tables = _expand_using(scope, resolver) 50 51 if schema.empty and expand_alias_refs: 52 _expand_alias_refs(scope, resolver) 53 54 _qualify_columns(scope, resolver) 55 56 if not schema.empty and expand_alias_refs: 57 _expand_alias_refs(scope, resolver) 58 59 if not isinstance(scope.expression, exp.UDTF): 60 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 61 qualify_outputs(scope) 62 63 _expand_group_by(scope) 64 _expand_order_by(scope, resolver) 65 66 return expression 67 68 69def validate_qualify_columns(expression: E) -> E: 70 """Raise an `OptimizeError` if any columns aren't qualified""" 71 unqualified_columns = [] 72 for scope in traverse_scope(expression): 73 if isinstance(scope.expression, exp.Select): 74 unqualified_columns.extend(scope.unqualified_columns) 75 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 76 column = scope.external_columns[0] 77 raise OptimizeError( 78 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 79 ) 80 81 if unqualified_columns: 82 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 83 return expression 84 85 86def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 87 """ 88 Remove table column aliases. 89 90 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 91 """ 92 for derived_table in derived_tables: 93 table_alias = derived_table.args.get("alias") 94 if table_alias: 95 table_alias.args.pop("columns", None) 96 97 98def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 99 joins = list(scope.find_all(exp.Join)) 100 names = {join.alias_or_name for join in joins} 101 ordered = [key for key in scope.selected_sources if key not in names] 102 103 # Mapping of automatically joined column names to an ordered set of source names (dict). 104 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 105 106 for join in joins: 107 using = join.args.get("using") 108 109 if not using: 110 continue 111 112 join_table = join.alias_or_name 113 114 columns = {} 115 116 for source_name in scope.selected_sources: 117 if source_name in ordered: 118 for column_name in resolver.get_source_columns(source_name): 119 if column_name not in columns: 120 columns[column_name] = source_name 121 122 source_table = ordered[-1] 123 ordered.append(join_table) 124 join_columns = resolver.get_source_columns(join_table) 125 conditions = [] 126 127 for identifier in using: 128 identifier = identifier.name 129 table = columns.get(identifier) 130 131 if not table or identifier not in join_columns: 132 if (columns and "*" not in columns) and join_columns: 133 raise OptimizeError(f"Cannot automatically join: {identifier}") 134 135 table = table or source_table 136 conditions.append( 137 exp.condition( 138 exp.EQ( 139 this=exp.column(identifier, table=table), 140 expression=exp.column(identifier, table=join_table), 141 ) 142 ) 143 ) 144 145 # Set all values in the dict to None, because we only care about the key ordering 146 tables = column_tables.setdefault(identifier, {}) 147 if table not in tables: 148 tables[table] = None 149 if join_table not in tables: 150 tables[join_table] = None 151 152 join.args.pop("using") 153 join.set("on", exp.and_(*conditions, copy=False)) 154 155 if column_tables: 156 for column in scope.columns: 157 if not column.table and column.name in column_tables: 158 tables = column_tables[column.name] 159 coalesce = [exp.column(column.name, table=table) for table in tables] 160 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 161 162 # Ensure selects keep their output name 163 if isinstance(column.parent, exp.Select): 164 replacement = alias(replacement, alias=column.name, copy=False) 165 166 scope.replace(column, replacement) 167 168 return column_tables 169 170 171def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 172 expression = scope.expression 173 174 if not isinstance(expression, exp.Select): 175 return 176 177 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 178 179 def replace_columns( 180 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 181 ) -> None: 182 if not node: 183 return 184 185 for column, *_ in walk_in_scope(node): 186 if not isinstance(column, exp.Column): 187 continue 188 189 table = resolver.get_table(column.name) if resolve_table and not column.table else None 190 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 191 double_agg = ( 192 (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) 193 if alias_expr 194 else False 195 ) 196 197 if table and (not alias_expr or double_agg): 198 column.set("table", table) 199 elif not column.table and alias_expr and not double_agg: 200 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 201 if literal_index: 202 column.replace(exp.Literal.number(i)) 203 else: 204 column = column.replace(exp.paren(alias_expr)) 205 simplified = simplify_parens(column) 206 if simplified is not column: 207 column.replace(simplified) 208 209 for i, projection in enumerate(scope.expression.selects): 210 replace_columns(projection) 211 212 if isinstance(projection, exp.Alias): 213 alias_to_expression[projection.alias] = (projection.this, i + 1) 214 215 replace_columns(expression.args.get("where")) 216 replace_columns(expression.args.get("group"), literal_index=True) 217 replace_columns(expression.args.get("having"), resolve_table=True) 218 replace_columns(expression.args.get("qualify"), resolve_table=True) 219 scope.clear_cache() 220 221 222def _expand_group_by(scope: Scope) -> None: 223 expression = scope.expression 224 group = expression.args.get("group") 225 if not group: 226 return 227 228 group.set("expressions", _expand_positional_references(scope, group.expressions)) 229 expression.set("group", group) 230 231 232def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 233 order = scope.expression.args.get("order") 234 if not order: 235 return 236 237 ordereds = order.expressions 238 for ordered, new_expression in zip( 239 ordereds, 240 _expand_positional_references(scope, (o.this for o in ordereds), alias=True), 241 ): 242 for agg in ordered.find_all(exp.AggFunc): 243 for col in agg.find_all(exp.Column): 244 if not col.table: 245 col.set("table", resolver.get_table(col.name)) 246 247 ordered.set("this", new_expression) 248 249 if scope.expression.args.get("group"): 250 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 251 252 for ordered in ordereds: 253 ordered = ordered.this 254 255 ordered.replace( 256 exp.to_identifier(_select_by_pos(scope, ordered).alias) 257 if ordered.is_int 258 else selects.get(ordered, ordered) 259 ) 260 261 262def _expand_positional_references( 263 scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False 264) -> t.List[exp.Expression]: 265 new_nodes: t.List[exp.Expression] = [] 266 for node in expressions: 267 if node.is_int: 268 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 269 270 if alias: 271 new_nodes.append(exp.column(select.args["alias"].copy())) 272 else: 273 select = select.this 274 275 if isinstance(select, exp.Literal): 276 new_nodes.append(node) 277 else: 278 new_nodes.append(select.copy()) 279 else: 280 new_nodes.append(node) 281 282 return new_nodes 283 284 285def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 286 try: 287 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 288 except IndexError: 289 raise OptimizeError(f"Unknown output column: {node.name}") 290 291 292def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 293 """Disambiguate columns, ensuring each column specifies a source""" 294 for column in scope.columns: 295 column_table = column.table 296 column_name = column.name 297 298 if column_table and column_table in scope.sources: 299 source_columns = resolver.get_source_columns(column_table) 300 if source_columns and column_name not in source_columns and "*" not in source_columns: 301 raise OptimizeError(f"Unknown column: {column_name}") 302 303 if not column_table: 304 if scope.pivots and not column.find_ancestor(exp.Pivot): 305 # If the column is under the Pivot expression, we need to qualify it 306 # using the name of the pivoted source instead of the pivot's alias 307 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 308 continue 309 310 column_table = resolver.get_table(column_name) 311 312 # column_table can be a '' because bigquery unnest has no table alias 313 if column_table: 314 column.set("table", column_table) 315 elif column_table not in scope.sources and ( 316 not scope.parent 317 or column_table not in scope.parent.sources 318 or not scope.is_correlated_subquery 319 ): 320 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 321 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 322 323 root, *parts = column.parts 324 325 if root.name in scope.sources: 326 # struct is already qualified, but we still need to change the AST representation 327 column_table = root 328 root, *parts = parts 329 else: 330 column_table = resolver.get_table(root.name) 331 332 if column_table: 333 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 334 335 for pivot in scope.pivots: 336 for column in pivot.find_all(exp.Column): 337 if not column.table and column.name in resolver.all_columns: 338 column_table = resolver.get_table(column.name) 339 if column_table: 340 column.set("table", column_table) 341 342 343def _expand_stars( 344 scope: Scope, 345 resolver: Resolver, 346 using_column_tables: t.Dict[str, t.Any], 347 pseudocolumns: t.Set[str], 348) -> None: 349 """Expand stars to lists of column selections""" 350 351 new_selections = [] 352 except_columns: t.Dict[int, t.Set[str]] = {} 353 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 354 coalesced_columns = set() 355 356 # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future 357 pivot_columns = None 358 pivot_output_columns = None 359 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 360 361 has_pivoted_source = pivot and not pivot.args.get("unpivot") 362 if pivot and has_pivoted_source: 363 pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) 364 365 pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] 366 if not pivot_output_columns: 367 pivot_output_columns = [col.alias_or_name for col in pivot.expressions] 368 369 for expression in scope.expression.selects: 370 if isinstance(expression, exp.Star): 371 tables = list(scope.selected_sources) 372 _add_except_columns(expression, tables, except_columns) 373 _add_replace_columns(expression, tables, replace_columns) 374 elif expression.is_star: 375 tables = [expression.table] 376 _add_except_columns(expression.this, tables, except_columns) 377 _add_replace_columns(expression.this, tables, replace_columns) 378 else: 379 new_selections.append(expression) 380 continue 381 382 for table in tables: 383 if table not in scope.sources: 384 raise OptimizeError(f"Unknown table: {table}") 385 386 columns = resolver.get_source_columns(table, only_visible=True) 387 388 if pseudocolumns: 389 columns = [name for name in columns if name.upper() not in pseudocolumns] 390 391 if columns and "*" not in columns: 392 table_id = id(table) 393 columns_to_exclude = except_columns.get(table_id) or set() 394 395 if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: 396 implicit_columns = [col for col in columns if col not in pivot_columns] 397 new_selections.extend( 398 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 399 for name in implicit_columns + pivot_output_columns 400 if name not in columns_to_exclude 401 ) 402 continue 403 404 for name in columns: 405 if name in using_column_tables and table in using_column_tables[name]: 406 if name in coalesced_columns: 407 continue 408 409 coalesced_columns.add(name) 410 tables = using_column_tables[name] 411 coalesce = [exp.column(name, table=table) for table in tables] 412 413 new_selections.append( 414 alias( 415 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 416 alias=name, 417 copy=False, 418 ) 419 ) 420 elif name not in columns_to_exclude: 421 alias_ = replace_columns.get(table_id, {}).get(name, name) 422 column = exp.column(name, table=table) 423 new_selections.append( 424 alias(column, alias_, copy=False) if alias_ != name else column 425 ) 426 else: 427 return 428 429 # Ensures we don't overwrite the initial selections with an empty list 430 if new_selections: 431 scope.expression.set("expressions", new_selections) 432 433 434def _add_except_columns( 435 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 436) -> None: 437 except_ = expression.args.get("except") 438 439 if not except_: 440 return 441 442 columns = {e.name for e in except_} 443 444 for table in tables: 445 except_columns[id(table)] = columns 446 447 448def _add_replace_columns( 449 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 450) -> None: 451 replace = expression.args.get("replace") 452 453 if not replace: 454 return 455 456 columns = {e.this.name: e.alias for e in replace} 457 458 for table in tables: 459 replace_columns[id(table)] = columns 460 461 462def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 463 """Ensure all output columns are aliased""" 464 if isinstance(scope_or_expression, exp.Expression): 465 scope = build_scope(scope_or_expression) 466 if not isinstance(scope, Scope): 467 return 468 else: 469 scope = scope_or_expression 470 471 new_selections = [] 472 for i, (selection, aliased_column) in enumerate( 473 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 474 ): 475 if isinstance(selection, exp.Subquery): 476 if not selection.output_name: 477 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 478 elif not isinstance(selection, exp.Alias) and not selection.is_star: 479 selection = alias( 480 selection, 481 alias=selection.output_name or f"_col_{i}", 482 ) 483 if aliased_column: 484 selection.set("alias", exp.to_identifier(aliased_column)) 485 486 new_selections.append(selection) 487 488 scope.expression.set("expressions", new_selections) 489 490 491def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 492 """Makes sure all identifiers that need to be quoted are quoted.""" 493 return expression.transform( 494 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 495 ) 496 497 498def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 499 """ 500 Pushes down the CTE alias columns into the projection, 501 502 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 503 504 Example: 505 >>> import sqlglot 506 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 507 >>> pushdown_cte_alias_columns(expression).sql() 508 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 509 510 Args: 511 expression: Expression to pushdown. 512 513 Returns: 514 The expression with the CTE aliases pushed down into the projection. 515 """ 516 for cte in expression.find_all(exp.CTE): 517 if cte.alias_column_names: 518 new_expressions = [] 519 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 520 if isinstance(projection, exp.Alias): 521 projection.set("alias", _alias) 522 else: 523 projection = alias(projection, alias=_alias) 524 new_expressions.append(projection) 525 cte.this.set("expressions", new_expressions) 526 527 return expression 528 529 530class Resolver: 531 """ 532 Helper for resolving columns. 533 534 This is a class so we can lazily load some things and easily share them across functions. 535 """ 536 537 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 538 self.scope = scope 539 self.schema = schema 540 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 541 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 542 self._all_columns: t.Optional[t.Set[str]] = None 543 self._infer_schema = infer_schema 544 545 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 546 """ 547 Get the table for a column name. 548 549 Args: 550 column_name: The column name to find the table for. 551 Returns: 552 The table name if it can be found/inferred. 553 """ 554 if self._unambiguous_columns is None: 555 self._unambiguous_columns = self._get_unambiguous_columns( 556 self._get_all_source_columns() 557 ) 558 559 table_name = self._unambiguous_columns.get(column_name) 560 561 if not table_name and self._infer_schema: 562 sources_without_schema = tuple( 563 source 564 for source, columns in self._get_all_source_columns().items() 565 if not columns or "*" in columns 566 ) 567 if len(sources_without_schema) == 1: 568 table_name = sources_without_schema[0] 569 570 if table_name not in self.scope.selected_sources: 571 return exp.to_identifier(table_name) 572 573 node, _ = self.scope.selected_sources.get(table_name) 574 575 if isinstance(node, exp.Subqueryable): 576 while node and node.alias != table_name: 577 node = node.parent 578 579 node_alias = node.args.get("alias") 580 if node_alias: 581 return exp.to_identifier(node_alias.this) 582 583 return exp.to_identifier(table_name) 584 585 @property 586 def all_columns(self) -> t.Set[str]: 587 """All available columns of all sources in this scope""" 588 if self._all_columns is None: 589 self._all_columns = { 590 column for columns in self._get_all_source_columns().values() for column in columns 591 } 592 return self._all_columns 593 594 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 595 """Resolve the source columns for a given source `name`.""" 596 if name not in self.scope.sources: 597 raise OptimizeError(f"Unknown table: {name}") 598 599 source = self.scope.sources[name] 600 601 if isinstance(source, exp.Table): 602 columns = self.schema.column_names(source, only_visible) 603 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 604 columns = source.expression.alias_column_names 605 else: 606 columns = source.expression.named_selects 607 608 node, _ = self.scope.selected_sources.get(name) or (None, None) 609 if isinstance(node, Scope): 610 column_aliases = node.expression.alias_column_names 611 elif isinstance(node, exp.Expression): 612 column_aliases = node.alias_column_names 613 else: 614 column_aliases = [] 615 616 # If the source's columns are aliased, their aliases shadow the corresponding column names 617 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] 618 619 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 620 if self._source_columns is None: 621 self._source_columns = { 622 source_name: self.get_source_columns(source_name) 623 for source_name, source in itertools.chain( 624 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 625 ) 626 } 627 return self._source_columns 628 629 def _get_unambiguous_columns( 630 self, source_columns: t.Dict[str, t.List[str]] 631 ) -> t.Dict[str, str]: 632 """ 633 Find all the unambiguous columns in sources. 634 635 Args: 636 source_columns: Mapping of names to source columns. 637 638 Returns: 639 Mapping of column name to source name. 640 """ 641 if not source_columns: 642 return {} 643 644 source_columns_pairs = list(source_columns.items()) 645 646 first_table, first_columns = source_columns_pairs[0] 647 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 648 all_columns = set(unambiguous_columns) 649 650 for table, columns in source_columns_pairs[1:]: 651 unique = self._find_unique_columns(columns) 652 ambiguous = set(all_columns).intersection(unique) 653 all_columns.update(columns) 654 655 for column in ambiguous: 656 unambiguous_columns.pop(column, None) 657 for column in unique.difference(ambiguous): 658 unambiguous_columns[column] = table 659 660 return unambiguous_columns 661 662 @staticmethod 663 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 664 """ 665 Find the unique columns in a list of columns. 666 667 Example: 668 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 669 ['a', 'c'] 670 671 This is necessary because duplicate column names are ambiguous. 672 """ 673 counts: t.Dict[str, int] = {} 674 for column in columns: 675 counts[column] = counts.get(column, 0) + 1 676 return {column for column, count in counts.items() if count == 1}
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
17def qualify_columns( 18 expression: exp.Expression, 19 schema: t.Dict | Schema, 20 expand_alias_refs: bool = True, 21 infer_schema: t.Optional[bool] = None, 22) -> exp.Expression: 23 """ 24 Rewrite sqlglot AST to have fully qualified columns. 25 26 Example: 27 >>> import sqlglot 28 >>> schema = {"tbl": {"col": "INT"}} 29 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 30 >>> qualify_columns(expression, schema).sql() 31 'SELECT tbl.col AS col FROM tbl' 32 33 Args: 34 expression: Expression to qualify. 35 schema: Database schema. 36 expand_alias_refs: Whether or not to expand references to aliases. 37 infer_schema: Whether or not to infer the schema if missing. 38 39 Returns: 40 The qualified expression. 41 """ 42 schema = ensure_schema(schema) 43 infer_schema = schema.empty if infer_schema is None else infer_schema 44 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 45 46 for scope in traverse_scope(expression): 47 resolver = Resolver(scope, schema, infer_schema=infer_schema) 48 _pop_table_column_aliases(scope.ctes) 49 _pop_table_column_aliases(scope.derived_tables) 50 using_column_tables = _expand_using(scope, resolver) 51 52 if schema.empty and expand_alias_refs: 53 _expand_alias_refs(scope, resolver) 54 55 _qualify_columns(scope, resolver) 56 57 if not schema.empty and expand_alias_refs: 58 _expand_alias_refs(scope, resolver) 59 60 if not isinstance(scope.expression, exp.UDTF): 61 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 62 qualify_outputs(scope) 63 64 _expand_group_by(scope) 65 _expand_order_by(scope, resolver) 66 67 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether or not to expand references to aliases.
- infer_schema: Whether or not to infer the schema if missing.
Returns:
The qualified expression.
def
validate_qualify_columns(expression: ~E) -> ~E:
70def validate_qualify_columns(expression: E) -> E: 71 """Raise an `OptimizeError` if any columns aren't qualified""" 72 unqualified_columns = [] 73 for scope in traverse_scope(expression): 74 if isinstance(scope.expression, exp.Select): 75 unqualified_columns.extend(scope.unqualified_columns) 76 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 77 column = scope.external_columns[0] 78 raise OptimizeError( 79 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 80 ) 81 82 if unqualified_columns: 83 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 84 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
463def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 464 """Ensure all output columns are aliased""" 465 if isinstance(scope_or_expression, exp.Expression): 466 scope = build_scope(scope_or_expression) 467 if not isinstance(scope, Scope): 468 return 469 else: 470 scope = scope_or_expression 471 472 new_selections = [] 473 for i, (selection, aliased_column) in enumerate( 474 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 475 ): 476 if isinstance(selection, exp.Subquery): 477 if not selection.output_name: 478 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 479 elif not isinstance(selection, exp.Alias) and not selection.is_star: 480 selection = alias( 481 selection, 482 alias=selection.output_name or f"_col_{i}", 483 ) 484 if aliased_column: 485 selection.set("alias", exp.to_identifier(aliased_column)) 486 487 new_selections.append(selection) 488 489 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
492def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 493 """Makes sure all identifiers that need to be quoted are quoted.""" 494 return expression.transform( 495 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 496 )
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
499def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 500 """ 501 Pushes down the CTE alias columns into the projection, 502 503 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 504 505 Example: 506 >>> import sqlglot 507 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 508 >>> pushdown_cte_alias_columns(expression).sql() 509 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 510 511 Args: 512 expression: Expression to pushdown. 513 514 Returns: 515 The expression with the CTE aliases pushed down into the projection. 516 """ 517 for cte in expression.find_all(exp.CTE): 518 if cte.alias_column_names: 519 new_expressions = [] 520 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 521 if isinstance(projection, exp.Alias): 522 projection.set("alias", _alias) 523 else: 524 projection = alias(projection, alias=_alias) 525 new_expressions.append(projection) 526 cte.this.set("expressions", new_expressions) 527 528 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
531class Resolver: 532 """ 533 Helper for resolving columns. 534 535 This is a class so we can lazily load some things and easily share them across functions. 536 """ 537 538 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 539 self.scope = scope 540 self.schema = schema 541 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 542 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 543 self._all_columns: t.Optional[t.Set[str]] = None 544 self._infer_schema = infer_schema 545 546 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 547 """ 548 Get the table for a column name. 549 550 Args: 551 column_name: The column name to find the table for. 552 Returns: 553 The table name if it can be found/inferred. 554 """ 555 if self._unambiguous_columns is None: 556 self._unambiguous_columns = self._get_unambiguous_columns( 557 self._get_all_source_columns() 558 ) 559 560 table_name = self._unambiguous_columns.get(column_name) 561 562 if not table_name and self._infer_schema: 563 sources_without_schema = tuple( 564 source 565 for source, columns in self._get_all_source_columns().items() 566 if not columns or "*" in columns 567 ) 568 if len(sources_without_schema) == 1: 569 table_name = sources_without_schema[0] 570 571 if table_name not in self.scope.selected_sources: 572 return exp.to_identifier(table_name) 573 574 node, _ = self.scope.selected_sources.get(table_name) 575 576 if isinstance(node, exp.Subqueryable): 577 while node and node.alias != table_name: 578 node = node.parent 579 580 node_alias = node.args.get("alias") 581 if node_alias: 582 return exp.to_identifier(node_alias.this) 583 584 return exp.to_identifier(table_name) 585 586 @property 587 def all_columns(self) -> t.Set[str]: 588 """All available columns of all sources in this scope""" 589 if self._all_columns is None: 590 self._all_columns = { 591 column for columns in self._get_all_source_columns().values() for column in columns 592 } 593 return self._all_columns 594 595 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 596 """Resolve the source columns for a given source `name`.""" 597 if name not in self.scope.sources: 598 raise OptimizeError(f"Unknown table: {name}") 599 600 source = self.scope.sources[name] 601 602 if isinstance(source, exp.Table): 603 columns = self.schema.column_names(source, only_visible) 604 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 605 columns = source.expression.alias_column_names 606 else: 607 columns = source.expression.named_selects 608 609 node, _ = self.scope.selected_sources.get(name) or (None, None) 610 if isinstance(node, Scope): 611 column_aliases = node.expression.alias_column_names 612 elif isinstance(node, exp.Expression): 613 column_aliases = node.alias_column_names 614 else: 615 column_aliases = [] 616 617 # If the source's columns are aliased, their aliases shadow the corresponding column names 618 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] 619 620 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 621 if self._source_columns is None: 622 self._source_columns = { 623 source_name: self.get_source_columns(source_name) 624 for source_name, source in itertools.chain( 625 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 626 ) 627 } 628 return self._source_columns 629 630 def _get_unambiguous_columns( 631 self, source_columns: t.Dict[str, t.List[str]] 632 ) -> t.Dict[str, str]: 633 """ 634 Find all the unambiguous columns in sources. 635 636 Args: 637 source_columns: Mapping of names to source columns. 638 639 Returns: 640 Mapping of column name to source name. 641 """ 642 if not source_columns: 643 return {} 644 645 source_columns_pairs = list(source_columns.items()) 646 647 first_table, first_columns = source_columns_pairs[0] 648 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 649 all_columns = set(unambiguous_columns) 650 651 for table, columns in source_columns_pairs[1:]: 652 unique = self._find_unique_columns(columns) 653 ambiguous = set(all_columns).intersection(unique) 654 all_columns.update(columns) 655 656 for column in ambiguous: 657 unambiguous_columns.pop(column, None) 658 for column in unique.difference(ambiguous): 659 unambiguous_columns[column] = table 660 661 return unambiguous_columns 662 663 @staticmethod 664 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 665 """ 666 Find the unique columns in a list of columns. 667 668 Example: 669 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 670 ['a', 'c'] 671 672 This is necessary because duplicate column names are ambiguous. 673 """ 674 counts: t.Dict[str, int] = {} 675 for column in columns: 676 counts[column] = counts.get(column, 0) + 1 677 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
538 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 539 self.scope = scope 540 self.schema = schema 541 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 542 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 543 self._all_columns: t.Optional[t.Set[str]] = None 544 self._infer_schema = infer_schema
546 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 547 """ 548 Get the table for a column name. 549 550 Args: 551 column_name: The column name to find the table for. 552 Returns: 553 The table name if it can be found/inferred. 554 """ 555 if self._unambiguous_columns is None: 556 self._unambiguous_columns = self._get_unambiguous_columns( 557 self._get_all_source_columns() 558 ) 559 560 table_name = self._unambiguous_columns.get(column_name) 561 562 if not table_name and self._infer_schema: 563 sources_without_schema = tuple( 564 source 565 for source, columns in self._get_all_source_columns().items() 566 if not columns or "*" in columns 567 ) 568 if len(sources_without_schema) == 1: 569 table_name = sources_without_schema[0] 570 571 if table_name not in self.scope.selected_sources: 572 return exp.to_identifier(table_name) 573 574 node, _ = self.scope.selected_sources.get(table_name) 575 576 if isinstance(node, exp.Subqueryable): 577 while node and node.alias != table_name: 578 node = node.parent 579 580 node_alias = node.args.get("alias") 581 if node_alias: 582 return exp.to_identifier(node_alias.this) 583 584 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
586 @property 587 def all_columns(self) -> t.Set[str]: 588 """All available columns of all sources in this scope""" 589 if self._all_columns is None: 590 self._all_columns = { 591 column for columns in self._get_all_source_columns().values() for column in columns 592 } 593 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
595 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 596 """Resolve the source columns for a given source `name`.""" 597 if name not in self.scope.sources: 598 raise OptimizeError(f"Unknown table: {name}") 599 600 source = self.scope.sources[name] 601 602 if isinstance(source, exp.Table): 603 columns = self.schema.column_names(source, only_visible) 604 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 605 columns = source.expression.alias_column_names 606 else: 607 columns = source.expression.named_selects 608 609 node, _ = self.scope.selected_sources.get(name) or (None, None) 610 if isinstance(node, Scope): 611 column_aliases = node.expression.alias_column_names 612 elif isinstance(node, exp.Expression): 613 column_aliases = node.alias_column_names 614 else: 615 column_aliases = [] 616 617 # If the source's columns are aliased, their aliases shadow the corresponding column names 618 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
Resolve the source columns for a given source name
.