Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.errors import UnsupportedError
  7from sqlglot.helper import find_new_name, name_sequence
  8
  9
 10if t.TYPE_CHECKING:
 11    from sqlglot._typing import E
 12    from sqlglot.generator import Generator
 13
 14
 15def preprocess(
 16    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
 17) -> t.Callable[[Generator, exp.Expression], str]:
 18    """
 19    Creates a new transform by chaining a sequence of transformations and converts the resulting
 20    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
 21    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
 22
 23    Args:
 24        transforms: sequence of transform functions. These will be called in order.
 25
 26    Returns:
 27        Function that can be used as a generator transform.
 28    """
 29
 30    def _to_sql(self, expression: exp.Expression) -> str:
 31        expression_type = type(expression)
 32
 33        try:
 34            expression = transforms[0](expression)
 35            for transform in transforms[1:]:
 36                expression = transform(expression)
 37        except UnsupportedError as unsupported_error:
 38            self.unsupported(str(unsupported_error))
 39
 40        _sql_handler = getattr(self, expression.key + "_sql", None)
 41        if _sql_handler:
 42            return _sql_handler(expression)
 43
 44        transforms_handler = self.TRANSFORMS.get(type(expression))
 45        if transforms_handler:
 46            if expression_type is type(expression):
 47                if isinstance(expression, exp.Func):
 48                    return self.function_fallback_sql(expression)
 49
 50                # Ensures we don't enter an infinite loop. This can happen when the original expression
 51                # has the same type as the final expression and there's no _sql method available for it,
 52                # because then it'd re-enter _to_sql.
 53                raise ValueError(
 54                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
 55                )
 56
 57            return transforms_handler(self, expression)
 58
 59        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
 60
 61    return _to_sql
 62
 63
 64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 65    if isinstance(expression, exp.Select):
 66        count = 0
 67        recursive_ctes = []
 68
 69        for unnest in expression.find_all(exp.Unnest):
 70            if (
 71                not isinstance(unnest.parent, (exp.From, exp.Join))
 72                or len(unnest.expressions) != 1
 73                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 74            ):
 75                continue
 76
 77            generate_date_array = unnest.expressions[0]
 78            start = generate_date_array.args.get("start")
 79            end = generate_date_array.args.get("end")
 80            step = generate_date_array.args.get("step")
 81
 82            if not start or not end or not isinstance(step, exp.Interval):
 83                continue
 84
 85            alias = unnest.args.get("alias")
 86            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 87
 88            start = exp.cast(start, "date")
 89            date_add = exp.func(
 90                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 91            )
 92            cast_date_add = exp.cast(date_add, "date")
 93
 94            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 95
 96            base_query = exp.select(start.as_(column_name))
 97            recursive_query = (
 98                exp.select(cast_date_add)
 99                .from_(cte_name)
100                .where(cast_date_add <= exp.cast(end, "date"))
101            )
102            cte_query = base_query.union(recursive_query, distinct=False)
103
104            generate_dates_query = exp.select(column_name).from_(cte_name)
105            unnest.replace(generate_dates_query.subquery(cte_name))
106
107            recursive_ctes.append(
108                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
109            )
110            count += 1
111
112        if recursive_ctes:
113            with_expression = expression.args.get("with") or exp.With()
114            with_expression.set("recursive", True)
115            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
116            expression.set("with", with_expression)
117
118    return expression
119
120
121def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
122    """Unnests GENERATE_SERIES or SEQUENCE table references."""
123    this = expression.this
124    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
125        unnest = exp.Unnest(expressions=[this])
126        if expression.alias:
127            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
128
129        return unnest
130
131    return expression
132
133
134def unalias_group(expression: exp.Expression) -> exp.Expression:
135    """
136    Replace references to select aliases in GROUP BY clauses.
137
138    Example:
139        >>> import sqlglot
140        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
141        'SELECT a AS b FROM x GROUP BY 1'
142
143    Args:
144        expression: the expression that will be transformed.
145
146    Returns:
147        The transformed expression.
148    """
149    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
150        aliased_selects = {
151            e.alias: i
152            for i, e in enumerate(expression.parent.expressions, start=1)
153            if isinstance(e, exp.Alias)
154        }
155
156        for group_by in expression.expressions:
157            if (
158                isinstance(group_by, exp.Column)
159                and not group_by.table
160                and group_by.name in aliased_selects
161            ):
162                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
163
164    return expression
165
166
167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
168    """
169    Convert SELECT DISTINCT ON statements to a subquery with a window function.
170
171    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
172
173    Args:
174        expression: the expression that will be transformed.
175
176    Returns:
177        The transformed expression.
178    """
179    if (
180        isinstance(expression, exp.Select)
181        and expression.args.get("distinct")
182        and expression.args["distinct"].args.get("on")
183        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
184    ):
185        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
186        row_number = find_new_name(expression.named_selects, "_row_number")
187        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
188        order = expression.args.get("order")
189
190        if order:
191            window.set("order", order.pop())
192        else:
193            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
194
195        window = exp.alias_(window, row_number)
196        expression.select(window, copy=False)
197
198        return (
199            exp.select("*", copy=False)
200            .from_(expression.subquery("_t", copy=False), copy=False)
201            .where(exp.column(row_number).eq(1), copy=False)
202        )
203
204    return expression
205
206
207def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
208    """
209    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
210
211    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
212    https://docs.snowflake.com/en/sql-reference/constructs/qualify
213
214    Some dialects don't support window functions in the WHERE clause, so we need to include them as
215    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
216    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
217    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
218    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
219    corresponding expression to avoid creating invalid column references.
220    """
221    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
222        taken = set(expression.named_selects)
223        for select in expression.selects:
224            if not select.alias_or_name:
225                alias = find_new_name(taken, "_c")
226                select.replace(exp.alias_(select, alias))
227                taken.add(alias)
228
229        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
230            alias_or_name = select.alias_or_name
231            identifier = select.args.get("alias") or select.this
232            if isinstance(identifier, exp.Identifier):
233                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
234            return alias_or_name
235
236        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
237        qualify_filters = expression.args["qualify"].pop().this
238        expression_by_alias = {
239            select.alias: select.this
240            for select in expression.selects
241            if isinstance(select, exp.Alias)
242        }
243
244        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
245        for select_candidate in qualify_filters.find_all(select_candidates):
246            if isinstance(select_candidate, exp.Window):
247                if expression_by_alias:
248                    for column in select_candidate.find_all(exp.Column):
249                        expr = expression_by_alias.get(column.name)
250                        if expr:
251                            column.replace(expr)
252
253                alias = find_new_name(expression.named_selects, "_w")
254                expression.select(exp.alias_(select_candidate, alias), copy=False)
255                column = exp.column(alias)
256
257                if isinstance(select_candidate.parent, exp.Qualify):
258                    qualify_filters = column
259                else:
260                    select_candidate.replace(column)
261            elif select_candidate.name not in expression.named_selects:
262                expression.select(select_candidate.copy(), copy=False)
263
264        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
265            qualify_filters, copy=False
266        )
267
268    return expression
269
270
271def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
272    """
273    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
274    other expressions. This transforms removes the precision from parameterized types in expressions.
275    """
276    for node in expression.find_all(exp.DataType):
277        node.set(
278            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
279        )
280
281    return expression
282
283
284def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
285    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
286    from sqlglot.optimizer.scope import find_all_in_scope
287
288    if isinstance(expression, exp.Select):
289        unnest_aliases = {
290            unnest.alias
291            for unnest in find_all_in_scope(expression, exp.Unnest)
292            if isinstance(unnest.parent, (exp.From, exp.Join))
293        }
294        if unnest_aliases:
295            for column in expression.find_all(exp.Column):
296                if column.table in unnest_aliases:
297                    column.set("table", None)
298                elif column.db in unnest_aliases:
299                    column.set("db", None)
300
301    return expression
302
303
304def unnest_to_explode(
305    expression: exp.Expression,
306    unnest_using_arrays_zip: bool = True,
307) -> exp.Expression:
308    """Convert cross join unnest into lateral view explode."""
309
310    def _unnest_zip_exprs(
311        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
312    ) -> t.List[exp.Expression]:
313        if has_multi_expr:
314            if not unnest_using_arrays_zip:
315                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
316
317            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
318            zip_exprs: t.List[exp.Expression] = [
319                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
320            ]
321            u.set("expressions", zip_exprs)
322            return zip_exprs
323        return unnest_exprs
324
325    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
326        if u.args.get("offset"):
327            return exp.Posexplode
328        return exp.Inline if has_multi_expr else exp.Explode
329
330    if isinstance(expression, exp.Select):
331        from_ = expression.args.get("from")
332
333        if from_ and isinstance(from_.this, exp.Unnest):
334            unnest = from_.this
335            alias = unnest.args.get("alias")
336            exprs = unnest.expressions
337            has_multi_expr = len(exprs) > 1
338            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
339
340            unnest.replace(
341                exp.Table(
342                    this=_udtf_type(unnest, has_multi_expr)(
343                        this=this,
344                        expressions=expressions,
345                    ),
346                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
347                )
348            )
349
350        joins = expression.args.get("joins") or []
351        for join in list(joins):
352            join_expr = join.this
353
354            is_lateral = isinstance(join_expr, exp.Lateral)
355
356            unnest = join_expr.this if is_lateral else join_expr
357
358            if isinstance(unnest, exp.Unnest):
359                if is_lateral:
360                    alias = join_expr.args.get("alias")
361                else:
362                    alias = unnest.args.get("alias")
363                exprs = unnest.expressions
364                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
365                has_multi_expr = len(exprs) > 1
366                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
367
368                joins.remove(join)
369
370                alias_cols = alias.columns if alias else []
371
372                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
373                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
374                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
375
376                if not has_multi_expr and len(alias_cols) not in (1, 2):
377                    raise UnsupportedError(
378                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
379                    )
380
381                for e, column in zip(exprs, alias_cols):
382                    expression.append(
383                        "laterals",
384                        exp.Lateral(
385                            this=_udtf_type(unnest, has_multi_expr)(this=e),
386                            view=True,
387                            alias=exp.TableAlias(
388                                this=alias.this,  # type: ignore
389                                columns=alias_cols,
390                            ),
391                        ),
392                    )
393
394    return expression
395
396
397def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
398    """Convert explode/posexplode into unnest."""
399
400    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
401        if isinstance(expression, exp.Select):
402            from sqlglot.optimizer.scope import Scope
403
404            taken_select_names = set(expression.named_selects)
405            taken_source_names = {name for name, _ in Scope(expression).references}
406
407            def new_name(names: t.Set[str], name: str) -> str:
408                name = find_new_name(names, name)
409                names.add(name)
410                return name
411
412            arrays: t.List[exp.Condition] = []
413            series_alias = new_name(taken_select_names, "pos")
414            series = exp.alias_(
415                exp.Unnest(
416                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
417                ),
418                new_name(taken_source_names, "_u"),
419                table=[series_alias],
420            )
421
422            # we use list here because expression.selects is mutated inside the loop
423            for select in list(expression.selects):
424                explode = select.find(exp.Explode)
425
426                if explode:
427                    pos_alias = ""
428                    explode_alias = ""
429
430                    if isinstance(select, exp.Alias):
431                        explode_alias = select.args["alias"]
432                        alias = select
433                    elif isinstance(select, exp.Aliases):
434                        pos_alias = select.aliases[0]
435                        explode_alias = select.aliases[1]
436                        alias = select.replace(exp.alias_(select.this, "", copy=False))
437                    else:
438                        alias = select.replace(exp.alias_(select, ""))
439                        explode = alias.find(exp.Explode)
440                        assert explode
441
442                    is_posexplode = isinstance(explode, exp.Posexplode)
443                    explode_arg = explode.this
444
445                    if isinstance(explode, exp.ExplodeOuter):
446                        bracket = explode_arg[0]
447                        bracket.set("safe", True)
448                        bracket.set("offset", True)
449                        explode_arg = exp.func(
450                            "IF",
451                            exp.func(
452                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
453                            ).eq(0),
454                            exp.array(bracket, copy=False),
455                            explode_arg,
456                        )
457
458                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
459                    if isinstance(explode_arg, exp.Column):
460                        taken_select_names.add(explode_arg.output_name)
461
462                    unnest_source_alias = new_name(taken_source_names, "_u")
463
464                    if not explode_alias:
465                        explode_alias = new_name(taken_select_names, "col")
466
467                        if is_posexplode:
468                            pos_alias = new_name(taken_select_names, "pos")
469
470                    if not pos_alias:
471                        pos_alias = new_name(taken_select_names, "pos")
472
473                    alias.set("alias", exp.to_identifier(explode_alias))
474
475                    series_table_alias = series.args["alias"].this
476                    column = exp.If(
477                        this=exp.column(series_alias, table=series_table_alias).eq(
478                            exp.column(pos_alias, table=unnest_source_alias)
479                        ),
480                        true=exp.column(explode_alias, table=unnest_source_alias),
481                    )
482
483                    explode.replace(column)
484
485                    if is_posexplode:
486                        expressions = expression.expressions
487                        expressions.insert(
488                            expressions.index(alias) + 1,
489                            exp.If(
490                                this=exp.column(series_alias, table=series_table_alias).eq(
491                                    exp.column(pos_alias, table=unnest_source_alias)
492                                ),
493                                true=exp.column(pos_alias, table=unnest_source_alias),
494                            ).as_(pos_alias),
495                        )
496                        expression.set("expressions", expressions)
497
498                    if not arrays:
499                        if expression.args.get("from"):
500                            expression.join(series, copy=False, join_type="CROSS")
501                        else:
502                            expression.from_(series, copy=False)
503
504                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
505                    arrays.append(size)
506
507                    # trino doesn't support left join unnest with on conditions
508                    # if it did, this would be much simpler
509                    expression.join(
510                        exp.alias_(
511                            exp.Unnest(
512                                expressions=[explode_arg.copy()],
513                                offset=exp.to_identifier(pos_alias),
514                            ),
515                            unnest_source_alias,
516                            table=[explode_alias],
517                        ),
518                        join_type="CROSS",
519                        copy=False,
520                    )
521
522                    if index_offset != 1:
523                        size = size - 1
524
525                    expression.where(
526                        exp.column(series_alias, table=series_table_alias)
527                        .eq(exp.column(pos_alias, table=unnest_source_alias))
528                        .or_(
529                            (exp.column(series_alias, table=series_table_alias) > size).and_(
530                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
531                            )
532                        ),
533                        copy=False,
534                    )
535
536            if arrays:
537                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
538
539                if index_offset != 1:
540                    end = end - (1 - index_offset)
541                series.expressions[0].set("end", end)
542
543        return expression
544
545    return _explode_to_unnest
546
547
548def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
549    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
550    if (
551        isinstance(expression, exp.PERCENTILES)
552        and not isinstance(expression.parent, exp.WithinGroup)
553        and expression.expression
554    ):
555        column = expression.this.pop()
556        expression.set("this", expression.expression.pop())
557        order = exp.Order(expressions=[exp.Ordered(this=column)])
558        expression = exp.WithinGroup(this=expression, expression=order)
559
560    return expression
561
562
563def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
564    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
565    if (
566        isinstance(expression, exp.WithinGroup)
567        and isinstance(expression.this, exp.PERCENTILES)
568        and isinstance(expression.expression, exp.Order)
569    ):
570        quantile = expression.this.this
571        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
572        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
573
574    return expression
575
576
577def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
578    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
579    if isinstance(expression, exp.With) and expression.recursive:
580        next_name = name_sequence("_c_")
581
582        for cte in expression.expressions:
583            if not cte.args["alias"].columns:
584                query = cte.this
585                if isinstance(query, exp.SetOperation):
586                    query = query.this
587
588                cte.args["alias"].set(
589                    "columns",
590                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
591                )
592
593    return expression
594
595
596def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
597    """Replace 'epoch' in casts by the equivalent date literal."""
598    if (
599        isinstance(expression, (exp.Cast, exp.TryCast))
600        and expression.name.lower() == "epoch"
601        and expression.to.this in exp.DataType.TEMPORAL_TYPES
602    ):
603        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
604
605    return expression
606
607
608def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
609    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
610    if isinstance(expression, exp.Select):
611        for join in expression.args.get("joins") or []:
612            on = join.args.get("on")
613            if on and join.kind in ("SEMI", "ANTI"):
614                subquery = exp.select("1").from_(join.this).where(on)
615                exists = exp.Exists(this=subquery)
616                if join.kind == "ANTI":
617                    exists = exists.not_(copy=False)
618
619                join.pop()
620                expression.where(exists, copy=False)
621
622    return expression
623
624
625def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
626    """
627    Converts a query with a FULL OUTER join to a union of identical queries that
628    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
629    for queries that have a single FULL OUTER join.
630    """
631    if isinstance(expression, exp.Select):
632        full_outer_joins = [
633            (index, join)
634            for index, join in enumerate(expression.args.get("joins") or [])
635            if join.side == "FULL"
636        ]
637
638        if len(full_outer_joins) == 1:
639            expression_copy = expression.copy()
640            expression.set("limit", None)
641            index, full_outer_join = full_outer_joins[0]
642
643            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
644            join_conditions = full_outer_join.args.get("on") or exp.and_(
645                *[
646                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
647                    for col in full_outer_join.args.get("using")
648                ]
649            )
650
651            full_outer_join.set("side", "left")
652            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
653            expression_copy.args["joins"][index].set("side", "right")
654            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
655            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
656            expression.args.pop("order", None)  # remove order by from LEFT side
657
658            return exp.union(expression, expression_copy, copy=False, distinct=False)
659
660    return expression
661
662
663def move_ctes_to_top_level(expression: E) -> E:
664    """
665    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
666    defined at the top-level, so for example queries like:
667
668        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
669
670    are invalid in those dialects. This transformation can be used to ensure all CTEs are
671    moved to the top level so that the final SQL code is valid from a syntax standpoint.
672
673    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
674    """
675    top_level_with = expression.args.get("with")
676    for inner_with in expression.find_all(exp.With):
677        if inner_with.parent is expression:
678            continue
679
680        if not top_level_with:
681            top_level_with = inner_with.pop()
682            expression.set("with", top_level_with)
683        else:
684            if inner_with.recursive:
685                top_level_with.set("recursive", True)
686
687            parent_cte = inner_with.find_ancestor(exp.CTE)
688            inner_with.pop()
689
690            if parent_cte:
691                i = top_level_with.expressions.index(parent_cte)
692                top_level_with.expressions[i:i] = inner_with.expressions
693                top_level_with.set("expressions", top_level_with.expressions)
694            else:
695                top_level_with.set(
696                    "expressions", top_level_with.expressions + inner_with.expressions
697                )
698
699    return expression
700
701
702def ensure_bools(expression: exp.Expression) -> exp.Expression:
703    """Converts numeric values used in conditions into explicit boolean expressions."""
704    from sqlglot.optimizer.canonicalize import ensure_bools
705
706    def _ensure_bool(node: exp.Expression) -> None:
707        if (
708            node.is_number
709            or (
710                not isinstance(node, exp.SubqueryPredicate)
711                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
712            )
713            or (isinstance(node, exp.Column) and not node.type)
714        ):
715            node.replace(node.neq(0))
716
717    for node in expression.walk():
718        ensure_bools(node, _ensure_bool)
719
720    return expression
721
722
723def unqualify_columns(expression: exp.Expression) -> exp.Expression:
724    for column in expression.find_all(exp.Column):
725        # We only wanna pop off the table, db, catalog args
726        for part in column.parts[:-1]:
727            part.pop()
728
729    return expression
730
731
732def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
733    assert isinstance(expression, exp.Create)
734    for constraint in expression.find_all(exp.UniqueColumnConstraint):
735        if constraint.parent:
736            constraint.parent.pop()
737
738    return expression
739
740
741def ctas_with_tmp_tables_to_create_tmp_view(
742    expression: exp.Expression,
743    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
744) -> exp.Expression:
745    assert isinstance(expression, exp.Create)
746    properties = expression.args.get("properties")
747    temporary = any(
748        isinstance(prop, exp.TemporaryProperty)
749        for prop in (properties.expressions if properties else [])
750    )
751
752    # CTAS with temp tables map to CREATE TEMPORARY VIEW
753    if expression.kind == "TABLE" and temporary:
754        if expression.expression:
755            return exp.Create(
756                kind="TEMPORARY VIEW",
757                this=expression.this,
758                expression=expression.expression,
759            )
760        return tmp_storage_provider(expression)
761
762    return expression
763
764
765def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
766    """
767    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
768    PARTITIONED BY value is an array of column names, they are transformed into a schema.
769    The corresponding columns are removed from the create statement.
770    """
771    assert isinstance(expression, exp.Create)
772    has_schema = isinstance(expression.this, exp.Schema)
773    is_partitionable = expression.kind in {"TABLE", "VIEW"}
774
775    if has_schema and is_partitionable:
776        prop = expression.find(exp.PartitionedByProperty)
777        if prop and prop.this and not isinstance(prop.this, exp.Schema):
778            schema = expression.this
779            columns = {v.name.upper() for v in prop.this.expressions}
780            partitions = [col for col in schema.expressions if col.name.upper() in columns]
781            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
782            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
783            expression.set("this", schema)
784
785    return expression
786
787
788def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
789    """
790    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
791
792    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
793    """
794    assert isinstance(expression, exp.Create)
795    prop = expression.find(exp.PartitionedByProperty)
796    if (
797        prop
798        and prop.this
799        and isinstance(prop.this, exp.Schema)
800        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
801    ):
802        prop_this = exp.Tuple(
803            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
804        )
805        schema = expression.this
806        for e in prop.this.expressions:
807            schema.append("expressions", e)
808        prop.set("this", prop_this)
809
810    return expression
811
812
813def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
814    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
815    if isinstance(expression, exp.Struct):
816        expression.set(
817            "expressions",
818            [
819                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
820                for e in expression.expressions
821            ],
822        )
823
824    return expression
825
826
827def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
828    """
829    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
830    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
831
832    For example,
833        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
834        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
835
836    Args:
837        expression: The AST to remove join marks from.
838
839    Returns:
840       The AST with join marks removed.
841    """
842    from sqlglot.optimizer.scope import traverse_scope
843
844    for scope in traverse_scope(expression):
845        query = scope.expression
846
847        where = query.args.get("where")
848        joins = query.args.get("joins")
849
850        if not where or not joins:
851            continue
852
853        query_from = query.args["from"]
854
855        # These keep track of the joins to be replaced
856        new_joins: t.Dict[str, exp.Join] = {}
857        old_joins = {join.alias_or_name: join for join in joins}
858
859        for column in scope.columns:
860            if not column.args.get("join_mark"):
861                continue
862
863            predicate = column.find_ancestor(exp.Predicate, exp.Select)
864            assert isinstance(
865                predicate, exp.Binary
866            ), "Columns can only be marked with (+) when involved in a binary operation"
867
868            predicate_parent = predicate.parent
869            join_predicate = predicate.pop()
870
871            left_columns = [
872                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
873            ]
874            right_columns = [
875                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
876            ]
877
878            assert not (
879                left_columns and right_columns
880            ), "The (+) marker cannot appear in both sides of a binary predicate"
881
882            marked_column_tables = set()
883            for col in left_columns or right_columns:
884                table = col.table
885                assert table, f"Column {col} needs to be qualified with a table"
886
887                col.set("join_mark", False)
888                marked_column_tables.add(table)
889
890            assert (
891                len(marked_column_tables) == 1
892            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
893
894            join_this = old_joins.get(col.table, query_from).this
895            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
896
897            # Upsert new_join into new_joins dictionary
898            new_join_alias_or_name = new_join.alias_or_name
899            existing_join = new_joins.get(new_join_alias_or_name)
900            if existing_join:
901                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
902            else:
903                new_joins[new_join_alias_or_name] = new_join
904
905            # If the parent of the target predicate is a binary node, then it now has only one child
906            if isinstance(predicate_parent, exp.Binary):
907                if predicate_parent.left is None:
908                    predicate_parent.replace(predicate_parent.right)
909                else:
910                    predicate_parent.replace(predicate_parent.left)
911
912        if query_from.alias_or_name in new_joins:
913            only_old_joins = old_joins.keys() - new_joins.keys()
914            assert (
915                len(only_old_joins) >= 1
916            ), "Cannot determine which table to use in the new FROM clause"
917
918            new_from_name = list(only_old_joins)[0]
919            query.set("from", exp.From(this=old_joins[new_from_name].this))
920
921        query.set("joins", list(new_joins.values()))
922
923        if not where.this:
924            where.pop()
925
926    return expression
927
928
929def any_to_exists(expression: exp.Expression) -> exp.Expression:
930    """
931    Transform ANY operator to Spark's EXISTS
932
933    For example,
934        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
935        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
936
937    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
938    transformation
939    """
940    if isinstance(expression, exp.Select):
941        for any in expression.find_all(exp.Any):
942            this = any.this
943            if isinstance(this, exp.Query):
944                continue
945
946            binop = any.parent
947            if isinstance(binop, exp.Binary):
948                lambda_arg = exp.to_identifier("x")
949                any.replace(lambda_arg)
950                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
951                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
952
953    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
16def preprocess(
17    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
18) -> t.Callable[[Generator, exp.Expression], str]:
19    """
20    Creates a new transform by chaining a sequence of transformations and converts the resulting
21    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
22    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
23
24    Args:
25        transforms: sequence of transform functions. These will be called in order.
26
27    Returns:
28        Function that can be used as a generator transform.
29    """
30
31    def _to_sql(self, expression: exp.Expression) -> str:
32        expression_type = type(expression)
33
34        try:
35            expression = transforms[0](expression)
36            for transform in transforms[1:]:
37                expression = transform(expression)
38        except UnsupportedError as unsupported_error:
39            self.unsupported(str(unsupported_error))
40
41        _sql_handler = getattr(self, expression.key + "_sql", None)
42        if _sql_handler:
43            return _sql_handler(expression)
44
45        transforms_handler = self.TRANSFORMS.get(type(expression))
46        if transforms_handler:
47            if expression_type is type(expression):
48                if isinstance(expression, exp.Func):
49                    return self.function_fallback_sql(expression)
50
51                # Ensures we don't enter an infinite loop. This can happen when the original expression
52                # has the same type as the final expression and there's no _sql method available for it,
53                # because then it'd re-enter _to_sql.
54                raise ValueError(
55                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
56                )
57
58            return transforms_handler(self, expression)
59
60        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
61
62    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 66    if isinstance(expression, exp.Select):
 67        count = 0
 68        recursive_ctes = []
 69
 70        for unnest in expression.find_all(exp.Unnest):
 71            if (
 72                not isinstance(unnest.parent, (exp.From, exp.Join))
 73                or len(unnest.expressions) != 1
 74                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 75            ):
 76                continue
 77
 78            generate_date_array = unnest.expressions[0]
 79            start = generate_date_array.args.get("start")
 80            end = generate_date_array.args.get("end")
 81            step = generate_date_array.args.get("step")
 82
 83            if not start or not end or not isinstance(step, exp.Interval):
 84                continue
 85
 86            alias = unnest.args.get("alias")
 87            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 88
 89            start = exp.cast(start, "date")
 90            date_add = exp.func(
 91                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 92            )
 93            cast_date_add = exp.cast(date_add, "date")
 94
 95            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 96
 97            base_query = exp.select(start.as_(column_name))
 98            recursive_query = (
 99                exp.select(cast_date_add)
100                .from_(cte_name)
101                .where(cast_date_add <= exp.cast(end, "date"))
102            )
103            cte_query = base_query.union(recursive_query, distinct=False)
104
105            generate_dates_query = exp.select(column_name).from_(cte_name)
106            unnest.replace(generate_dates_query.subquery(cte_name))
107
108            recursive_ctes.append(
109                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
110            )
111            count += 1
112
113        if recursive_ctes:
114            with_expression = expression.args.get("with") or exp.With()
115            with_expression.set("recursive", True)
116            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
117            expression.set("with", with_expression)
118
119    return expression
def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
123    """Unnests GENERATE_SERIES or SEQUENCE table references."""
124    this = expression.this
125    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
126        unnest = exp.Unnest(expressions=[this])
127        if expression.alias:
128            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
129
130        return unnest
131
132    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
135def unalias_group(expression: exp.Expression) -> exp.Expression:
136    """
137    Replace references to select aliases in GROUP BY clauses.
138
139    Example:
140        >>> import sqlglot
141        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
142        'SELECT a AS b FROM x GROUP BY 1'
143
144    Args:
145        expression: the expression that will be transformed.
146
147    Returns:
148        The transformed expression.
149    """
150    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
151        aliased_selects = {
152            e.alias: i
153            for i, e in enumerate(expression.parent.expressions, start=1)
154            if isinstance(e, exp.Alias)
155        }
156
157        for group_by in expression.expressions:
158            if (
159                isinstance(group_by, exp.Column)
160                and not group_by.table
161                and group_by.name in aliased_selects
162            ):
163                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
164
165    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
168def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
169    """
170    Convert SELECT DISTINCT ON statements to a subquery with a window function.
171
172    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
173
174    Args:
175        expression: the expression that will be transformed.
176
177    Returns:
178        The transformed expression.
179    """
180    if (
181        isinstance(expression, exp.Select)
182        and expression.args.get("distinct")
183        and expression.args["distinct"].args.get("on")
184        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
185    ):
186        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
187        row_number = find_new_name(expression.named_selects, "_row_number")
188        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
189        order = expression.args.get("order")
190
191        if order:
192            window.set("order", order.pop())
193        else:
194            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
195
196        window = exp.alias_(window, row_number)
197        expression.select(window, copy=False)
198
199        return (
200            exp.select("*", copy=False)
201            .from_(expression.subquery("_t", copy=False), copy=False)
202            .where(exp.column(row_number).eq(1), copy=False)
203        )
204
205    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
208def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
209    """
210    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
211
212    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
213    https://docs.snowflake.com/en/sql-reference/constructs/qualify
214
215    Some dialects don't support window functions in the WHERE clause, so we need to include them as
216    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
217    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
218    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
219    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
220    corresponding expression to avoid creating invalid column references.
221    """
222    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
223        taken = set(expression.named_selects)
224        for select in expression.selects:
225            if not select.alias_or_name:
226                alias = find_new_name(taken, "_c")
227                select.replace(exp.alias_(select, alias))
228                taken.add(alias)
229
230        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
231            alias_or_name = select.alias_or_name
232            identifier = select.args.get("alias") or select.this
233            if isinstance(identifier, exp.Identifier):
234                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
235            return alias_or_name
236
237        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
238        qualify_filters = expression.args["qualify"].pop().this
239        expression_by_alias = {
240            select.alias: select.this
241            for select in expression.selects
242            if isinstance(select, exp.Alias)
243        }
244
245        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
246        for select_candidate in qualify_filters.find_all(select_candidates):
247            if isinstance(select_candidate, exp.Window):
248                if expression_by_alias:
249                    for column in select_candidate.find_all(exp.Column):
250                        expr = expression_by_alias.get(column.name)
251                        if expr:
252                            column.replace(expr)
253
254                alias = find_new_name(expression.named_selects, "_w")
255                expression.select(exp.alias_(select_candidate, alias), copy=False)
256                column = exp.column(alias)
257
258                if isinstance(select_candidate.parent, exp.Qualify):
259                    qualify_filters = column
260                else:
261                    select_candidate.replace(column)
262            elif select_candidate.name not in expression.named_selects:
263                expression.select(select_candidate.copy(), copy=False)
264
265        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
266            qualify_filters, copy=False
267        )
268
269    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
272def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
273    """
274    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
275    other expressions. This transforms removes the precision from parameterized types in expressions.
276    """
277    for node in expression.find_all(exp.DataType):
278        node.set(
279            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
280        )
281
282    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
285def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
286    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
287    from sqlglot.optimizer.scope import find_all_in_scope
288
289    if isinstance(expression, exp.Select):
290        unnest_aliases = {
291            unnest.alias
292            for unnest in find_all_in_scope(expression, exp.Unnest)
293            if isinstance(unnest.parent, (exp.From, exp.Join))
294        }
295        if unnest_aliases:
296            for column in expression.find_all(exp.Column):
297                if column.table in unnest_aliases:
298                    column.set("table", None)
299                elif column.db in unnest_aliases:
300                    column.set("db", None)
301
302    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression, unnest_using_arrays_zip: bool = True) -> sqlglot.expressions.Expression:
305def unnest_to_explode(
306    expression: exp.Expression,
307    unnest_using_arrays_zip: bool = True,
308) -> exp.Expression:
309    """Convert cross join unnest into lateral view explode."""
310
311    def _unnest_zip_exprs(
312        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
313    ) -> t.List[exp.Expression]:
314        if has_multi_expr:
315            if not unnest_using_arrays_zip:
316                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
317
318            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
319            zip_exprs: t.List[exp.Expression] = [
320                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
321            ]
322            u.set("expressions", zip_exprs)
323            return zip_exprs
324        return unnest_exprs
325
326    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
327        if u.args.get("offset"):
328            return exp.Posexplode
329        return exp.Inline if has_multi_expr else exp.Explode
330
331    if isinstance(expression, exp.Select):
332        from_ = expression.args.get("from")
333
334        if from_ and isinstance(from_.this, exp.Unnest):
335            unnest = from_.this
336            alias = unnest.args.get("alias")
337            exprs = unnest.expressions
338            has_multi_expr = len(exprs) > 1
339            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
340
341            unnest.replace(
342                exp.Table(
343                    this=_udtf_type(unnest, has_multi_expr)(
344                        this=this,
345                        expressions=expressions,
346                    ),
347                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
348                )
349            )
350
351        joins = expression.args.get("joins") or []
352        for join in list(joins):
353            join_expr = join.this
354
355            is_lateral = isinstance(join_expr, exp.Lateral)
356
357            unnest = join_expr.this if is_lateral else join_expr
358
359            if isinstance(unnest, exp.Unnest):
360                if is_lateral:
361                    alias = join_expr.args.get("alias")
362                else:
363                    alias = unnest.args.get("alias")
364                exprs = unnest.expressions
365                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
366                has_multi_expr = len(exprs) > 1
367                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
368
369                joins.remove(join)
370
371                alias_cols = alias.columns if alias else []
372
373                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
374                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
375                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
376
377                if not has_multi_expr and len(alias_cols) not in (1, 2):
378                    raise UnsupportedError(
379                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
380                    )
381
382                for e, column in zip(exprs, alias_cols):
383                    expression.append(
384                        "laterals",
385                        exp.Lateral(
386                            this=_udtf_type(unnest, has_multi_expr)(this=e),
387                            view=True,
388                            alias=exp.TableAlias(
389                                this=alias.this,  # type: ignore
390                                columns=alias_cols,
391                            ),
392                        ),
393                    )
394
395    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
398def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
399    """Convert explode/posexplode into unnest."""
400
401    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
402        if isinstance(expression, exp.Select):
403            from sqlglot.optimizer.scope import Scope
404
405            taken_select_names = set(expression.named_selects)
406            taken_source_names = {name for name, _ in Scope(expression).references}
407
408            def new_name(names: t.Set[str], name: str) -> str:
409                name = find_new_name(names, name)
410                names.add(name)
411                return name
412
413            arrays: t.List[exp.Condition] = []
414            series_alias = new_name(taken_select_names, "pos")
415            series = exp.alias_(
416                exp.Unnest(
417                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
418                ),
419                new_name(taken_source_names, "_u"),
420                table=[series_alias],
421            )
422
423            # we use list here because expression.selects is mutated inside the loop
424            for select in list(expression.selects):
425                explode = select.find(exp.Explode)
426
427                if explode:
428                    pos_alias = ""
429                    explode_alias = ""
430
431                    if isinstance(select, exp.Alias):
432                        explode_alias = select.args["alias"]
433                        alias = select
434                    elif isinstance(select, exp.Aliases):
435                        pos_alias = select.aliases[0]
436                        explode_alias = select.aliases[1]
437                        alias = select.replace(exp.alias_(select.this, "", copy=False))
438                    else:
439                        alias = select.replace(exp.alias_(select, ""))
440                        explode = alias.find(exp.Explode)
441                        assert explode
442
443                    is_posexplode = isinstance(explode, exp.Posexplode)
444                    explode_arg = explode.this
445
446                    if isinstance(explode, exp.ExplodeOuter):
447                        bracket = explode_arg[0]
448                        bracket.set("safe", True)
449                        bracket.set("offset", True)
450                        explode_arg = exp.func(
451                            "IF",
452                            exp.func(
453                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
454                            ).eq(0),
455                            exp.array(bracket, copy=False),
456                            explode_arg,
457                        )
458
459                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
460                    if isinstance(explode_arg, exp.Column):
461                        taken_select_names.add(explode_arg.output_name)
462
463                    unnest_source_alias = new_name(taken_source_names, "_u")
464
465                    if not explode_alias:
466                        explode_alias = new_name(taken_select_names, "col")
467
468                        if is_posexplode:
469                            pos_alias = new_name(taken_select_names, "pos")
470
471                    if not pos_alias:
472                        pos_alias = new_name(taken_select_names, "pos")
473
474                    alias.set("alias", exp.to_identifier(explode_alias))
475
476                    series_table_alias = series.args["alias"].this
477                    column = exp.If(
478                        this=exp.column(series_alias, table=series_table_alias).eq(
479                            exp.column(pos_alias, table=unnest_source_alias)
480                        ),
481                        true=exp.column(explode_alias, table=unnest_source_alias),
482                    )
483
484                    explode.replace(column)
485
486                    if is_posexplode:
487                        expressions = expression.expressions
488                        expressions.insert(
489                            expressions.index(alias) + 1,
490                            exp.If(
491                                this=exp.column(series_alias, table=series_table_alias).eq(
492                                    exp.column(pos_alias, table=unnest_source_alias)
493                                ),
494                                true=exp.column(pos_alias, table=unnest_source_alias),
495                            ).as_(pos_alias),
496                        )
497                        expression.set("expressions", expressions)
498
499                    if not arrays:
500                        if expression.args.get("from"):
501                            expression.join(series, copy=False, join_type="CROSS")
502                        else:
503                            expression.from_(series, copy=False)
504
505                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
506                    arrays.append(size)
507
508                    # trino doesn't support left join unnest with on conditions
509                    # if it did, this would be much simpler
510                    expression.join(
511                        exp.alias_(
512                            exp.Unnest(
513                                expressions=[explode_arg.copy()],
514                                offset=exp.to_identifier(pos_alias),
515                            ),
516                            unnest_source_alias,
517                            table=[explode_alias],
518                        ),
519                        join_type="CROSS",
520                        copy=False,
521                    )
522
523                    if index_offset != 1:
524                        size = size - 1
525
526                    expression.where(
527                        exp.column(series_alias, table=series_table_alias)
528                        .eq(exp.column(pos_alias, table=unnest_source_alias))
529                        .or_(
530                            (exp.column(series_alias, table=series_table_alias) > size).and_(
531                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
532                            )
533                        ),
534                        copy=False,
535                    )
536
537            if arrays:
538                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
539
540                if index_offset != 1:
541                    end = end - (1 - index_offset)
542                series.expressions[0].set("end", end)
543
544        return expression
545
546    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
549def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
550    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
551    if (
552        isinstance(expression, exp.PERCENTILES)
553        and not isinstance(expression.parent, exp.WithinGroup)
554        and expression.expression
555    ):
556        column = expression.this.pop()
557        expression.set("this", expression.expression.pop())
558        order = exp.Order(expressions=[exp.Ordered(this=column)])
559        expression = exp.WithinGroup(this=expression, expression=order)
560
561    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
564def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
565    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
566    if (
567        isinstance(expression, exp.WithinGroup)
568        and isinstance(expression.this, exp.PERCENTILES)
569        and isinstance(expression.expression, exp.Order)
570    ):
571        quantile = expression.this.this
572        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
573        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
574
575    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
578def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
579    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
580    if isinstance(expression, exp.With) and expression.recursive:
581        next_name = name_sequence("_c_")
582
583        for cte in expression.expressions:
584            if not cte.args["alias"].columns:
585                query = cte.this
586                if isinstance(query, exp.SetOperation):
587                    query = query.this
588
589                cte.args["alias"].set(
590                    "columns",
591                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
592                )
593
594    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
597def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
598    """Replace 'epoch' in casts by the equivalent date literal."""
599    if (
600        isinstance(expression, (exp.Cast, exp.TryCast))
601        and expression.name.lower() == "epoch"
602        and expression.to.this in exp.DataType.TEMPORAL_TYPES
603    ):
604        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
605
606    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
609def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
610    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
611    if isinstance(expression, exp.Select):
612        for join in expression.args.get("joins") or []:
613            on = join.args.get("on")
614            if on and join.kind in ("SEMI", "ANTI"):
615                subquery = exp.select("1").from_(join.this).where(on)
616                exists = exp.Exists(this=subquery)
617                if join.kind == "ANTI":
618                    exists = exists.not_(copy=False)
619
620                join.pop()
621                expression.where(exists, copy=False)
622
623    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
626def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
627    """
628    Converts a query with a FULL OUTER join to a union of identical queries that
629    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
630    for queries that have a single FULL OUTER join.
631    """
632    if isinstance(expression, exp.Select):
633        full_outer_joins = [
634            (index, join)
635            for index, join in enumerate(expression.args.get("joins") or [])
636            if join.side == "FULL"
637        ]
638
639        if len(full_outer_joins) == 1:
640            expression_copy = expression.copy()
641            expression.set("limit", None)
642            index, full_outer_join = full_outer_joins[0]
643
644            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
645            join_conditions = full_outer_join.args.get("on") or exp.and_(
646                *[
647                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
648                    for col in full_outer_join.args.get("using")
649                ]
650            )
651
652            full_outer_join.set("side", "left")
653            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
654            expression_copy.args["joins"][index].set("side", "right")
655            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
656            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
657            expression.args.pop("order", None)  # remove order by from LEFT side
658
659            return exp.union(expression, expression_copy, copy=False, distinct=False)
660
661    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level(expression: ~E) -> ~E:
664def move_ctes_to_top_level(expression: E) -> E:
665    """
666    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
667    defined at the top-level, so for example queries like:
668
669        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
670
671    are invalid in those dialects. This transformation can be used to ensure all CTEs are
672    moved to the top level so that the final SQL code is valid from a syntax standpoint.
673
674    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
675    """
676    top_level_with = expression.args.get("with")
677    for inner_with in expression.find_all(exp.With):
678        if inner_with.parent is expression:
679            continue
680
681        if not top_level_with:
682            top_level_with = inner_with.pop()
683            expression.set("with", top_level_with)
684        else:
685            if inner_with.recursive:
686                top_level_with.set("recursive", True)
687
688            parent_cte = inner_with.find_ancestor(exp.CTE)
689            inner_with.pop()
690
691            if parent_cte:
692                i = top_level_with.expressions.index(parent_cte)
693                top_level_with.expressions[i:i] = inner_with.expressions
694                top_level_with.set("expressions", top_level_with.expressions)
695            else:
696                top_level_with.set(
697                    "expressions", top_level_with.expressions + inner_with.expressions
698                )
699
700    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
703def ensure_bools(expression: exp.Expression) -> exp.Expression:
704    """Converts numeric values used in conditions into explicit boolean expressions."""
705    from sqlglot.optimizer.canonicalize import ensure_bools
706
707    def _ensure_bool(node: exp.Expression) -> None:
708        if (
709            node.is_number
710            or (
711                not isinstance(node, exp.SubqueryPredicate)
712                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
713            )
714            or (isinstance(node, exp.Column) and not node.type)
715        ):
716            node.replace(node.neq(0))
717
718    for node in expression.walk():
719        ensure_bools(node, _ensure_bool)
720
721    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
724def unqualify_columns(expression: exp.Expression) -> exp.Expression:
725    for column in expression.find_all(exp.Column):
726        # We only wanna pop off the table, db, catalog args
727        for part in column.parts[:-1]:
728            part.pop()
729
730    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
733def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
734    assert isinstance(expression, exp.Create)
735    for constraint in expression.find_all(exp.UniqueColumnConstraint):
736        if constraint.parent:
737            constraint.parent.pop()
738
739    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
742def ctas_with_tmp_tables_to_create_tmp_view(
743    expression: exp.Expression,
744    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
745) -> exp.Expression:
746    assert isinstance(expression, exp.Create)
747    properties = expression.args.get("properties")
748    temporary = any(
749        isinstance(prop, exp.TemporaryProperty)
750        for prop in (properties.expressions if properties else [])
751    )
752
753    # CTAS with temp tables map to CREATE TEMPORARY VIEW
754    if expression.kind == "TABLE" and temporary:
755        if expression.expression:
756            return exp.Create(
757                kind="TEMPORARY VIEW",
758                this=expression.this,
759                expression=expression.expression,
760            )
761        return tmp_storage_provider(expression)
762
763    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
766def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
767    """
768    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
769    PARTITIONED BY value is an array of column names, they are transformed into a schema.
770    The corresponding columns are removed from the create statement.
771    """
772    assert isinstance(expression, exp.Create)
773    has_schema = isinstance(expression.this, exp.Schema)
774    is_partitionable = expression.kind in {"TABLE", "VIEW"}
775
776    if has_schema and is_partitionable:
777        prop = expression.find(exp.PartitionedByProperty)
778        if prop and prop.this and not isinstance(prop.this, exp.Schema):
779            schema = expression.this
780            columns = {v.name.upper() for v in prop.this.expressions}
781            partitions = [col for col in schema.expressions if col.name.upper() in columns]
782            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
783            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
784            expression.set("this", schema)
785
786    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
789def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
790    """
791    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
792
793    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
794    """
795    assert isinstance(expression, exp.Create)
796    prop = expression.find(exp.PartitionedByProperty)
797    if (
798        prop
799        and prop.this
800        and isinstance(prop.this, exp.Schema)
801        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
802    ):
803        prop_this = exp.Tuple(
804            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
805        )
806        schema = expression.this
807        for e in prop.this.expressions:
808            schema.append("expressions", e)
809        prop.set("this", prop_this)
810
811    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
814def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
815    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
816    if isinstance(expression, exp.Struct):
817        expression.set(
818            "expressions",
819            [
820                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
821                for e in expression.expressions
822            ],
823        )
824
825    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
828def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
829    """
830    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
831    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
832
833    For example,
834        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
835        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
836
837    Args:
838        expression: The AST to remove join marks from.
839
840    Returns:
841       The AST with join marks removed.
842    """
843    from sqlglot.optimizer.scope import traverse_scope
844
845    for scope in traverse_scope(expression):
846        query = scope.expression
847
848        where = query.args.get("where")
849        joins = query.args.get("joins")
850
851        if not where or not joins:
852            continue
853
854        query_from = query.args["from"]
855
856        # These keep track of the joins to be replaced
857        new_joins: t.Dict[str, exp.Join] = {}
858        old_joins = {join.alias_or_name: join for join in joins}
859
860        for column in scope.columns:
861            if not column.args.get("join_mark"):
862                continue
863
864            predicate = column.find_ancestor(exp.Predicate, exp.Select)
865            assert isinstance(
866                predicate, exp.Binary
867            ), "Columns can only be marked with (+) when involved in a binary operation"
868
869            predicate_parent = predicate.parent
870            join_predicate = predicate.pop()
871
872            left_columns = [
873                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
874            ]
875            right_columns = [
876                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
877            ]
878
879            assert not (
880                left_columns and right_columns
881            ), "The (+) marker cannot appear in both sides of a binary predicate"
882
883            marked_column_tables = set()
884            for col in left_columns or right_columns:
885                table = col.table
886                assert table, f"Column {col} needs to be qualified with a table"
887
888                col.set("join_mark", False)
889                marked_column_tables.add(table)
890
891            assert (
892                len(marked_column_tables) == 1
893            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
894
895            join_this = old_joins.get(col.table, query_from).this
896            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
897
898            # Upsert new_join into new_joins dictionary
899            new_join_alias_or_name = new_join.alias_or_name
900            existing_join = new_joins.get(new_join_alias_or_name)
901            if existing_join:
902                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
903            else:
904                new_joins[new_join_alias_or_name] = new_join
905
906            # If the parent of the target predicate is a binary node, then it now has only one child
907            if isinstance(predicate_parent, exp.Binary):
908                if predicate_parent.left is None:
909                    predicate_parent.replace(predicate_parent.right)
910                else:
911                    predicate_parent.replace(predicate_parent.left)
912
913        if query_from.alias_or_name in new_joins:
914            only_old_joins = old_joins.keys() - new_joins.keys()
915            assert (
916                len(only_old_joins) >= 1
917            ), "Cannot determine which table to use in the new FROM clause"
918
919            new_from_name = list(only_old_joins)[0]
920            query.set("from", exp.From(this=old_joins[new_from_name].this))
921
922        query.set("joins", list(new_joins.values()))
923
924        if not where.this:
925            where.pop()
926
927    return expression

Remove join marks from an AST. This rule assumes that all marked columns are qualified. If this does not hold for a query, consider running sqlglot.optimizer.qualify first.

For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this

Arguments:
  • expression: The AST to remove join marks from.
Returns:

The AST with join marks removed.

def any_to_exists( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
930def any_to_exists(expression: exp.Expression) -> exp.Expression:
931    """
932    Transform ANY operator to Spark's EXISTS
933
934    For example,
935        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
936        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
937
938    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
939    transformation
940    """
941    if isinstance(expression, exp.Select):
942        for any in expression.find_all(exp.Any):
943            this = any.this
944            if isinstance(this, exp.Query):
945                continue
946
947            binop = any.parent
948            if isinstance(binop, exp.Binary):
949                lambda_arg = exp.to_identifier("x")
950                any.replace(lambda_arg)
951                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
952                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
953
954    return expression

Transform ANY operator to Spark's EXISTS

For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)

Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation