sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop()) 71 else: 72 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 73 74 window = exp.alias_(window, row_number) 75 expression.select(window, copy=False) 76 77 return ( 78 exp.select(*outer_selects, copy=False) 79 .from_(expression.subquery("_t", copy=False), copy=False) 80 .where(exp.column(row_number).eq(1), copy=False) 81 ) 82 83 return expression 84 85 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 87 """ 88 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 89 90 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 91 https://docs.snowflake.com/en/sql-reference/constructs/qualify 92 93 Some dialects don't support window functions in the WHERE clause, so we need to include them as 94 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 95 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 96 otherwise we won't be able to refer to it in the outer query's WHERE clause. 97 """ 98 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 99 taken = set(expression.named_selects) 100 for select in expression.selects: 101 if not select.alias_or_name: 102 alias = find_new_name(taken, "_c") 103 select.replace(exp.alias_(select, alias)) 104 taken.add(alias) 105 106 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 107 qualify_filters = expression.args["qualify"].pop().this 108 109 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 110 for expr in qualify_filters.find_all(select_candidates): 111 if isinstance(expr, exp.Window): 112 alias = find_new_name(expression.named_selects, "_w") 113 expression.select(exp.alias_(expr, alias), copy=False) 114 column = exp.column(alias) 115 116 if isinstance(expr.parent, exp.Qualify): 117 qualify_filters = column 118 else: 119 expr.replace(column) 120 elif expr.name not in expression.named_selects: 121 expression.select(expr.copy(), copy=False) 122 123 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 124 qualify_filters, copy=False 125 ) 126 127 return expression 128 129 130def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 131 """ 132 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 133 other expressions. This transforms removes the precision from parameterized types in expressions. 134 """ 135 for node in expression.find_all(exp.DataType): 136 node.set( 137 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 138 ) 139 140 return expression 141 142 143def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 144 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 145 from sqlglot.optimizer.scope import find_all_in_scope 146 147 if isinstance(expression, exp.Select): 148 unnest_aliases = { 149 unnest.alias 150 for unnest in find_all_in_scope(expression, exp.Unnest) 151 if isinstance(unnest.parent, (exp.From, exp.Join)) 152 } 153 if unnest_aliases: 154 for column in expression.find_all(exp.Column): 155 if column.table in unnest_aliases: 156 column.set("table", None) 157 elif column.db in unnest_aliases: 158 column.set("db", None) 159 160 return expression 161 162 163def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 164 """Convert cross join unnest into lateral view explode.""" 165 if isinstance(expression, exp.Select): 166 for join in expression.args.get("joins") or []: 167 unnest = join.this 168 169 if isinstance(unnest, exp.Unnest): 170 alias = unnest.args.get("alias") 171 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 172 173 expression.args["joins"].remove(join) 174 175 for e, column in zip(unnest.expressions, alias.columns if alias else []): 176 expression.append( 177 "laterals", 178 exp.Lateral( 179 this=udtf(this=e), 180 view=True, 181 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 182 ), 183 ) 184 185 return expression 186 187 188def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 189 """Convert explode/posexplode into unnest.""" 190 191 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 192 if isinstance(expression, exp.Select): 193 from sqlglot.optimizer.scope import Scope 194 195 taken_select_names = set(expression.named_selects) 196 taken_source_names = {name for name, _ in Scope(expression).references} 197 198 def new_name(names: t.Set[str], name: str) -> str: 199 name = find_new_name(names, name) 200 names.add(name) 201 return name 202 203 arrays: t.List[exp.Condition] = [] 204 series_alias = new_name(taken_select_names, "pos") 205 series = exp.alias_( 206 exp.Unnest( 207 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 208 ), 209 new_name(taken_source_names, "_u"), 210 table=[series_alias], 211 ) 212 213 # we use list here because expression.selects is mutated inside the loop 214 for select in list(expression.selects): 215 explode = select.find(exp.Explode) 216 217 if explode: 218 pos_alias = "" 219 explode_alias = "" 220 221 if isinstance(select, exp.Alias): 222 explode_alias = select.args["alias"] 223 alias = select 224 elif isinstance(select, exp.Aliases): 225 pos_alias = select.aliases[0] 226 explode_alias = select.aliases[1] 227 alias = select.replace(exp.alias_(select.this, "", copy=False)) 228 else: 229 alias = select.replace(exp.alias_(select, "")) 230 explode = alias.find(exp.Explode) 231 assert explode 232 233 is_posexplode = isinstance(explode, exp.Posexplode) 234 explode_arg = explode.this 235 236 if isinstance(explode, exp.ExplodeOuter): 237 bracket = explode_arg[0] 238 bracket.set("safe", True) 239 bracket.set("offset", True) 240 explode_arg = exp.func( 241 "IF", 242 exp.func( 243 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 244 ).eq(0), 245 exp.array(bracket, copy=False), 246 explode_arg, 247 ) 248 249 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 250 if isinstance(explode_arg, exp.Column): 251 taken_select_names.add(explode_arg.output_name) 252 253 unnest_source_alias = new_name(taken_source_names, "_u") 254 255 if not explode_alias: 256 explode_alias = new_name(taken_select_names, "col") 257 258 if is_posexplode: 259 pos_alias = new_name(taken_select_names, "pos") 260 261 if not pos_alias: 262 pos_alias = new_name(taken_select_names, "pos") 263 264 alias.set("alias", exp.to_identifier(explode_alias)) 265 266 series_table_alias = series.args["alias"].this 267 column = exp.If( 268 this=exp.column(series_alias, table=series_table_alias).eq( 269 exp.column(pos_alias, table=unnest_source_alias) 270 ), 271 true=exp.column(explode_alias, table=unnest_source_alias), 272 ) 273 274 explode.replace(column) 275 276 if is_posexplode: 277 expressions = expression.expressions 278 expressions.insert( 279 expressions.index(alias) + 1, 280 exp.If( 281 this=exp.column(series_alias, table=series_table_alias).eq( 282 exp.column(pos_alias, table=unnest_source_alias) 283 ), 284 true=exp.column(pos_alias, table=unnest_source_alias), 285 ).as_(pos_alias), 286 ) 287 expression.set("expressions", expressions) 288 289 if not arrays: 290 if expression.args.get("from"): 291 expression.join(series, copy=False, join_type="CROSS") 292 else: 293 expression.from_(series, copy=False) 294 295 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 296 arrays.append(size) 297 298 # trino doesn't support left join unnest with on conditions 299 # if it did, this would be much simpler 300 expression.join( 301 exp.alias_( 302 exp.Unnest( 303 expressions=[explode_arg.copy()], 304 offset=exp.to_identifier(pos_alias), 305 ), 306 unnest_source_alias, 307 table=[explode_alias], 308 ), 309 join_type="CROSS", 310 copy=False, 311 ) 312 313 if index_offset != 1: 314 size = size - 1 315 316 expression.where( 317 exp.column(series_alias, table=series_table_alias) 318 .eq(exp.column(pos_alias, table=unnest_source_alias)) 319 .or_( 320 (exp.column(series_alias, table=series_table_alias) > size).and_( 321 exp.column(pos_alias, table=unnest_source_alias).eq(size) 322 ) 323 ), 324 copy=False, 325 ) 326 327 if arrays: 328 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 329 330 if index_offset != 1: 331 end = end - (1 - index_offset) 332 series.expressions[0].set("end", end) 333 334 return expression 335 336 return _explode_to_unnest 337 338 339PERCENTILES = (exp.PercentileCont, exp.PercentileDisc) 340 341 342def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 343 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 344 if ( 345 isinstance(expression, PERCENTILES) 346 and not isinstance(expression.parent, exp.WithinGroup) 347 and expression.expression 348 ): 349 column = expression.this.pop() 350 expression.set("this", expression.expression.pop()) 351 order = exp.Order(expressions=[exp.Ordered(this=column)]) 352 expression = exp.WithinGroup(this=expression, expression=order) 353 354 return expression 355 356 357def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 358 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 359 if ( 360 isinstance(expression, exp.WithinGroup) 361 and isinstance(expression.this, PERCENTILES) 362 and isinstance(expression.expression, exp.Order) 363 ): 364 quantile = expression.this.this 365 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 366 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 367 368 return expression 369 370 371def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 372 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 373 if isinstance(expression, exp.With) and expression.recursive: 374 next_name = name_sequence("_c_") 375 376 for cte in expression.expressions: 377 if not cte.args["alias"].columns: 378 query = cte.this 379 if isinstance(query, exp.Union): 380 query = query.this 381 382 cte.args["alias"].set( 383 "columns", 384 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 385 ) 386 387 return expression 388 389 390def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 391 """Replace 'epoch' in casts by the equivalent date literal.""" 392 if ( 393 isinstance(expression, (exp.Cast, exp.TryCast)) 394 and expression.name.lower() == "epoch" 395 and expression.to.this in exp.DataType.TEMPORAL_TYPES 396 ): 397 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 398 399 return expression 400 401 402def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 403 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 404 if isinstance(expression, exp.Select): 405 for join in expression.args.get("joins") or []: 406 on = join.args.get("on") 407 if on and join.kind in ("SEMI", "ANTI"): 408 subquery = exp.select("1").from_(join.this).where(on) 409 exists = exp.Exists(this=subquery) 410 if join.kind == "ANTI": 411 exists = exists.not_(copy=False) 412 413 join.pop() 414 expression.where(exists, copy=False) 415 416 return expression 417 418 419def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 420 """ 421 Converts a query with a FULL OUTER join to a union of identical queries that 422 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 423 for queries that have a single FULL OUTER join. 424 """ 425 if isinstance(expression, exp.Select): 426 full_outer_joins = [ 427 (index, join) 428 for index, join in enumerate(expression.args.get("joins") or []) 429 if join.side == "FULL" 430 ] 431 432 if len(full_outer_joins) == 1: 433 expression_copy = expression.copy() 434 expression.set("limit", None) 435 index, full_outer_join = full_outer_joins[0] 436 full_outer_join.set("side", "left") 437 expression_copy.args["joins"][index].set("side", "right") 438 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 439 440 return exp.union(expression, expression_copy, copy=False) 441 442 return expression 443 444 445def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 446 """ 447 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 448 defined at the top-level, so for example queries like: 449 450 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 451 452 are invalid in those dialects. This transformation can be used to ensure all CTEs are 453 moved to the top level so that the final SQL code is valid from a syntax standpoint. 454 455 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 456 """ 457 top_level_with = expression.args.get("with") 458 for node in expression.find_all(exp.With): 459 if node.parent is expression: 460 continue 461 462 inner_with = node.pop() 463 if not top_level_with: 464 top_level_with = inner_with 465 expression.set("with", top_level_with) 466 else: 467 if inner_with.recursive: 468 top_level_with.set("recursive", True) 469 470 top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions) 471 472 return expression 473 474 475def ensure_bools(expression: exp.Expression) -> exp.Expression: 476 """Converts numeric values used in conditions into explicit boolean expressions.""" 477 from sqlglot.optimizer.canonicalize import ensure_bools 478 479 def _ensure_bool(node: exp.Expression) -> None: 480 if ( 481 node.is_number 482 or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 483 or (isinstance(node, exp.Column) and not node.type) 484 ): 485 node.replace(node.neq(0)) 486 487 for node in expression.walk(): 488 ensure_bools(node, _ensure_bool) 489 490 return expression 491 492 493def unqualify_columns(expression: exp.Expression) -> exp.Expression: 494 for column in expression.find_all(exp.Column): 495 # We only wanna pop off the table, db, catalog args 496 for part in column.parts[:-1]: 497 part.pop() 498 499 return expression 500 501 502def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 503 assert isinstance(expression, exp.Create) 504 for constraint in expression.find_all(exp.UniqueColumnConstraint): 505 if constraint.parent: 506 constraint.parent.pop() 507 508 return expression 509 510 511def ctas_with_tmp_tables_to_create_tmp_view( 512 expression: exp.Expression, 513 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 514) -> exp.Expression: 515 assert isinstance(expression, exp.Create) 516 properties = expression.args.get("properties") 517 temporary = any( 518 isinstance(prop, exp.TemporaryProperty) 519 for prop in (properties.expressions if properties else []) 520 ) 521 522 # CTAS with temp tables map to CREATE TEMPORARY VIEW 523 if expression.kind == "TABLE" and temporary: 524 if expression.expression: 525 return exp.Create( 526 kind="TEMPORARY VIEW", 527 this=expression.this, 528 expression=expression.expression, 529 ) 530 return tmp_storage_provider(expression) 531 532 return expression 533 534 535def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 536 """ 537 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 538 PARTITIONED BY value is an array of column names, they are transformed into a schema. 539 The corresponding columns are removed from the create statement. 540 """ 541 assert isinstance(expression, exp.Create) 542 has_schema = isinstance(expression.this, exp.Schema) 543 is_partitionable = expression.kind in {"TABLE", "VIEW"} 544 545 if has_schema and is_partitionable: 546 prop = expression.find(exp.PartitionedByProperty) 547 if prop and prop.this and not isinstance(prop.this, exp.Schema): 548 schema = expression.this 549 columns = {v.name.upper() for v in prop.this.expressions} 550 partitions = [col for col in schema.expressions if col.name.upper() in columns] 551 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 552 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 553 expression.set("this", schema) 554 555 return expression 556 557 558def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 559 """ 560 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 561 562 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 563 """ 564 assert isinstance(expression, exp.Create) 565 prop = expression.find(exp.PartitionedByProperty) 566 if ( 567 prop 568 and prop.this 569 and isinstance(prop.this, exp.Schema) 570 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 571 ): 572 prop_this = exp.Tuple( 573 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 574 ) 575 schema = expression.this 576 for e in prop.this.expressions: 577 schema.append("expressions", e) 578 prop.set("this", prop_this) 579 580 return expression 581 582 583def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 584 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 585 if isinstance(expression, exp.Struct): 586 expression.set( 587 "expressions", 588 [ 589 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 590 for e in expression.expressions 591 ], 592 ) 593 594 return expression 595 596 597def preprocess( 598 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 599) -> t.Callable[[Generator, exp.Expression], str]: 600 """ 601 Creates a new transform by chaining a sequence of transformations and converts the resulting 602 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 603 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 604 605 Args: 606 transforms: sequence of transform functions. These will be called in order. 607 608 Returns: 609 Function that can be used as a generator transform. 610 """ 611 612 def _to_sql(self, expression: exp.Expression) -> str: 613 expression_type = type(expression) 614 615 expression = transforms[0](expression) 616 for transform in transforms[1:]: 617 expression = transform(expression) 618 619 _sql_handler = getattr(self, expression.key + "_sql", None) 620 if _sql_handler: 621 return _sql_handler(expression) 622 623 transforms_handler = self.TRANSFORMS.get(type(expression)) 624 if transforms_handler: 625 if expression_type is type(expression): 626 if isinstance(expression, exp.Func): 627 return self.function_fallback_sql(expression) 628 629 # Ensures we don't enter an infinite loop. This can happen when the original expression 630 # has the same type as the final expression and there's no _sql method available for it, 631 # because then it'd re-enter _to_sql. 632 raise ValueError( 633 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 634 ) 635 636 return transforms_handler(self, expression) 637 638 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 639 640 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop()) 72 else: 73 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 74 75 window = exp.alias_(window, row_number) 76 expression.select(window, copy=False) 77 78 return ( 79 exp.select(*outer_selects, copy=False) 80 .from_(expression.subquery("_t", copy=False), copy=False) 81 .where(exp.column(row_number).eq(1), copy=False) 82 ) 83 84 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
87def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 88 """ 89 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 90 91 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 92 https://docs.snowflake.com/en/sql-reference/constructs/qualify 93 94 Some dialects don't support window functions in the WHERE clause, so we need to include them as 95 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 96 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 97 otherwise we won't be able to refer to it in the outer query's WHERE clause. 98 """ 99 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 100 taken = set(expression.named_selects) 101 for select in expression.selects: 102 if not select.alias_or_name: 103 alias = find_new_name(taken, "_c") 104 select.replace(exp.alias_(select, alias)) 105 taken.add(alias) 106 107 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 108 qualify_filters = expression.args["qualify"].pop().this 109 110 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 111 for expr in qualify_filters.find_all(select_candidates): 112 if isinstance(expr, exp.Window): 113 alias = find_new_name(expression.named_selects, "_w") 114 expression.select(exp.alias_(expr, alias), copy=False) 115 column = exp.column(alias) 116 117 if isinstance(expr.parent, exp.Qualify): 118 qualify_filters = column 119 else: 120 expr.replace(column) 121 elif expr.name not in expression.named_selects: 122 expression.select(expr.copy(), copy=False) 123 124 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 125 qualify_filters, copy=False 126 ) 127 128 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.
131def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 132 """ 133 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 134 other expressions. This transforms removes the precision from parameterized types in expressions. 135 """ 136 for node in expression.find_all(exp.DataType): 137 node.set( 138 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 139 ) 140 141 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
144def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 145 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 146 from sqlglot.optimizer.scope import find_all_in_scope 147 148 if isinstance(expression, exp.Select): 149 unnest_aliases = { 150 unnest.alias 151 for unnest in find_all_in_scope(expression, exp.Unnest) 152 if isinstance(unnest.parent, (exp.From, exp.Join)) 153 } 154 if unnest_aliases: 155 for column in expression.find_all(exp.Column): 156 if column.table in unnest_aliases: 157 column.set("table", None) 158 elif column.db in unnest_aliases: 159 column.set("db", None) 160 161 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
164def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 165 """Convert cross join unnest into lateral view explode.""" 166 if isinstance(expression, exp.Select): 167 for join in expression.args.get("joins") or []: 168 unnest = join.this 169 170 if isinstance(unnest, exp.Unnest): 171 alias = unnest.args.get("alias") 172 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 173 174 expression.args["joins"].remove(join) 175 176 for e, column in zip(unnest.expressions, alias.columns if alias else []): 177 expression.append( 178 "laterals", 179 exp.Lateral( 180 this=udtf(this=e), 181 view=True, 182 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 183 ), 184 ) 185 186 return expression
Convert cross join unnest into lateral view explode.
189def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 190 """Convert explode/posexplode into unnest.""" 191 192 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 193 if isinstance(expression, exp.Select): 194 from sqlglot.optimizer.scope import Scope 195 196 taken_select_names = set(expression.named_selects) 197 taken_source_names = {name for name, _ in Scope(expression).references} 198 199 def new_name(names: t.Set[str], name: str) -> str: 200 name = find_new_name(names, name) 201 names.add(name) 202 return name 203 204 arrays: t.List[exp.Condition] = [] 205 series_alias = new_name(taken_select_names, "pos") 206 series = exp.alias_( 207 exp.Unnest( 208 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 209 ), 210 new_name(taken_source_names, "_u"), 211 table=[series_alias], 212 ) 213 214 # we use list here because expression.selects is mutated inside the loop 215 for select in list(expression.selects): 216 explode = select.find(exp.Explode) 217 218 if explode: 219 pos_alias = "" 220 explode_alias = "" 221 222 if isinstance(select, exp.Alias): 223 explode_alias = select.args["alias"] 224 alias = select 225 elif isinstance(select, exp.Aliases): 226 pos_alias = select.aliases[0] 227 explode_alias = select.aliases[1] 228 alias = select.replace(exp.alias_(select.this, "", copy=False)) 229 else: 230 alias = select.replace(exp.alias_(select, "")) 231 explode = alias.find(exp.Explode) 232 assert explode 233 234 is_posexplode = isinstance(explode, exp.Posexplode) 235 explode_arg = explode.this 236 237 if isinstance(explode, exp.ExplodeOuter): 238 bracket = explode_arg[0] 239 bracket.set("safe", True) 240 bracket.set("offset", True) 241 explode_arg = exp.func( 242 "IF", 243 exp.func( 244 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 245 ).eq(0), 246 exp.array(bracket, copy=False), 247 explode_arg, 248 ) 249 250 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 251 if isinstance(explode_arg, exp.Column): 252 taken_select_names.add(explode_arg.output_name) 253 254 unnest_source_alias = new_name(taken_source_names, "_u") 255 256 if not explode_alias: 257 explode_alias = new_name(taken_select_names, "col") 258 259 if is_posexplode: 260 pos_alias = new_name(taken_select_names, "pos") 261 262 if not pos_alias: 263 pos_alias = new_name(taken_select_names, "pos") 264 265 alias.set("alias", exp.to_identifier(explode_alias)) 266 267 series_table_alias = series.args["alias"].this 268 column = exp.If( 269 this=exp.column(series_alias, table=series_table_alias).eq( 270 exp.column(pos_alias, table=unnest_source_alias) 271 ), 272 true=exp.column(explode_alias, table=unnest_source_alias), 273 ) 274 275 explode.replace(column) 276 277 if is_posexplode: 278 expressions = expression.expressions 279 expressions.insert( 280 expressions.index(alias) + 1, 281 exp.If( 282 this=exp.column(series_alias, table=series_table_alias).eq( 283 exp.column(pos_alias, table=unnest_source_alias) 284 ), 285 true=exp.column(pos_alias, table=unnest_source_alias), 286 ).as_(pos_alias), 287 ) 288 expression.set("expressions", expressions) 289 290 if not arrays: 291 if expression.args.get("from"): 292 expression.join(series, copy=False, join_type="CROSS") 293 else: 294 expression.from_(series, copy=False) 295 296 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 297 arrays.append(size) 298 299 # trino doesn't support left join unnest with on conditions 300 # if it did, this would be much simpler 301 expression.join( 302 exp.alias_( 303 exp.Unnest( 304 expressions=[explode_arg.copy()], 305 offset=exp.to_identifier(pos_alias), 306 ), 307 unnest_source_alias, 308 table=[explode_alias], 309 ), 310 join_type="CROSS", 311 copy=False, 312 ) 313 314 if index_offset != 1: 315 size = size - 1 316 317 expression.where( 318 exp.column(series_alias, table=series_table_alias) 319 .eq(exp.column(pos_alias, table=unnest_source_alias)) 320 .or_( 321 (exp.column(series_alias, table=series_table_alias) > size).and_( 322 exp.column(pos_alias, table=unnest_source_alias).eq(size) 323 ) 324 ), 325 copy=False, 326 ) 327 328 if arrays: 329 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 330 331 if index_offset != 1: 332 end = end - (1 - index_offset) 333 series.expressions[0].set("end", end) 334 335 return expression 336 337 return _explode_to_unnest
Convert explode/posexplode into unnest.
343def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 344 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 345 if ( 346 isinstance(expression, PERCENTILES) 347 and not isinstance(expression.parent, exp.WithinGroup) 348 and expression.expression 349 ): 350 column = expression.this.pop() 351 expression.set("this", expression.expression.pop()) 352 order = exp.Order(expressions=[exp.Ordered(this=column)]) 353 expression = exp.WithinGroup(this=expression, expression=order) 354 355 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
358def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 359 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 360 if ( 361 isinstance(expression, exp.WithinGroup) 362 and isinstance(expression.this, PERCENTILES) 363 and isinstance(expression.expression, exp.Order) 364 ): 365 quantile = expression.this.this 366 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 367 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 368 369 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
372def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 373 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 374 if isinstance(expression, exp.With) and expression.recursive: 375 next_name = name_sequence("_c_") 376 377 for cte in expression.expressions: 378 if not cte.args["alias"].columns: 379 query = cte.this 380 if isinstance(query, exp.Union): 381 query = query.this 382 383 cte.args["alias"].set( 384 "columns", 385 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 386 ) 387 388 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
391def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 392 """Replace 'epoch' in casts by the equivalent date literal.""" 393 if ( 394 isinstance(expression, (exp.Cast, exp.TryCast)) 395 and expression.name.lower() == "epoch" 396 and expression.to.this in exp.DataType.TEMPORAL_TYPES 397 ): 398 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 399 400 return expression
Replace 'epoch' in casts by the equivalent date literal.
403def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 404 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 405 if isinstance(expression, exp.Select): 406 for join in expression.args.get("joins") or []: 407 on = join.args.get("on") 408 if on and join.kind in ("SEMI", "ANTI"): 409 subquery = exp.select("1").from_(join.this).where(on) 410 exists = exp.Exists(this=subquery) 411 if join.kind == "ANTI": 412 exists = exists.not_(copy=False) 413 414 join.pop() 415 expression.where(exists, copy=False) 416 417 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
420def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 421 """ 422 Converts a query with a FULL OUTER join to a union of identical queries that 423 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 424 for queries that have a single FULL OUTER join. 425 """ 426 if isinstance(expression, exp.Select): 427 full_outer_joins = [ 428 (index, join) 429 for index, join in enumerate(expression.args.get("joins") or []) 430 if join.side == "FULL" 431 ] 432 433 if len(full_outer_joins) == 1: 434 expression_copy = expression.copy() 435 expression.set("limit", None) 436 index, full_outer_join = full_outer_joins[0] 437 full_outer_join.set("side", "left") 438 expression_copy.args["joins"][index].set("side", "right") 439 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 440 441 return exp.union(expression, expression_copy, copy=False) 442 443 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
446def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 447 """ 448 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 449 defined at the top-level, so for example queries like: 450 451 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 452 453 are invalid in those dialects. This transformation can be used to ensure all CTEs are 454 moved to the top level so that the final SQL code is valid from a syntax standpoint. 455 456 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 457 """ 458 top_level_with = expression.args.get("with") 459 for node in expression.find_all(exp.With): 460 if node.parent is expression: 461 continue 462 463 inner_with = node.pop() 464 if not top_level_with: 465 top_level_with = inner_with 466 expression.set("with", top_level_with) 467 else: 468 if inner_with.recursive: 469 top_level_with.set("recursive", True) 470 471 top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions) 472 473 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
476def ensure_bools(expression: exp.Expression) -> exp.Expression: 477 """Converts numeric values used in conditions into explicit boolean expressions.""" 478 from sqlglot.optimizer.canonicalize import ensure_bools 479 480 def _ensure_bool(node: exp.Expression) -> None: 481 if ( 482 node.is_number 483 or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 484 or (isinstance(node, exp.Column) and not node.type) 485 ): 486 node.replace(node.neq(0)) 487 488 for node in expression.walk(): 489 ensure_bools(node, _ensure_bool) 490 491 return expression
Converts numeric values used in conditions into explicit boolean expressions.
512def ctas_with_tmp_tables_to_create_tmp_view( 513 expression: exp.Expression, 514 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 515) -> exp.Expression: 516 assert isinstance(expression, exp.Create) 517 properties = expression.args.get("properties") 518 temporary = any( 519 isinstance(prop, exp.TemporaryProperty) 520 for prop in (properties.expressions if properties else []) 521 ) 522 523 # CTAS with temp tables map to CREATE TEMPORARY VIEW 524 if expression.kind == "TABLE" and temporary: 525 if expression.expression: 526 return exp.Create( 527 kind="TEMPORARY VIEW", 528 this=expression.this, 529 expression=expression.expression, 530 ) 531 return tmp_storage_provider(expression) 532 533 return expression
536def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 537 """ 538 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 539 PARTITIONED BY value is an array of column names, they are transformed into a schema. 540 The corresponding columns are removed from the create statement. 541 """ 542 assert isinstance(expression, exp.Create) 543 has_schema = isinstance(expression.this, exp.Schema) 544 is_partitionable = expression.kind in {"TABLE", "VIEW"} 545 546 if has_schema and is_partitionable: 547 prop = expression.find(exp.PartitionedByProperty) 548 if prop and prop.this and not isinstance(prop.this, exp.Schema): 549 schema = expression.this 550 columns = {v.name.upper() for v in prop.this.expressions} 551 partitions = [col for col in schema.expressions if col.name.upper() in columns] 552 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 553 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 554 expression.set("this", schema) 555 556 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
559def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 560 """ 561 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 562 563 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 564 """ 565 assert isinstance(expression, exp.Create) 566 prop = expression.find(exp.PartitionedByProperty) 567 if ( 568 prop 569 and prop.this 570 and isinstance(prop.this, exp.Schema) 571 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 572 ): 573 prop_this = exp.Tuple( 574 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 575 ) 576 schema = expression.this 577 for e in prop.this.expressions: 578 schema.append("expressions", e) 579 prop.set("this", prop_this) 580 581 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
584def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 585 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 586 if isinstance(expression, exp.Struct): 587 expression.set( 588 "expressions", 589 [ 590 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 591 for e in expression.expressions 592 ], 593 ) 594 595 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
598def preprocess( 599 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 600) -> t.Callable[[Generator, exp.Expression], str]: 601 """ 602 Creates a new transform by chaining a sequence of transformations and converts the resulting 603 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 604 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 605 606 Args: 607 transforms: sequence of transform functions. These will be called in order. 608 609 Returns: 610 Function that can be used as a generator transform. 611 """ 612 613 def _to_sql(self, expression: exp.Expression) -> str: 614 expression_type = type(expression) 615 616 expression = transforms[0](expression) 617 for transform in transforms[1:]: 618 expression = transform(expression) 619 620 _sql_handler = getattr(self, expression.key + "_sql", None) 621 if _sql_handler: 622 return _sql_handler(expression) 623 624 transforms_handler = self.TRANSFORMS.get(type(expression)) 625 if transforms_handler: 626 if expression_type is type(expression): 627 if isinstance(expression, exp.Func): 628 return self.function_fallback_sql(expression) 629 630 # Ensures we don't enter an infinite loop. This can happen when the original expression 631 # has the same type as the final expression and there's no _sql method available for it, 632 # because then it'd re-enter _to_sql. 633 raise ValueError( 634 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 635 ) 636 637 return transforms_handler(self, expression) 638 639 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 640 641 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.