sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot._typing import E 8from sqlglot.dialects.dialect import DialectType 9from sqlglot.errors import OptimizeError 10from sqlglot.helper import case_sensitive, seq_get 11from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope 12from sqlglot.schema import Schema, ensure_schema 13 14 15def qualify_columns( 16 expression: exp.Expression, 17 schema: t.Dict | Schema, 18 expand_alias_refs: bool = True, 19 infer_schema: t.Optional[bool] = None, 20) -> exp.Expression: 21 """ 22 Rewrite sqlglot AST to have fully qualified columns. 23 24 Example: 25 >>> import sqlglot 26 >>> schema = {"tbl": {"col": "INT"}} 27 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 28 >>> qualify_columns(expression, schema).sql() 29 'SELECT tbl.col AS col FROM tbl' 30 31 Args: 32 expression: expression to qualify 33 schema: Database schema 34 expand_alias_refs: whether or not to expand references to aliases 35 infer_schema: whether or not to infer the schema if missing 36 Returns: 37 sqlglot.Expression: qualified expression 38 """ 39 schema = ensure_schema(schema) 40 infer_schema = schema.empty if infer_schema is None else infer_schema 41 42 for scope in traverse_scope(expression): 43 resolver = Resolver(scope, schema, infer_schema=infer_schema) 44 _pop_table_column_aliases(scope.ctes) 45 _pop_table_column_aliases(scope.derived_tables) 46 using_column_tables = _expand_using(scope, resolver) 47 48 if schema.empty and expand_alias_refs: 49 _expand_alias_refs(scope, resolver) 50 51 _qualify_columns(scope, resolver) 52 53 if not schema.empty and expand_alias_refs: 54 _expand_alias_refs(scope, resolver) 55 56 if not isinstance(scope.expression, exp.UDTF): 57 _expand_stars(scope, resolver, using_column_tables) 58 _qualify_outputs(scope) 59 _expand_group_by(scope, resolver) 60 _expand_order_by(scope) 61 62 return expression 63 64 65def validate_qualify_columns(expression): 66 """Raise an `OptimizeError` if any columns aren't qualified""" 67 unqualified_columns = [] 68 for scope in traverse_scope(expression): 69 if isinstance(scope.expression, exp.Select): 70 unqualified_columns.extend(scope.unqualified_columns) 71 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 72 column = scope.external_columns[0] 73 raise OptimizeError( 74 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 75 ) 76 77 if unqualified_columns: 78 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 79 return expression 80 81 82def _pop_table_column_aliases(derived_tables): 83 """ 84 Remove table column aliases. 85 86 (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) 87 """ 88 for derived_table in derived_tables: 89 table_alias = derived_table.args.get("alias") 90 if table_alias: 91 table_alias.args.pop("columns", None) 92 93 94def _expand_using(scope, resolver): 95 joins = list(scope.find_all(exp.Join)) 96 names = {join.this.alias_or_name for join in joins} 97 ordered = [key for key in scope.selected_sources if key not in names] 98 99 # Mapping of automatically joined column names to an ordered set of source names (dict). 100 column_tables = {} 101 102 for join in joins: 103 using = join.args.get("using") 104 105 if not using: 106 continue 107 108 join_table = join.this.alias_or_name 109 110 columns = {} 111 112 for k in scope.selected_sources: 113 if k in ordered: 114 for column in resolver.get_source_columns(k): 115 if column not in columns: 116 columns[column] = k 117 118 source_table = ordered[-1] 119 ordered.append(join_table) 120 join_columns = resolver.get_source_columns(join_table) 121 conditions = [] 122 123 for identifier in using: 124 identifier = identifier.name 125 table = columns.get(identifier) 126 127 if not table or identifier not in join_columns: 128 if columns and join_columns: 129 raise OptimizeError(f"Cannot automatically join: {identifier}") 130 131 table = table or source_table 132 conditions.append( 133 exp.condition( 134 exp.EQ( 135 this=exp.column(identifier, table=table), 136 expression=exp.column(identifier, table=join_table), 137 ) 138 ) 139 ) 140 141 # Set all values in the dict to None, because we only care about the key ordering 142 tables = column_tables.setdefault(identifier, {}) 143 if table not in tables: 144 tables[table] = None 145 if join_table not in tables: 146 tables[join_table] = None 147 148 join.args.pop("using") 149 join.set("on", exp.and_(*conditions, copy=False)) 150 151 if column_tables: 152 for column in scope.columns: 153 if not column.table and column.name in column_tables: 154 tables = column_tables[column.name] 155 coalesce = [exp.column(column.name, table=table) for table in tables] 156 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 157 158 # Ensure selects keep their output name 159 if isinstance(column.parent, exp.Select): 160 replacement = alias(replacement, alias=column.name, copy=False) 161 162 scope.replace(column, replacement) 163 164 return column_tables 165 166 167def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 168 expression = scope.expression 169 170 if not isinstance(expression, exp.Select): 171 return 172 173 alias_to_expression: t.Dict[str, exp.Expression] = {} 174 175 def replace_columns( 176 node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False 177 ): 178 if not node: 179 return 180 181 for column, *_ in walk_in_scope(node): 182 if not isinstance(column, exp.Column): 183 continue 184 table = resolver.get_table(column.name) if resolve_agg and not column.table else None 185 if table and column.find_ancestor(exp.AggFunc): 186 column.set("table", table) 187 elif expand and not column.table and column.name in alias_to_expression: 188 column.replace(alias_to_expression[column.name].copy()) 189 190 for projection in scope.selects: 191 replace_columns(projection) 192 193 if isinstance(projection, exp.Alias): 194 alias_to_expression[projection.alias] = projection.this 195 196 replace_columns(expression.args.get("where")) 197 replace_columns(expression.args.get("group")) 198 replace_columns(expression.args.get("having"), resolve_agg=True) 199 replace_columns(expression.args.get("qualify"), resolve_agg=True) 200 replace_columns(expression.args.get("order"), expand=False, resolve_agg=True) 201 scope.clear_cache() 202 203 204def _expand_group_by(scope, resolver): 205 group = scope.expression.args.get("group") 206 if not group: 207 return 208 209 group.set("expressions", _expand_positional_references(scope, group.expressions)) 210 scope.expression.set("group", group) 211 212 213def _expand_order_by(scope): 214 order = scope.expression.args.get("order") 215 if not order: 216 return 217 218 ordereds = order.expressions 219 for ordered, new_expression in zip( 220 ordereds, 221 _expand_positional_references(scope, (o.this for o in ordereds)), 222 ): 223 ordered.set("this", new_expression) 224 225 226def _expand_positional_references(scope, expressions): 227 new_nodes = [] 228 for node in expressions: 229 if node.is_int: 230 try: 231 select = scope.selects[int(node.name) - 1] 232 except IndexError: 233 raise OptimizeError(f"Unknown output column: {node.name}") 234 if isinstance(select, exp.Alias): 235 select = select.this 236 new_nodes.append(select.copy()) 237 scope.clear_cache() 238 else: 239 new_nodes.append(node) 240 241 return new_nodes 242 243 244def _qualify_columns(scope, resolver): 245 """Disambiguate columns, ensuring each column specifies a source""" 246 for column in scope.columns: 247 column_table = column.table 248 column_name = column.name 249 250 if column_table and column_table in scope.sources: 251 source_columns = resolver.get_source_columns(column_table) 252 if source_columns and column_name not in source_columns and "*" not in source_columns: 253 raise OptimizeError(f"Unknown column: {column_name}") 254 255 if not column_table: 256 if scope.pivots and not column.find_ancestor(exp.Pivot): 257 # If the column is under the Pivot expression, we need to qualify it 258 # using the name of the pivoted source instead of the pivot's alias 259 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 260 continue 261 262 column_table = resolver.get_table(column_name) 263 264 # column_table can be a '' because bigquery unnest has no table alias 265 if column_table: 266 column.set("table", column_table) 267 elif column_table not in scope.sources and ( 268 not scope.parent or column_table not in scope.parent.sources 269 ): 270 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 271 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 272 273 root, *parts = column.parts 274 275 if root.name in scope.sources: 276 # struct is already qualified, but we still need to change the AST representation 277 column_table = root 278 root, *parts = parts 279 else: 280 column_table = resolver.get_table(root.name) 281 282 if column_table: 283 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 284 285 for pivot in scope.pivots: 286 for column in pivot.find_all(exp.Column): 287 if not column.table and column.name in resolver.all_columns: 288 column_table = resolver.get_table(column.name) 289 if column_table: 290 column.set("table", column_table) 291 292 293def _expand_stars(scope, resolver, using_column_tables): 294 """Expand stars to lists of column selections""" 295 296 new_selections = [] 297 except_columns = {} 298 replace_columns = {} 299 coalesced_columns = set() 300 301 # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future 302 pivot_columns = None 303 pivot_output_columns = None 304 pivot = seq_get(scope.pivots, 0) 305 306 has_pivoted_source = pivot and not pivot.args.get("unpivot") 307 if has_pivoted_source: 308 pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) 309 310 pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] 311 if not pivot_output_columns: 312 pivot_output_columns = [col.alias_or_name for col in pivot.expressions] 313 314 for expression in scope.selects: 315 if isinstance(expression, exp.Star): 316 tables = list(scope.selected_sources) 317 _add_except_columns(expression, tables, except_columns) 318 _add_replace_columns(expression, tables, replace_columns) 319 elif expression.is_star: 320 tables = [expression.table] 321 _add_except_columns(expression.this, tables, except_columns) 322 _add_replace_columns(expression.this, tables, replace_columns) 323 else: 324 new_selections.append(expression) 325 continue 326 327 for table in tables: 328 if table not in scope.sources: 329 raise OptimizeError(f"Unknown table: {table}") 330 331 columns = resolver.get_source_columns(table, only_visible=True) 332 333 if columns and "*" not in columns: 334 if has_pivoted_source: 335 implicit_columns = [col for col in columns if col not in pivot_columns] 336 new_selections.extend( 337 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 338 for name in implicit_columns + pivot_output_columns 339 ) 340 continue 341 342 table_id = id(table) 343 for name in columns: 344 if name in using_column_tables and table in using_column_tables[name]: 345 if name in coalesced_columns: 346 continue 347 348 coalesced_columns.add(name) 349 tables = using_column_tables[name] 350 coalesce = [exp.column(name, table=table) for table in tables] 351 352 new_selections.append( 353 alias( 354 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 355 alias=name, 356 copy=False, 357 ) 358 ) 359 elif name not in except_columns.get(table_id, set()): 360 alias_ = replace_columns.get(table_id, {}).get(name, name) 361 column = exp.column(name, table=table) 362 new_selections.append( 363 alias(column, alias_, copy=False) if alias_ != name else column 364 ) 365 else: 366 return 367 368 scope.expression.set("expressions", new_selections) 369 370 371def _add_except_columns(expression, tables, except_columns): 372 except_ = expression.args.get("except") 373 374 if not except_: 375 return 376 377 columns = {e.name for e in except_} 378 379 for table in tables: 380 except_columns[id(table)] = columns 381 382 383def _add_replace_columns(expression, tables, replace_columns): 384 replace = expression.args.get("replace") 385 386 if not replace: 387 return 388 389 columns = {e.this.name: e.alias for e in replace} 390 391 for table in tables: 392 replace_columns[id(table)] = columns 393 394 395def _qualify_outputs(scope): 396 """Ensure all output columns are aliased""" 397 new_selections = [] 398 399 for i, (selection, aliased_column) in enumerate( 400 itertools.zip_longest(scope.selects, scope.outer_column_list) 401 ): 402 if isinstance(selection, exp.Subquery): 403 if not selection.output_name: 404 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 405 elif not isinstance(selection, exp.Alias) and not selection.is_star: 406 selection = alias( 407 selection, 408 alias=selection.output_name or f"_col_{i}", 409 ) 410 if aliased_column: 411 selection.set("alias", exp.to_identifier(aliased_column)) 412 413 new_selections.append(selection) 414 415 scope.expression.set("expressions", new_selections) 416 417 418def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 419 """Makes sure all identifiers that need to be quoted are quoted.""" 420 421 def _quote(expression: E) -> E: 422 if isinstance(expression, exp.Identifier): 423 name = expression.this 424 expression.set( 425 "quoted", 426 identify 427 or case_sensitive(name, dialect=dialect) 428 or not exp.SAFE_IDENTIFIER_RE.match(name), 429 ) 430 return expression 431 432 return expression.transform(_quote, copy=False) 433 434 435class Resolver: 436 """ 437 Helper for resolving columns. 438 439 This is a class so we can lazily load some things and easily share them across functions. 440 """ 441 442 def __init__(self, scope, schema, infer_schema: bool = True): 443 self.scope = scope 444 self.schema = schema 445 self._source_columns = None 446 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 447 self._all_columns = None 448 self._infer_schema = infer_schema 449 450 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 451 """ 452 Get the table for a column name. 453 454 Args: 455 column_name: The column name to find the table for. 456 Returns: 457 The table name if it can be found/inferred. 458 """ 459 if self._unambiguous_columns is None: 460 self._unambiguous_columns = self._get_unambiguous_columns( 461 self._get_all_source_columns() 462 ) 463 464 table_name = self._unambiguous_columns.get(column_name) 465 466 if not table_name and self._infer_schema: 467 sources_without_schema = tuple( 468 source 469 for source, columns in self._get_all_source_columns().items() 470 if not columns or "*" in columns 471 ) 472 if len(sources_without_schema) == 1: 473 table_name = sources_without_schema[0] 474 475 if table_name not in self.scope.selected_sources: 476 return exp.to_identifier(table_name) 477 478 node, _ = self.scope.selected_sources.get(table_name) 479 480 if isinstance(node, exp.Subqueryable): 481 while node and node.alias != table_name: 482 node = node.parent 483 484 node_alias = node.args.get("alias") 485 if node_alias: 486 return exp.to_identifier(node_alias.this) 487 488 return exp.to_identifier(table_name) 489 490 @property 491 def all_columns(self): 492 """All available columns of all sources in this scope""" 493 if self._all_columns is None: 494 self._all_columns = { 495 column for columns in self._get_all_source_columns().values() for column in columns 496 } 497 return self._all_columns 498 499 def get_source_columns(self, name, only_visible=False): 500 """Resolve the source columns for a given source `name`""" 501 if name not in self.scope.sources: 502 raise OptimizeError(f"Unknown table: {name}") 503 504 source = self.scope.sources[name] 505 506 # If referencing a table, return the columns from the schema 507 if isinstance(source, exp.Table): 508 return self.schema.column_names(source, only_visible) 509 510 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 511 return source.expression.alias_column_names 512 513 # Otherwise, if referencing another scope, return that scope's named selects 514 return source.expression.named_selects 515 516 def _get_all_source_columns(self): 517 if self._source_columns is None: 518 self._source_columns = { 519 k: self.get_source_columns(k) 520 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 521 } 522 return self._source_columns 523 524 def _get_unambiguous_columns(self, source_columns): 525 """ 526 Find all the unambiguous columns in sources. 527 528 Args: 529 source_columns (dict): Mapping of names to source columns 530 Returns: 531 dict: Mapping of column name to source name 532 """ 533 if not source_columns: 534 return {} 535 536 source_columns = list(source_columns.items()) 537 538 first_table, first_columns = source_columns[0] 539 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 540 all_columns = set(unambiguous_columns) 541 542 for table, columns in source_columns[1:]: 543 unique = self._find_unique_columns(columns) 544 ambiguous = set(all_columns).intersection(unique) 545 all_columns.update(columns) 546 for column in ambiguous: 547 unambiguous_columns.pop(column, None) 548 for column in unique.difference(ambiguous): 549 unambiguous_columns[column] = table 550 551 return unambiguous_columns 552 553 @staticmethod 554 def _find_unique_columns(columns): 555 """ 556 Find the unique columns in a list of columns. 557 558 Example: 559 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 560 ['a', 'c'] 561 562 This is necessary because duplicate column names are ambiguous. 563 """ 564 counts = {} 565 for column in columns: 566 counts[column] = counts.get(column, 0) + 1 567 return {column for column, count in counts.items() if count == 1}
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
16def qualify_columns( 17 expression: exp.Expression, 18 schema: t.Dict | Schema, 19 expand_alias_refs: bool = True, 20 infer_schema: t.Optional[bool] = None, 21) -> exp.Expression: 22 """ 23 Rewrite sqlglot AST to have fully qualified columns. 24 25 Example: 26 >>> import sqlglot 27 >>> schema = {"tbl": {"col": "INT"}} 28 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 29 >>> qualify_columns(expression, schema).sql() 30 'SELECT tbl.col AS col FROM tbl' 31 32 Args: 33 expression: expression to qualify 34 schema: Database schema 35 expand_alias_refs: whether or not to expand references to aliases 36 infer_schema: whether or not to infer the schema if missing 37 Returns: 38 sqlglot.Expression: qualified expression 39 """ 40 schema = ensure_schema(schema) 41 infer_schema = schema.empty if infer_schema is None else infer_schema 42 43 for scope in traverse_scope(expression): 44 resolver = Resolver(scope, schema, infer_schema=infer_schema) 45 _pop_table_column_aliases(scope.ctes) 46 _pop_table_column_aliases(scope.derived_tables) 47 using_column_tables = _expand_using(scope, resolver) 48 49 if schema.empty and expand_alias_refs: 50 _expand_alias_refs(scope, resolver) 51 52 _qualify_columns(scope, resolver) 53 54 if not schema.empty and expand_alias_refs: 55 _expand_alias_refs(scope, resolver) 56 57 if not isinstance(scope.expression, exp.UDTF): 58 _expand_stars(scope, resolver, using_column_tables) 59 _qualify_outputs(scope) 60 _expand_group_by(scope, resolver) 61 _expand_order_by(scope) 62 63 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: expression to qualify
- schema: Database schema
- expand_alias_refs: whether or not to expand references to aliases
- infer_schema: whether or not to infer the schema if missing
Returns:
sqlglot.Expression: qualified expression
def
validate_qualify_columns(expression):
66def validate_qualify_columns(expression): 67 """Raise an `OptimizeError` if any columns aren't qualified""" 68 unqualified_columns = [] 69 for scope in traverse_scope(expression): 70 if isinstance(scope.expression, exp.Select): 71 unqualified_columns.extend(scope.unqualified_columns) 72 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 73 column = scope.external_columns[0] 74 raise OptimizeError( 75 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 76 ) 77 78 if unqualified_columns: 79 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 80 return expression
Raise an OptimizeError
if any columns aren't qualified
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
419def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 420 """Makes sure all identifiers that need to be quoted are quoted.""" 421 422 def _quote(expression: E) -> E: 423 if isinstance(expression, exp.Identifier): 424 name = expression.this 425 expression.set( 426 "quoted", 427 identify 428 or case_sensitive(name, dialect=dialect) 429 or not exp.SAFE_IDENTIFIER_RE.match(name), 430 ) 431 return expression 432 433 return expression.transform(_quote, copy=False)
Makes sure all identifiers that need to be quoted are quoted.
class
Resolver:
436class Resolver: 437 """ 438 Helper for resolving columns. 439 440 This is a class so we can lazily load some things and easily share them across functions. 441 """ 442 443 def __init__(self, scope, schema, infer_schema: bool = True): 444 self.scope = scope 445 self.schema = schema 446 self._source_columns = None 447 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 448 self._all_columns = None 449 self._infer_schema = infer_schema 450 451 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 452 """ 453 Get the table for a column name. 454 455 Args: 456 column_name: The column name to find the table for. 457 Returns: 458 The table name if it can be found/inferred. 459 """ 460 if self._unambiguous_columns is None: 461 self._unambiguous_columns = self._get_unambiguous_columns( 462 self._get_all_source_columns() 463 ) 464 465 table_name = self._unambiguous_columns.get(column_name) 466 467 if not table_name and self._infer_schema: 468 sources_without_schema = tuple( 469 source 470 for source, columns in self._get_all_source_columns().items() 471 if not columns or "*" in columns 472 ) 473 if len(sources_without_schema) == 1: 474 table_name = sources_without_schema[0] 475 476 if table_name not in self.scope.selected_sources: 477 return exp.to_identifier(table_name) 478 479 node, _ = self.scope.selected_sources.get(table_name) 480 481 if isinstance(node, exp.Subqueryable): 482 while node and node.alias != table_name: 483 node = node.parent 484 485 node_alias = node.args.get("alias") 486 if node_alias: 487 return exp.to_identifier(node_alias.this) 488 489 return exp.to_identifier(table_name) 490 491 @property 492 def all_columns(self): 493 """All available columns of all sources in this scope""" 494 if self._all_columns is None: 495 self._all_columns = { 496 column for columns in self._get_all_source_columns().values() for column in columns 497 } 498 return self._all_columns 499 500 def get_source_columns(self, name, only_visible=False): 501 """Resolve the source columns for a given source `name`""" 502 if name not in self.scope.sources: 503 raise OptimizeError(f"Unknown table: {name}") 504 505 source = self.scope.sources[name] 506 507 # If referencing a table, return the columns from the schema 508 if isinstance(source, exp.Table): 509 return self.schema.column_names(source, only_visible) 510 511 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 512 return source.expression.alias_column_names 513 514 # Otherwise, if referencing another scope, return that scope's named selects 515 return source.expression.named_selects 516 517 def _get_all_source_columns(self): 518 if self._source_columns is None: 519 self._source_columns = { 520 k: self.get_source_columns(k) 521 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 522 } 523 return self._source_columns 524 525 def _get_unambiguous_columns(self, source_columns): 526 """ 527 Find all the unambiguous columns in sources. 528 529 Args: 530 source_columns (dict): Mapping of names to source columns 531 Returns: 532 dict: Mapping of column name to source name 533 """ 534 if not source_columns: 535 return {} 536 537 source_columns = list(source_columns.items()) 538 539 first_table, first_columns = source_columns[0] 540 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 541 all_columns = set(unambiguous_columns) 542 543 for table, columns in source_columns[1:]: 544 unique = self._find_unique_columns(columns) 545 ambiguous = set(all_columns).intersection(unique) 546 all_columns.update(columns) 547 for column in ambiguous: 548 unambiguous_columns.pop(column, None) 549 for column in unique.difference(ambiguous): 550 unambiguous_columns[column] = table 551 552 return unambiguous_columns 553 554 @staticmethod 555 def _find_unique_columns(columns): 556 """ 557 Find the unique columns in a list of columns. 558 559 Example: 560 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 561 ['a', 'c'] 562 563 This is necessary because duplicate column names are ambiguous. 564 """ 565 counts = {} 566 for column in columns: 567 counts[column] = counts.get(column, 0) + 1 568 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
451 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 452 """ 453 Get the table for a column name. 454 455 Args: 456 column_name: The column name to find the table for. 457 Returns: 458 The table name if it can be found/inferred. 459 """ 460 if self._unambiguous_columns is None: 461 self._unambiguous_columns = self._get_unambiguous_columns( 462 self._get_all_source_columns() 463 ) 464 465 table_name = self._unambiguous_columns.get(column_name) 466 467 if not table_name and self._infer_schema: 468 sources_without_schema = tuple( 469 source 470 for source, columns in self._get_all_source_columns().items() 471 if not columns or "*" in columns 472 ) 473 if len(sources_without_schema) == 1: 474 table_name = sources_without_schema[0] 475 476 if table_name not in self.scope.selected_sources: 477 return exp.to_identifier(table_name) 478 479 node, _ = self.scope.selected_sources.get(table_name) 480 481 if isinstance(node, exp.Subqueryable): 482 while node and node.alias != table_name: 483 node = node.parent 484 485 node_alias = node.args.get("alias") 486 if node_alias: 487 return exp.to_identifier(node_alias.this) 488 489 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
def
get_source_columns(self, name, only_visible=False):
500 def get_source_columns(self, name, only_visible=False): 501 """Resolve the source columns for a given source `name`""" 502 if name not in self.scope.sources: 503 raise OptimizeError(f"Unknown table: {name}") 504 505 source = self.scope.sources[name] 506 507 # If referencing a table, return the columns from the schema 508 if isinstance(source, exp.Table): 509 return self.schema.column_names(source, only_visible) 510 511 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 512 return source.expression.alias_column_names 513 514 # Otherwise, if referencing another scope, return that scope's named selects 515 return source.expression.named_selects
Resolve the source columns for a given source name