sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.errors import UnsupportedError 7from sqlglot.helper import find_new_name, name_sequence 8 9 10if t.TYPE_CHECKING: 11 from sqlglot._typing import E 12 from sqlglot.generator import Generator 13 14 15def preprocess( 16 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 17) -> t.Callable[[Generator, exp.Expression], str]: 18 """ 19 Creates a new transform by chaining a sequence of transformations and converts the resulting 20 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 21 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 22 23 Args: 24 transforms: sequence of transform functions. These will be called in order. 25 26 Returns: 27 Function that can be used as a generator transform. 28 """ 29 30 def _to_sql(self, expression: exp.Expression) -> str: 31 expression_type = type(expression) 32 33 try: 34 expression = transforms[0](expression) 35 for transform in transforms[1:]: 36 expression = transform(expression) 37 except UnsupportedError as unsupported_error: 38 self.unsupported(str(unsupported_error)) 39 40 _sql_handler = getattr(self, expression.key + "_sql", None) 41 if _sql_handler: 42 return _sql_handler(expression) 43 44 transforms_handler = self.TRANSFORMS.get(type(expression)) 45 if transforms_handler: 46 if expression_type is type(expression): 47 if isinstance(expression, exp.Func): 48 return self.function_fallback_sql(expression) 49 50 # Ensures we don't enter an infinite loop. This can happen when the original expression 51 # has the same type as the final expression and there's no _sql method available for it, 52 # because then it'd re-enter _to_sql. 53 raise ValueError( 54 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 55 ) 56 57 return transforms_handler(self, expression) 58 59 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 60 61 return _to_sql 62 63 64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 65 if isinstance(expression, exp.Select): 66 count = 0 67 recursive_ctes = [] 68 69 for unnest in expression.find_all(exp.Unnest): 70 if ( 71 not isinstance(unnest.parent, (exp.From, exp.Join)) 72 or len(unnest.expressions) != 1 73 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 74 ): 75 continue 76 77 generate_date_array = unnest.expressions[0] 78 start = generate_date_array.args.get("start") 79 end = generate_date_array.args.get("end") 80 step = generate_date_array.args.get("step") 81 82 if not start or not end or not isinstance(step, exp.Interval): 83 continue 84 85 alias = unnest.args.get("alias") 86 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 87 88 start = exp.cast(start, "date") 89 date_add = exp.func( 90 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 91 ) 92 cast_date_add = exp.cast(date_add, "date") 93 94 cte_name = "_generated_dates" + (f"_{count}" if count else "") 95 96 base_query = exp.select(start.as_(column_name)) 97 recursive_query = ( 98 exp.select(cast_date_add) 99 .from_(cte_name) 100 .where(cast_date_add <= exp.cast(end, "date")) 101 ) 102 cte_query = base_query.union(recursive_query, distinct=False) 103 104 generate_dates_query = exp.select(column_name).from_(cte_name) 105 unnest.replace(generate_dates_query.subquery(cte_name)) 106 107 recursive_ctes.append( 108 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 109 ) 110 count += 1 111 112 if recursive_ctes: 113 with_expression = expression.args.get("with") or exp.With() 114 with_expression.set("recursive", True) 115 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 116 expression.set("with", with_expression) 117 118 return expression 119 120 121def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 122 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 123 this = expression.this 124 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 125 unnest = exp.Unnest(expressions=[this]) 126 if expression.alias: 127 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 128 129 return unnest 130 131 return expression 132 133 134def unalias_group(expression: exp.Expression) -> exp.Expression: 135 """ 136 Replace references to select aliases in GROUP BY clauses. 137 138 Example: 139 >>> import sqlglot 140 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 141 'SELECT a AS b FROM x GROUP BY 1' 142 143 Args: 144 expression: the expression that will be transformed. 145 146 Returns: 147 The transformed expression. 148 """ 149 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 150 aliased_selects = { 151 e.alias: i 152 for i, e in enumerate(expression.parent.expressions, start=1) 153 if isinstance(e, exp.Alias) 154 } 155 156 for group_by in expression.expressions: 157 if ( 158 isinstance(group_by, exp.Column) 159 and not group_by.table 160 and group_by.name in aliased_selects 161 ): 162 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 163 164 return expression 165 166 167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 168 """ 169 Convert SELECT DISTINCT ON statements to a subquery with a window function. 170 171 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 172 173 Args: 174 expression: the expression that will be transformed. 175 176 Returns: 177 The transformed expression. 178 """ 179 if ( 180 isinstance(expression, exp.Select) 181 and expression.args.get("distinct") 182 and expression.args["distinct"].args.get("on") 183 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 184 ): 185 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 186 row_number = find_new_name(expression.named_selects, "_row_number") 187 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 188 order = expression.args.get("order") 189 190 if order: 191 window.set("order", order.pop()) 192 else: 193 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 194 195 window = exp.alias_(window, row_number) 196 expression.select(window, copy=False) 197 198 return ( 199 exp.select("*", copy=False) 200 .from_(expression.subquery("_t", copy=False), copy=False) 201 .where(exp.column(row_number).eq(1), copy=False) 202 ) 203 204 return expression 205 206 207def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 208 """ 209 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 210 211 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 212 https://docs.snowflake.com/en/sql-reference/constructs/qualify 213 214 Some dialects don't support window functions in the WHERE clause, so we need to include them as 215 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 216 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 217 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 218 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 219 corresponding expression to avoid creating invalid column references. 220 """ 221 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 222 taken = set(expression.named_selects) 223 for select in expression.selects: 224 if not select.alias_or_name: 225 alias = find_new_name(taken, "_c") 226 select.replace(exp.alias_(select, alias)) 227 taken.add(alias) 228 229 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 230 alias_or_name = select.alias_or_name 231 identifier = select.args.get("alias") or select.this 232 if isinstance(identifier, exp.Identifier): 233 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 234 return alias_or_name 235 236 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 237 qualify_filters = expression.args["qualify"].pop().this 238 expression_by_alias = { 239 select.alias: select.this 240 for select in expression.selects 241 if isinstance(select, exp.Alias) 242 } 243 244 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 245 for select_candidate in qualify_filters.find_all(select_candidates): 246 if isinstance(select_candidate, exp.Window): 247 if expression_by_alias: 248 for column in select_candidate.find_all(exp.Column): 249 expr = expression_by_alias.get(column.name) 250 if expr: 251 column.replace(expr) 252 253 alias = find_new_name(expression.named_selects, "_w") 254 expression.select(exp.alias_(select_candidate, alias), copy=False) 255 column = exp.column(alias) 256 257 if isinstance(select_candidate.parent, exp.Qualify): 258 qualify_filters = column 259 else: 260 select_candidate.replace(column) 261 elif select_candidate.name not in expression.named_selects: 262 expression.select(select_candidate.copy(), copy=False) 263 264 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 265 qualify_filters, copy=False 266 ) 267 268 return expression 269 270 271def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 272 """ 273 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 274 other expressions. This transforms removes the precision from parameterized types in expressions. 275 """ 276 for node in expression.find_all(exp.DataType): 277 node.set( 278 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 279 ) 280 281 return expression 282 283 284def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 285 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 286 from sqlglot.optimizer.scope import find_all_in_scope 287 288 if isinstance(expression, exp.Select): 289 unnest_aliases = { 290 unnest.alias 291 for unnest in find_all_in_scope(expression, exp.Unnest) 292 if isinstance(unnest.parent, (exp.From, exp.Join)) 293 } 294 if unnest_aliases: 295 for column in expression.find_all(exp.Column): 296 if column.table in unnest_aliases: 297 column.set("table", None) 298 elif column.db in unnest_aliases: 299 column.set("db", None) 300 301 return expression 302 303 304def unnest_to_explode( 305 expression: exp.Expression, 306 unnest_using_arrays_zip: bool = True, 307) -> exp.Expression: 308 """Convert cross join unnest into lateral view explode.""" 309 310 def _unnest_zip_exprs( 311 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 312 ) -> t.List[exp.Expression]: 313 if has_multi_expr: 314 if not unnest_using_arrays_zip: 315 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 316 317 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 318 zip_exprs: t.List[exp.Expression] = [ 319 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 320 ] 321 u.set("expressions", zip_exprs) 322 return zip_exprs 323 return unnest_exprs 324 325 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 326 if u.args.get("offset"): 327 return exp.Posexplode 328 return exp.Inline if has_multi_expr else exp.Explode 329 330 if isinstance(expression, exp.Select): 331 from_ = expression.args.get("from") 332 333 if from_ and isinstance(from_.this, exp.Unnest): 334 unnest = from_.this 335 alias = unnest.args.get("alias") 336 exprs = unnest.expressions 337 has_multi_expr = len(exprs) > 1 338 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 339 340 unnest.replace( 341 exp.Table( 342 this=_udtf_type(unnest, has_multi_expr)( 343 this=this, 344 expressions=expressions, 345 ), 346 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 347 ) 348 ) 349 350 joins = expression.args.get("joins") or [] 351 for join in list(joins): 352 join_expr = join.this 353 354 is_lateral = isinstance(join_expr, exp.Lateral) 355 356 unnest = join_expr.this if is_lateral else join_expr 357 358 if isinstance(unnest, exp.Unnest): 359 if is_lateral: 360 alias = join_expr.args.get("alias") 361 else: 362 alias = unnest.args.get("alias") 363 exprs = unnest.expressions 364 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 365 has_multi_expr = len(exprs) > 1 366 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 367 368 joins.remove(join) 369 370 alias_cols = alias.columns if alias else [] 371 372 # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases 373 # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. 374 # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html 375 376 if not has_multi_expr and len(alias_cols) not in (1, 2): 377 raise UnsupportedError( 378 "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases" 379 ) 380 381 for e, column in zip(exprs, alias_cols): 382 expression.append( 383 "laterals", 384 exp.Lateral( 385 this=_udtf_type(unnest, has_multi_expr)(this=e), 386 view=True, 387 alias=exp.TableAlias( 388 this=alias.this, # type: ignore 389 columns=alias_cols, 390 ), 391 ), 392 ) 393 394 return expression 395 396 397def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 398 """Convert explode/posexplode into unnest.""" 399 400 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 401 if isinstance(expression, exp.Select): 402 from sqlglot.optimizer.scope import Scope 403 404 taken_select_names = set(expression.named_selects) 405 taken_source_names = {name for name, _ in Scope(expression).references} 406 407 def new_name(names: t.Set[str], name: str) -> str: 408 name = find_new_name(names, name) 409 names.add(name) 410 return name 411 412 arrays: t.List[exp.Condition] = [] 413 series_alias = new_name(taken_select_names, "pos") 414 series = exp.alias_( 415 exp.Unnest( 416 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 417 ), 418 new_name(taken_source_names, "_u"), 419 table=[series_alias], 420 ) 421 422 # we use list here because expression.selects is mutated inside the loop 423 for select in list(expression.selects): 424 explode = select.find(exp.Explode) 425 426 if explode: 427 pos_alias = "" 428 explode_alias = "" 429 430 if isinstance(select, exp.Alias): 431 explode_alias = select.args["alias"] 432 alias = select 433 elif isinstance(select, exp.Aliases): 434 pos_alias = select.aliases[0] 435 explode_alias = select.aliases[1] 436 alias = select.replace(exp.alias_(select.this, "", copy=False)) 437 else: 438 alias = select.replace(exp.alias_(select, "")) 439 explode = alias.find(exp.Explode) 440 assert explode 441 442 is_posexplode = isinstance(explode, exp.Posexplode) 443 explode_arg = explode.this 444 445 if isinstance(explode, exp.ExplodeOuter): 446 bracket = explode_arg[0] 447 bracket.set("safe", True) 448 bracket.set("offset", True) 449 explode_arg = exp.func( 450 "IF", 451 exp.func( 452 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 453 ).eq(0), 454 exp.array(bracket, copy=False), 455 explode_arg, 456 ) 457 458 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 459 if isinstance(explode_arg, exp.Column): 460 taken_select_names.add(explode_arg.output_name) 461 462 unnest_source_alias = new_name(taken_source_names, "_u") 463 464 if not explode_alias: 465 explode_alias = new_name(taken_select_names, "col") 466 467 if is_posexplode: 468 pos_alias = new_name(taken_select_names, "pos") 469 470 if not pos_alias: 471 pos_alias = new_name(taken_select_names, "pos") 472 473 alias.set("alias", exp.to_identifier(explode_alias)) 474 475 series_table_alias = series.args["alias"].this 476 column = exp.If( 477 this=exp.column(series_alias, table=series_table_alias).eq( 478 exp.column(pos_alias, table=unnest_source_alias) 479 ), 480 true=exp.column(explode_alias, table=unnest_source_alias), 481 ) 482 483 explode.replace(column) 484 485 if is_posexplode: 486 expressions = expression.expressions 487 expressions.insert( 488 expressions.index(alias) + 1, 489 exp.If( 490 this=exp.column(series_alias, table=series_table_alias).eq( 491 exp.column(pos_alias, table=unnest_source_alias) 492 ), 493 true=exp.column(pos_alias, table=unnest_source_alias), 494 ).as_(pos_alias), 495 ) 496 expression.set("expressions", expressions) 497 498 if not arrays: 499 if expression.args.get("from"): 500 expression.join(series, copy=False, join_type="CROSS") 501 else: 502 expression.from_(series, copy=False) 503 504 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 505 arrays.append(size) 506 507 # trino doesn't support left join unnest with on conditions 508 # if it did, this would be much simpler 509 expression.join( 510 exp.alias_( 511 exp.Unnest( 512 expressions=[explode_arg.copy()], 513 offset=exp.to_identifier(pos_alias), 514 ), 515 unnest_source_alias, 516 table=[explode_alias], 517 ), 518 join_type="CROSS", 519 copy=False, 520 ) 521 522 if index_offset != 1: 523 size = size - 1 524 525 expression.where( 526 exp.column(series_alias, table=series_table_alias) 527 .eq(exp.column(pos_alias, table=unnest_source_alias)) 528 .or_( 529 (exp.column(series_alias, table=series_table_alias) > size).and_( 530 exp.column(pos_alias, table=unnest_source_alias).eq(size) 531 ) 532 ), 533 copy=False, 534 ) 535 536 if arrays: 537 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 538 539 if index_offset != 1: 540 end = end - (1 - index_offset) 541 series.expressions[0].set("end", end) 542 543 return expression 544 545 return _explode_to_unnest 546 547 548def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 549 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 550 if ( 551 isinstance(expression, exp.PERCENTILES) 552 and not isinstance(expression.parent, exp.WithinGroup) 553 and expression.expression 554 ): 555 column = expression.this.pop() 556 expression.set("this", expression.expression.pop()) 557 order = exp.Order(expressions=[exp.Ordered(this=column)]) 558 expression = exp.WithinGroup(this=expression, expression=order) 559 560 return expression 561 562 563def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 564 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 565 if ( 566 isinstance(expression, exp.WithinGroup) 567 and isinstance(expression.this, exp.PERCENTILES) 568 and isinstance(expression.expression, exp.Order) 569 ): 570 quantile = expression.this.this 571 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 572 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 573 574 return expression 575 576 577def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 578 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 579 if isinstance(expression, exp.With) and expression.recursive: 580 next_name = name_sequence("_c_") 581 582 for cte in expression.expressions: 583 if not cte.args["alias"].columns: 584 query = cte.this 585 if isinstance(query, exp.SetOperation): 586 query = query.this 587 588 cte.args["alias"].set( 589 "columns", 590 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 591 ) 592 593 return expression 594 595 596def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 597 """Replace 'epoch' in casts by the equivalent date literal.""" 598 if ( 599 isinstance(expression, (exp.Cast, exp.TryCast)) 600 and expression.name.lower() == "epoch" 601 and expression.to.this in exp.DataType.TEMPORAL_TYPES 602 ): 603 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 604 605 return expression 606 607 608def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 609 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 610 if isinstance(expression, exp.Select): 611 for join in expression.args.get("joins") or []: 612 on = join.args.get("on") 613 if on and join.kind in ("SEMI", "ANTI"): 614 subquery = exp.select("1").from_(join.this).where(on) 615 exists = exp.Exists(this=subquery) 616 if join.kind == "ANTI": 617 exists = exists.not_(copy=False) 618 619 join.pop() 620 expression.where(exists, copy=False) 621 622 return expression 623 624 625def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 626 """ 627 Converts a query with a FULL OUTER join to a union of identical queries that 628 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 629 for queries that have a single FULL OUTER join. 630 """ 631 if isinstance(expression, exp.Select): 632 full_outer_joins = [ 633 (index, join) 634 for index, join in enumerate(expression.args.get("joins") or []) 635 if join.side == "FULL" 636 ] 637 638 if len(full_outer_joins) == 1: 639 expression_copy = expression.copy() 640 expression.set("limit", None) 641 index, full_outer_join = full_outer_joins[0] 642 643 tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) 644 join_conditions = full_outer_join.args.get("on") or exp.and_( 645 *[ 646 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 647 for col in full_outer_join.args.get("using") 648 ] 649 ) 650 651 full_outer_join.set("side", "left") 652 anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) 653 expression_copy.args["joins"][index].set("side", "right") 654 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 655 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 656 expression.args.pop("order", None) # remove order by from LEFT side 657 658 return exp.union(expression, expression_copy, copy=False, distinct=False) 659 660 return expression 661 662 663def move_ctes_to_top_level(expression: E) -> E: 664 """ 665 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 666 defined at the top-level, so for example queries like: 667 668 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 669 670 are invalid in those dialects. This transformation can be used to ensure all CTEs are 671 moved to the top level so that the final SQL code is valid from a syntax standpoint. 672 673 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 674 """ 675 top_level_with = expression.args.get("with") 676 for inner_with in expression.find_all(exp.With): 677 if inner_with.parent is expression: 678 continue 679 680 if not top_level_with: 681 top_level_with = inner_with.pop() 682 expression.set("with", top_level_with) 683 else: 684 if inner_with.recursive: 685 top_level_with.set("recursive", True) 686 687 parent_cte = inner_with.find_ancestor(exp.CTE) 688 inner_with.pop() 689 690 if parent_cte: 691 i = top_level_with.expressions.index(parent_cte) 692 top_level_with.expressions[i:i] = inner_with.expressions 693 top_level_with.set("expressions", top_level_with.expressions) 694 else: 695 top_level_with.set( 696 "expressions", top_level_with.expressions + inner_with.expressions 697 ) 698 699 return expression 700 701 702def ensure_bools(expression: exp.Expression) -> exp.Expression: 703 """Converts numeric values used in conditions into explicit boolean expressions.""" 704 from sqlglot.optimizer.canonicalize import ensure_bools 705 706 def _ensure_bool(node: exp.Expression) -> None: 707 if ( 708 node.is_number 709 or ( 710 not isinstance(node, exp.SubqueryPredicate) 711 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 712 ) 713 or (isinstance(node, exp.Column) and not node.type) 714 ): 715 node.replace(node.neq(0)) 716 717 for node in expression.walk(): 718 ensure_bools(node, _ensure_bool) 719 720 return expression 721 722 723def unqualify_columns(expression: exp.Expression) -> exp.Expression: 724 for column in expression.find_all(exp.Column): 725 # We only wanna pop off the table, db, catalog args 726 for part in column.parts[:-1]: 727 part.pop() 728 729 return expression 730 731 732def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 733 assert isinstance(expression, exp.Create) 734 for constraint in expression.find_all(exp.UniqueColumnConstraint): 735 if constraint.parent: 736 constraint.parent.pop() 737 738 return expression 739 740 741def ctas_with_tmp_tables_to_create_tmp_view( 742 expression: exp.Expression, 743 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 744) -> exp.Expression: 745 assert isinstance(expression, exp.Create) 746 properties = expression.args.get("properties") 747 temporary = any( 748 isinstance(prop, exp.TemporaryProperty) 749 for prop in (properties.expressions if properties else []) 750 ) 751 752 # CTAS with temp tables map to CREATE TEMPORARY VIEW 753 if expression.kind == "TABLE" and temporary: 754 if expression.expression: 755 return exp.Create( 756 kind="TEMPORARY VIEW", 757 this=expression.this, 758 expression=expression.expression, 759 ) 760 return tmp_storage_provider(expression) 761 762 return expression 763 764 765def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 766 """ 767 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 768 PARTITIONED BY value is an array of column names, they are transformed into a schema. 769 The corresponding columns are removed from the create statement. 770 """ 771 assert isinstance(expression, exp.Create) 772 has_schema = isinstance(expression.this, exp.Schema) 773 is_partitionable = expression.kind in {"TABLE", "VIEW"} 774 775 if has_schema and is_partitionable: 776 prop = expression.find(exp.PartitionedByProperty) 777 if prop and prop.this and not isinstance(prop.this, exp.Schema): 778 schema = expression.this 779 columns = {v.name.upper() for v in prop.this.expressions} 780 partitions = [col for col in schema.expressions if col.name.upper() in columns] 781 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 782 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 783 expression.set("this", schema) 784 785 return expression 786 787 788def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 789 """ 790 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 791 792 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 793 """ 794 assert isinstance(expression, exp.Create) 795 prop = expression.find(exp.PartitionedByProperty) 796 if ( 797 prop 798 and prop.this 799 and isinstance(prop.this, exp.Schema) 800 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 801 ): 802 prop_this = exp.Tuple( 803 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 804 ) 805 schema = expression.this 806 for e in prop.this.expressions: 807 schema.append("expressions", e) 808 prop.set("this", prop_this) 809 810 return expression 811 812 813def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 814 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 815 if isinstance(expression, exp.Struct): 816 expression.set( 817 "expressions", 818 [ 819 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 820 for e in expression.expressions 821 ], 822 ) 823 824 return expression 825 826 827def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 828 """ 829 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 830 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 831 832 For example, 833 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 834 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 835 836 Args: 837 expression: The AST to remove join marks from. 838 839 Returns: 840 The AST with join marks removed. 841 """ 842 from sqlglot.optimizer.scope import traverse_scope 843 844 for scope in traverse_scope(expression): 845 query = scope.expression 846 847 where = query.args.get("where") 848 joins = query.args.get("joins") 849 850 if not where or not joins: 851 continue 852 853 query_from = query.args["from"] 854 855 # These keep track of the joins to be replaced 856 new_joins: t.Dict[str, exp.Join] = {} 857 old_joins = {join.alias_or_name: join for join in joins} 858 859 for column in scope.columns: 860 if not column.args.get("join_mark"): 861 continue 862 863 predicate = column.find_ancestor(exp.Predicate, exp.Select) 864 assert isinstance( 865 predicate, exp.Binary 866 ), "Columns can only be marked with (+) when involved in a binary operation" 867 868 predicate_parent = predicate.parent 869 join_predicate = predicate.pop() 870 871 left_columns = [ 872 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 873 ] 874 right_columns = [ 875 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 876 ] 877 878 assert not ( 879 left_columns and right_columns 880 ), "The (+) marker cannot appear in both sides of a binary predicate" 881 882 marked_column_tables = set() 883 for col in left_columns or right_columns: 884 table = col.table 885 assert table, f"Column {col} needs to be qualified with a table" 886 887 col.set("join_mark", False) 888 marked_column_tables.add(table) 889 890 assert ( 891 len(marked_column_tables) == 1 892 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 893 894 join_this = old_joins.get(col.table, query_from).this 895 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 896 897 # Upsert new_join into new_joins dictionary 898 new_join_alias_or_name = new_join.alias_or_name 899 existing_join = new_joins.get(new_join_alias_or_name) 900 if existing_join: 901 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 902 else: 903 new_joins[new_join_alias_or_name] = new_join 904 905 # If the parent of the target predicate is a binary node, then it now has only one child 906 if isinstance(predicate_parent, exp.Binary): 907 if predicate_parent.left is None: 908 predicate_parent.replace(predicate_parent.right) 909 else: 910 predicate_parent.replace(predicate_parent.left) 911 912 if query_from.alias_or_name in new_joins: 913 only_old_joins = old_joins.keys() - new_joins.keys() 914 assert ( 915 len(only_old_joins) >= 1 916 ), "Cannot determine which table to use in the new FROM clause" 917 918 new_from_name = list(only_old_joins)[0] 919 query.set("from", exp.From(this=old_joins[new_from_name].this)) 920 921 query.set("joins", list(new_joins.values())) 922 923 if not where.this: 924 where.pop() 925 926 return expression 927 928 929def any_to_exists(expression: exp.Expression) -> exp.Expression: 930 """ 931 Transform ANY operator to Spark's EXISTS 932 933 For example, 934 - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) 935 - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) 936 937 Both ANY and EXISTS accept queries but currently only array expressions are supported for this 938 transformation 939 """ 940 if isinstance(expression, exp.Select): 941 for any in expression.find_all(exp.Any): 942 this = any.this 943 if isinstance(this, exp.Query): 944 continue 945 946 binop = any.parent 947 if isinstance(binop, exp.Binary): 948 lambda_arg = exp.to_identifier("x") 949 any.replace(lambda_arg) 950 lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) 951 binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) 952 953 return expression
16def preprocess( 17 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 18) -> t.Callable[[Generator, exp.Expression], str]: 19 """ 20 Creates a new transform by chaining a sequence of transformations and converts the resulting 21 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 22 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 23 24 Args: 25 transforms: sequence of transform functions. These will be called in order. 26 27 Returns: 28 Function that can be used as a generator transform. 29 """ 30 31 def _to_sql(self, expression: exp.Expression) -> str: 32 expression_type = type(expression) 33 34 try: 35 expression = transforms[0](expression) 36 for transform in transforms[1:]: 37 expression = transform(expression) 38 except UnsupportedError as unsupported_error: 39 self.unsupported(str(unsupported_error)) 40 41 _sql_handler = getattr(self, expression.key + "_sql", None) 42 if _sql_handler: 43 return _sql_handler(expression) 44 45 transforms_handler = self.TRANSFORMS.get(type(expression)) 46 if transforms_handler: 47 if expression_type is type(expression): 48 if isinstance(expression, exp.Func): 49 return self.function_fallback_sql(expression) 50 51 # Ensures we don't enter an infinite loop. This can happen when the original expression 52 # has the same type as the final expression and there's no _sql method available for it, 53 # because then it'd re-enter _to_sql. 54 raise ValueError( 55 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 56 ) 57 58 return transforms_handler(self, expression) 59 60 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 61 62 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 66 if isinstance(expression, exp.Select): 67 count = 0 68 recursive_ctes = [] 69 70 for unnest in expression.find_all(exp.Unnest): 71 if ( 72 not isinstance(unnest.parent, (exp.From, exp.Join)) 73 or len(unnest.expressions) != 1 74 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 75 ): 76 continue 77 78 generate_date_array = unnest.expressions[0] 79 start = generate_date_array.args.get("start") 80 end = generate_date_array.args.get("end") 81 step = generate_date_array.args.get("step") 82 83 if not start or not end or not isinstance(step, exp.Interval): 84 continue 85 86 alias = unnest.args.get("alias") 87 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 88 89 start = exp.cast(start, "date") 90 date_add = exp.func( 91 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 92 ) 93 cast_date_add = exp.cast(date_add, "date") 94 95 cte_name = "_generated_dates" + (f"_{count}" if count else "") 96 97 base_query = exp.select(start.as_(column_name)) 98 recursive_query = ( 99 exp.select(cast_date_add) 100 .from_(cte_name) 101 .where(cast_date_add <= exp.cast(end, "date")) 102 ) 103 cte_query = base_query.union(recursive_query, distinct=False) 104 105 generate_dates_query = exp.select(column_name).from_(cte_name) 106 unnest.replace(generate_dates_query.subquery(cte_name)) 107 108 recursive_ctes.append( 109 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 110 ) 111 count += 1 112 113 if recursive_ctes: 114 with_expression = expression.args.get("with") or exp.With() 115 with_expression.set("recursive", True) 116 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 117 expression.set("with", with_expression) 118 119 return expression
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 123 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 124 this = expression.this 125 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 126 unnest = exp.Unnest(expressions=[this]) 127 if expression.alias: 128 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 129 130 return unnest 131 132 return expression
Unnests GENERATE_SERIES or SEQUENCE table references.
135def unalias_group(expression: exp.Expression) -> exp.Expression: 136 """ 137 Replace references to select aliases in GROUP BY clauses. 138 139 Example: 140 >>> import sqlglot 141 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 142 'SELECT a AS b FROM x GROUP BY 1' 143 144 Args: 145 expression: the expression that will be transformed. 146 147 Returns: 148 The transformed expression. 149 """ 150 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 151 aliased_selects = { 152 e.alias: i 153 for i, e in enumerate(expression.parent.expressions, start=1) 154 if isinstance(e, exp.Alias) 155 } 156 157 for group_by in expression.expressions: 158 if ( 159 isinstance(group_by, exp.Column) 160 and not group_by.table 161 and group_by.name in aliased_selects 162 ): 163 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 164 165 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
168def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 169 """ 170 Convert SELECT DISTINCT ON statements to a subquery with a window function. 171 172 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 173 174 Args: 175 expression: the expression that will be transformed. 176 177 Returns: 178 The transformed expression. 179 """ 180 if ( 181 isinstance(expression, exp.Select) 182 and expression.args.get("distinct") 183 and expression.args["distinct"].args.get("on") 184 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 185 ): 186 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 187 row_number = find_new_name(expression.named_selects, "_row_number") 188 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 189 order = expression.args.get("order") 190 191 if order: 192 window.set("order", order.pop()) 193 else: 194 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 195 196 window = exp.alias_(window, row_number) 197 expression.select(window, copy=False) 198 199 return ( 200 exp.select("*", copy=False) 201 .from_(expression.subquery("_t", copy=False), copy=False) 202 .where(exp.column(row_number).eq(1), copy=False) 203 ) 204 205 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
208def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 209 """ 210 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 211 212 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 213 https://docs.snowflake.com/en/sql-reference/constructs/qualify 214 215 Some dialects don't support window functions in the WHERE clause, so we need to include them as 216 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 217 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 218 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 219 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 220 corresponding expression to avoid creating invalid column references. 221 """ 222 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 223 taken = set(expression.named_selects) 224 for select in expression.selects: 225 if not select.alias_or_name: 226 alias = find_new_name(taken, "_c") 227 select.replace(exp.alias_(select, alias)) 228 taken.add(alias) 229 230 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 231 alias_or_name = select.alias_or_name 232 identifier = select.args.get("alias") or select.this 233 if isinstance(identifier, exp.Identifier): 234 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 235 return alias_or_name 236 237 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 238 qualify_filters = expression.args["qualify"].pop().this 239 expression_by_alias = { 240 select.alias: select.this 241 for select in expression.selects 242 if isinstance(select, exp.Alias) 243 } 244 245 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 246 for select_candidate in qualify_filters.find_all(select_candidates): 247 if isinstance(select_candidate, exp.Window): 248 if expression_by_alias: 249 for column in select_candidate.find_all(exp.Column): 250 expr = expression_by_alias.get(column.name) 251 if expr: 252 column.replace(expr) 253 254 alias = find_new_name(expression.named_selects, "_w") 255 expression.select(exp.alias_(select_candidate, alias), copy=False) 256 column = exp.column(alias) 257 258 if isinstance(select_candidate.parent, exp.Qualify): 259 qualify_filters = column 260 else: 261 select_candidate.replace(column) 262 elif select_candidate.name not in expression.named_selects: 263 expression.select(select_candidate.copy(), copy=False) 264 265 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 266 qualify_filters, copy=False 267 ) 268 269 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
272def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 273 """ 274 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 275 other expressions. This transforms removes the precision from parameterized types in expressions. 276 """ 277 for node in expression.find_all(exp.DataType): 278 node.set( 279 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 280 ) 281 282 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
285def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 286 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 287 from sqlglot.optimizer.scope import find_all_in_scope 288 289 if isinstance(expression, exp.Select): 290 unnest_aliases = { 291 unnest.alias 292 for unnest in find_all_in_scope(expression, exp.Unnest) 293 if isinstance(unnest.parent, (exp.From, exp.Join)) 294 } 295 if unnest_aliases: 296 for column in expression.find_all(exp.Column): 297 if column.table in unnest_aliases: 298 column.set("table", None) 299 elif column.db in unnest_aliases: 300 column.set("db", None) 301 302 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
305def unnest_to_explode( 306 expression: exp.Expression, 307 unnest_using_arrays_zip: bool = True, 308) -> exp.Expression: 309 """Convert cross join unnest into lateral view explode.""" 310 311 def _unnest_zip_exprs( 312 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 313 ) -> t.List[exp.Expression]: 314 if has_multi_expr: 315 if not unnest_using_arrays_zip: 316 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 317 318 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 319 zip_exprs: t.List[exp.Expression] = [ 320 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 321 ] 322 u.set("expressions", zip_exprs) 323 return zip_exprs 324 return unnest_exprs 325 326 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 327 if u.args.get("offset"): 328 return exp.Posexplode 329 return exp.Inline if has_multi_expr else exp.Explode 330 331 if isinstance(expression, exp.Select): 332 from_ = expression.args.get("from") 333 334 if from_ and isinstance(from_.this, exp.Unnest): 335 unnest = from_.this 336 alias = unnest.args.get("alias") 337 exprs = unnest.expressions 338 has_multi_expr = len(exprs) > 1 339 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 340 341 unnest.replace( 342 exp.Table( 343 this=_udtf_type(unnest, has_multi_expr)( 344 this=this, 345 expressions=expressions, 346 ), 347 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 348 ) 349 ) 350 351 joins = expression.args.get("joins") or [] 352 for join in list(joins): 353 join_expr = join.this 354 355 is_lateral = isinstance(join_expr, exp.Lateral) 356 357 unnest = join_expr.this if is_lateral else join_expr 358 359 if isinstance(unnest, exp.Unnest): 360 if is_lateral: 361 alias = join_expr.args.get("alias") 362 else: 363 alias = unnest.args.get("alias") 364 exprs = unnest.expressions 365 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 366 has_multi_expr = len(exprs) > 1 367 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 368 369 joins.remove(join) 370 371 alias_cols = alias.columns if alias else [] 372 373 # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases 374 # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. 375 # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html 376 377 if not has_multi_expr and len(alias_cols) not in (1, 2): 378 raise UnsupportedError( 379 "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases" 380 ) 381 382 for e, column in zip(exprs, alias_cols): 383 expression.append( 384 "laterals", 385 exp.Lateral( 386 this=_udtf_type(unnest, has_multi_expr)(this=e), 387 view=True, 388 alias=exp.TableAlias( 389 this=alias.this, # type: ignore 390 columns=alias_cols, 391 ), 392 ), 393 ) 394 395 return expression
Convert cross join unnest into lateral view explode.
398def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 399 """Convert explode/posexplode into unnest.""" 400 401 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 402 if isinstance(expression, exp.Select): 403 from sqlglot.optimizer.scope import Scope 404 405 taken_select_names = set(expression.named_selects) 406 taken_source_names = {name for name, _ in Scope(expression).references} 407 408 def new_name(names: t.Set[str], name: str) -> str: 409 name = find_new_name(names, name) 410 names.add(name) 411 return name 412 413 arrays: t.List[exp.Condition] = [] 414 series_alias = new_name(taken_select_names, "pos") 415 series = exp.alias_( 416 exp.Unnest( 417 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 418 ), 419 new_name(taken_source_names, "_u"), 420 table=[series_alias], 421 ) 422 423 # we use list here because expression.selects is mutated inside the loop 424 for select in list(expression.selects): 425 explode = select.find(exp.Explode) 426 427 if explode: 428 pos_alias = "" 429 explode_alias = "" 430 431 if isinstance(select, exp.Alias): 432 explode_alias = select.args["alias"] 433 alias = select 434 elif isinstance(select, exp.Aliases): 435 pos_alias = select.aliases[0] 436 explode_alias = select.aliases[1] 437 alias = select.replace(exp.alias_(select.this, "", copy=False)) 438 else: 439 alias = select.replace(exp.alias_(select, "")) 440 explode = alias.find(exp.Explode) 441 assert explode 442 443 is_posexplode = isinstance(explode, exp.Posexplode) 444 explode_arg = explode.this 445 446 if isinstance(explode, exp.ExplodeOuter): 447 bracket = explode_arg[0] 448 bracket.set("safe", True) 449 bracket.set("offset", True) 450 explode_arg = exp.func( 451 "IF", 452 exp.func( 453 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 454 ).eq(0), 455 exp.array(bracket, copy=False), 456 explode_arg, 457 ) 458 459 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 460 if isinstance(explode_arg, exp.Column): 461 taken_select_names.add(explode_arg.output_name) 462 463 unnest_source_alias = new_name(taken_source_names, "_u") 464 465 if not explode_alias: 466 explode_alias = new_name(taken_select_names, "col") 467 468 if is_posexplode: 469 pos_alias = new_name(taken_select_names, "pos") 470 471 if not pos_alias: 472 pos_alias = new_name(taken_select_names, "pos") 473 474 alias.set("alias", exp.to_identifier(explode_alias)) 475 476 series_table_alias = series.args["alias"].this 477 column = exp.If( 478 this=exp.column(series_alias, table=series_table_alias).eq( 479 exp.column(pos_alias, table=unnest_source_alias) 480 ), 481 true=exp.column(explode_alias, table=unnest_source_alias), 482 ) 483 484 explode.replace(column) 485 486 if is_posexplode: 487 expressions = expression.expressions 488 expressions.insert( 489 expressions.index(alias) + 1, 490 exp.If( 491 this=exp.column(series_alias, table=series_table_alias).eq( 492 exp.column(pos_alias, table=unnest_source_alias) 493 ), 494 true=exp.column(pos_alias, table=unnest_source_alias), 495 ).as_(pos_alias), 496 ) 497 expression.set("expressions", expressions) 498 499 if not arrays: 500 if expression.args.get("from"): 501 expression.join(series, copy=False, join_type="CROSS") 502 else: 503 expression.from_(series, copy=False) 504 505 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 506 arrays.append(size) 507 508 # trino doesn't support left join unnest with on conditions 509 # if it did, this would be much simpler 510 expression.join( 511 exp.alias_( 512 exp.Unnest( 513 expressions=[explode_arg.copy()], 514 offset=exp.to_identifier(pos_alias), 515 ), 516 unnest_source_alias, 517 table=[explode_alias], 518 ), 519 join_type="CROSS", 520 copy=False, 521 ) 522 523 if index_offset != 1: 524 size = size - 1 525 526 expression.where( 527 exp.column(series_alias, table=series_table_alias) 528 .eq(exp.column(pos_alias, table=unnest_source_alias)) 529 .or_( 530 (exp.column(series_alias, table=series_table_alias) > size).and_( 531 exp.column(pos_alias, table=unnest_source_alias).eq(size) 532 ) 533 ), 534 copy=False, 535 ) 536 537 if arrays: 538 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 539 540 if index_offset != 1: 541 end = end - (1 - index_offset) 542 series.expressions[0].set("end", end) 543 544 return expression 545 546 return _explode_to_unnest
Convert explode/posexplode into unnest.
549def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 550 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 551 if ( 552 isinstance(expression, exp.PERCENTILES) 553 and not isinstance(expression.parent, exp.WithinGroup) 554 and expression.expression 555 ): 556 column = expression.this.pop() 557 expression.set("this", expression.expression.pop()) 558 order = exp.Order(expressions=[exp.Ordered(this=column)]) 559 expression = exp.WithinGroup(this=expression, expression=order) 560 561 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
564def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 565 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 566 if ( 567 isinstance(expression, exp.WithinGroup) 568 and isinstance(expression.this, exp.PERCENTILES) 569 and isinstance(expression.expression, exp.Order) 570 ): 571 quantile = expression.this.this 572 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 573 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 574 575 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
578def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 579 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 580 if isinstance(expression, exp.With) and expression.recursive: 581 next_name = name_sequence("_c_") 582 583 for cte in expression.expressions: 584 if not cte.args["alias"].columns: 585 query = cte.this 586 if isinstance(query, exp.SetOperation): 587 query = query.this 588 589 cte.args["alias"].set( 590 "columns", 591 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 592 ) 593 594 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
597def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 598 """Replace 'epoch' in casts by the equivalent date literal.""" 599 if ( 600 isinstance(expression, (exp.Cast, exp.TryCast)) 601 and expression.name.lower() == "epoch" 602 and expression.to.this in exp.DataType.TEMPORAL_TYPES 603 ): 604 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 605 606 return expression
Replace 'epoch' in casts by the equivalent date literal.
609def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 610 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 611 if isinstance(expression, exp.Select): 612 for join in expression.args.get("joins") or []: 613 on = join.args.get("on") 614 if on and join.kind in ("SEMI", "ANTI"): 615 subquery = exp.select("1").from_(join.this).where(on) 616 exists = exp.Exists(this=subquery) 617 if join.kind == "ANTI": 618 exists = exists.not_(copy=False) 619 620 join.pop() 621 expression.where(exists, copy=False) 622 623 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
626def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 627 """ 628 Converts a query with a FULL OUTER join to a union of identical queries that 629 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 630 for queries that have a single FULL OUTER join. 631 """ 632 if isinstance(expression, exp.Select): 633 full_outer_joins = [ 634 (index, join) 635 for index, join in enumerate(expression.args.get("joins") or []) 636 if join.side == "FULL" 637 ] 638 639 if len(full_outer_joins) == 1: 640 expression_copy = expression.copy() 641 expression.set("limit", None) 642 index, full_outer_join = full_outer_joins[0] 643 644 tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) 645 join_conditions = full_outer_join.args.get("on") or exp.and_( 646 *[ 647 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 648 for col in full_outer_join.args.get("using") 649 ] 650 ) 651 652 full_outer_join.set("side", "left") 653 anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) 654 expression_copy.args["joins"][index].set("side", "right") 655 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 656 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 657 expression.args.pop("order", None) # remove order by from LEFT side 658 659 return exp.union(expression, expression_copy, copy=False, distinct=False) 660 661 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
664def move_ctes_to_top_level(expression: E) -> E: 665 """ 666 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 667 defined at the top-level, so for example queries like: 668 669 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 670 671 are invalid in those dialects. This transformation can be used to ensure all CTEs are 672 moved to the top level so that the final SQL code is valid from a syntax standpoint. 673 674 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 675 """ 676 top_level_with = expression.args.get("with") 677 for inner_with in expression.find_all(exp.With): 678 if inner_with.parent is expression: 679 continue 680 681 if not top_level_with: 682 top_level_with = inner_with.pop() 683 expression.set("with", top_level_with) 684 else: 685 if inner_with.recursive: 686 top_level_with.set("recursive", True) 687 688 parent_cte = inner_with.find_ancestor(exp.CTE) 689 inner_with.pop() 690 691 if parent_cte: 692 i = top_level_with.expressions.index(parent_cte) 693 top_level_with.expressions[i:i] = inner_with.expressions 694 top_level_with.set("expressions", top_level_with.expressions) 695 else: 696 top_level_with.set( 697 "expressions", top_level_with.expressions + inner_with.expressions 698 ) 699 700 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
703def ensure_bools(expression: exp.Expression) -> exp.Expression: 704 """Converts numeric values used in conditions into explicit boolean expressions.""" 705 from sqlglot.optimizer.canonicalize import ensure_bools 706 707 def _ensure_bool(node: exp.Expression) -> None: 708 if ( 709 node.is_number 710 or ( 711 not isinstance(node, exp.SubqueryPredicate) 712 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 713 ) 714 or (isinstance(node, exp.Column) and not node.type) 715 ): 716 node.replace(node.neq(0)) 717 718 for node in expression.walk(): 719 ensure_bools(node, _ensure_bool) 720 721 return expression
Converts numeric values used in conditions into explicit boolean expressions.
742def ctas_with_tmp_tables_to_create_tmp_view( 743 expression: exp.Expression, 744 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 745) -> exp.Expression: 746 assert isinstance(expression, exp.Create) 747 properties = expression.args.get("properties") 748 temporary = any( 749 isinstance(prop, exp.TemporaryProperty) 750 for prop in (properties.expressions if properties else []) 751 ) 752 753 # CTAS with temp tables map to CREATE TEMPORARY VIEW 754 if expression.kind == "TABLE" and temporary: 755 if expression.expression: 756 return exp.Create( 757 kind="TEMPORARY VIEW", 758 this=expression.this, 759 expression=expression.expression, 760 ) 761 return tmp_storage_provider(expression) 762 763 return expression
766def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 767 """ 768 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 769 PARTITIONED BY value is an array of column names, they are transformed into a schema. 770 The corresponding columns are removed from the create statement. 771 """ 772 assert isinstance(expression, exp.Create) 773 has_schema = isinstance(expression.this, exp.Schema) 774 is_partitionable = expression.kind in {"TABLE", "VIEW"} 775 776 if has_schema and is_partitionable: 777 prop = expression.find(exp.PartitionedByProperty) 778 if prop and prop.this and not isinstance(prop.this, exp.Schema): 779 schema = expression.this 780 columns = {v.name.upper() for v in prop.this.expressions} 781 partitions = [col for col in schema.expressions if col.name.upper() in columns] 782 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 783 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 784 expression.set("this", schema) 785 786 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
789def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 790 """ 791 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 792 793 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 794 """ 795 assert isinstance(expression, exp.Create) 796 prop = expression.find(exp.PartitionedByProperty) 797 if ( 798 prop 799 and prop.this 800 and isinstance(prop.this, exp.Schema) 801 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 802 ): 803 prop_this = exp.Tuple( 804 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 805 ) 806 schema = expression.this 807 for e in prop.this.expressions: 808 schema.append("expressions", e) 809 prop.set("this", prop_this) 810 811 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
814def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 815 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 816 if isinstance(expression, exp.Struct): 817 expression.set( 818 "expressions", 819 [ 820 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 821 for e in expression.expressions 822 ], 823 ) 824 825 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
828def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 829 """ 830 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 831 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 832 833 For example, 834 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 835 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 836 837 Args: 838 expression: The AST to remove join marks from. 839 840 Returns: 841 The AST with join marks removed. 842 """ 843 from sqlglot.optimizer.scope import traverse_scope 844 845 for scope in traverse_scope(expression): 846 query = scope.expression 847 848 where = query.args.get("where") 849 joins = query.args.get("joins") 850 851 if not where or not joins: 852 continue 853 854 query_from = query.args["from"] 855 856 # These keep track of the joins to be replaced 857 new_joins: t.Dict[str, exp.Join] = {} 858 old_joins = {join.alias_or_name: join for join in joins} 859 860 for column in scope.columns: 861 if not column.args.get("join_mark"): 862 continue 863 864 predicate = column.find_ancestor(exp.Predicate, exp.Select) 865 assert isinstance( 866 predicate, exp.Binary 867 ), "Columns can only be marked with (+) when involved in a binary operation" 868 869 predicate_parent = predicate.parent 870 join_predicate = predicate.pop() 871 872 left_columns = [ 873 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 874 ] 875 right_columns = [ 876 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 877 ] 878 879 assert not ( 880 left_columns and right_columns 881 ), "The (+) marker cannot appear in both sides of a binary predicate" 882 883 marked_column_tables = set() 884 for col in left_columns or right_columns: 885 table = col.table 886 assert table, f"Column {col} needs to be qualified with a table" 887 888 col.set("join_mark", False) 889 marked_column_tables.add(table) 890 891 assert ( 892 len(marked_column_tables) == 1 893 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 894 895 join_this = old_joins.get(col.table, query_from).this 896 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 897 898 # Upsert new_join into new_joins dictionary 899 new_join_alias_or_name = new_join.alias_or_name 900 existing_join = new_joins.get(new_join_alias_or_name) 901 if existing_join: 902 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 903 else: 904 new_joins[new_join_alias_or_name] = new_join 905 906 # If the parent of the target predicate is a binary node, then it now has only one child 907 if isinstance(predicate_parent, exp.Binary): 908 if predicate_parent.left is None: 909 predicate_parent.replace(predicate_parent.right) 910 else: 911 predicate_parent.replace(predicate_parent.left) 912 913 if query_from.alias_or_name in new_joins: 914 only_old_joins = old_joins.keys() - new_joins.keys() 915 assert ( 916 len(only_old_joins) >= 1 917 ), "Cannot determine which table to use in the new FROM clause" 918 919 new_from_name = list(only_old_joins)[0] 920 query.set("from", exp.From(this=old_joins[new_from_name].this)) 921 922 query.set("joins", list(new_joins.values())) 923 924 if not where.this: 925 where.pop() 926 927 return expression
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running sqlglot.optimizer.qualify
first.
For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Arguments:
- expression: The AST to remove join marks from.
Returns:
The AST with join marks removed.
930def any_to_exists(expression: exp.Expression) -> exp.Expression: 931 """ 932 Transform ANY operator to Spark's EXISTS 933 934 For example, 935 - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) 936 - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) 937 938 Both ANY and EXISTS accept queries but currently only array expressions are supported for this 939 transformation 940 """ 941 if isinstance(expression, exp.Select): 942 for any in expression.find_all(exp.Any): 943 this = any.this 944 if isinstance(this, exp.Query): 945 continue 946 947 binop = any.parent 948 if isinstance(binop, exp.Binary): 949 lambda_arg = exp.to_identifier("x") 950 any.replace(lambda_arg) 951 lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) 952 binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) 953 954 return expression
Transform ANY operator to Spark's EXISTS
For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation