Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 67        order = expression.args.get("order")
 68
 69        if order:
 70            window.set("order", order.pop())
 71        else:
 72            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 73
 74        window = exp.alias_(window, row_number)
 75        expression.select(window, copy=False)
 76
 77        return (
 78            exp.select(*outer_selects, copy=False)
 79            .from_(expression.subquery("_t", copy=False), copy=False)
 80            .where(exp.column(row_number).eq(1), copy=False)
 81        )
 82
 83    return expression
 84
 85
 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 87    """
 88    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 89
 90    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 91    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 92
 93    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 94    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 95    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 96    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
 97    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
 98    corresponding expression to avoid creating invalid column references.
 99    """
100    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
101        taken = set(expression.named_selects)
102        for select in expression.selects:
103            if not select.alias_or_name:
104                alias = find_new_name(taken, "_c")
105                select.replace(exp.alias_(select, alias))
106                taken.add(alias)
107
108        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
109        qualify_filters = expression.args["qualify"].pop().this
110        expression_by_alias = {
111            select.alias: select.this
112            for select in expression.selects
113            if isinstance(select, exp.Alias)
114        }
115
116        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
117        for select_candidate in qualify_filters.find_all(select_candidates):
118            if isinstance(select_candidate, exp.Window):
119                if expression_by_alias:
120                    for column in select_candidate.find_all(exp.Column):
121                        expr = expression_by_alias.get(column.name)
122                        if expr:
123                            column.replace(expr)
124
125                alias = find_new_name(expression.named_selects, "_w")
126                expression.select(exp.alias_(select_candidate, alias), copy=False)
127                column = exp.column(alias)
128
129                if isinstance(select_candidate.parent, exp.Qualify):
130                    qualify_filters = column
131                else:
132                    select_candidate.replace(column)
133            elif select_candidate.name not in expression.named_selects:
134                expression.select(select_candidate.copy(), copy=False)
135
136        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
137            qualify_filters, copy=False
138        )
139
140    return expression
141
142
143def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
144    """
145    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
146    other expressions. This transforms removes the precision from parameterized types in expressions.
147    """
148    for node in expression.find_all(exp.DataType):
149        node.set(
150            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
151        )
152
153    return expression
154
155
156def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
157    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
158    from sqlglot.optimizer.scope import find_all_in_scope
159
160    if isinstance(expression, exp.Select):
161        unnest_aliases = {
162            unnest.alias
163            for unnest in find_all_in_scope(expression, exp.Unnest)
164            if isinstance(unnest.parent, (exp.From, exp.Join))
165        }
166        if unnest_aliases:
167            for column in expression.find_all(exp.Column):
168                if column.table in unnest_aliases:
169                    column.set("table", None)
170                elif column.db in unnest_aliases:
171                    column.set("db", None)
172
173    return expression
174
175
176def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
177    """Convert cross join unnest into lateral view explode."""
178    if isinstance(expression, exp.Select):
179        for join in expression.args.get("joins") or []:
180            unnest = join.this
181
182            if isinstance(unnest, exp.Unnest):
183                alias = unnest.args.get("alias")
184                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
185
186                expression.args["joins"].remove(join)
187
188                for e, column in zip(unnest.expressions, alias.columns if alias else []):
189                    expression.append(
190                        "laterals",
191                        exp.Lateral(
192                            this=udtf(this=e),
193                            view=True,
194                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
195                        ),
196                    )
197
198    return expression
199
200
201def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
202    """Convert explode/posexplode into unnest."""
203
204    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
205        if isinstance(expression, exp.Select):
206            from sqlglot.optimizer.scope import Scope
207
208            taken_select_names = set(expression.named_selects)
209            taken_source_names = {name for name, _ in Scope(expression).references}
210
211            def new_name(names: t.Set[str], name: str) -> str:
212                name = find_new_name(names, name)
213                names.add(name)
214                return name
215
216            arrays: t.List[exp.Condition] = []
217            series_alias = new_name(taken_select_names, "pos")
218            series = exp.alias_(
219                exp.Unnest(
220                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
221                ),
222                new_name(taken_source_names, "_u"),
223                table=[series_alias],
224            )
225
226            # we use list here because expression.selects is mutated inside the loop
227            for select in list(expression.selects):
228                explode = select.find(exp.Explode)
229
230                if explode:
231                    pos_alias = ""
232                    explode_alias = ""
233
234                    if isinstance(select, exp.Alias):
235                        explode_alias = select.args["alias"]
236                        alias = select
237                    elif isinstance(select, exp.Aliases):
238                        pos_alias = select.aliases[0]
239                        explode_alias = select.aliases[1]
240                        alias = select.replace(exp.alias_(select.this, "", copy=False))
241                    else:
242                        alias = select.replace(exp.alias_(select, ""))
243                        explode = alias.find(exp.Explode)
244                        assert explode
245
246                    is_posexplode = isinstance(explode, exp.Posexplode)
247                    explode_arg = explode.this
248
249                    if isinstance(explode, exp.ExplodeOuter):
250                        bracket = explode_arg[0]
251                        bracket.set("safe", True)
252                        bracket.set("offset", True)
253                        explode_arg = exp.func(
254                            "IF",
255                            exp.func(
256                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
257                            ).eq(0),
258                            exp.array(bracket, copy=False),
259                            explode_arg,
260                        )
261
262                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
263                    if isinstance(explode_arg, exp.Column):
264                        taken_select_names.add(explode_arg.output_name)
265
266                    unnest_source_alias = new_name(taken_source_names, "_u")
267
268                    if not explode_alias:
269                        explode_alias = new_name(taken_select_names, "col")
270
271                        if is_posexplode:
272                            pos_alias = new_name(taken_select_names, "pos")
273
274                    if not pos_alias:
275                        pos_alias = new_name(taken_select_names, "pos")
276
277                    alias.set("alias", exp.to_identifier(explode_alias))
278
279                    series_table_alias = series.args["alias"].this
280                    column = exp.If(
281                        this=exp.column(series_alias, table=series_table_alias).eq(
282                            exp.column(pos_alias, table=unnest_source_alias)
283                        ),
284                        true=exp.column(explode_alias, table=unnest_source_alias),
285                    )
286
287                    explode.replace(column)
288
289                    if is_posexplode:
290                        expressions = expression.expressions
291                        expressions.insert(
292                            expressions.index(alias) + 1,
293                            exp.If(
294                                this=exp.column(series_alias, table=series_table_alias).eq(
295                                    exp.column(pos_alias, table=unnest_source_alias)
296                                ),
297                                true=exp.column(pos_alias, table=unnest_source_alias),
298                            ).as_(pos_alias),
299                        )
300                        expression.set("expressions", expressions)
301
302                    if not arrays:
303                        if expression.args.get("from"):
304                            expression.join(series, copy=False, join_type="CROSS")
305                        else:
306                            expression.from_(series, copy=False)
307
308                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
309                    arrays.append(size)
310
311                    # trino doesn't support left join unnest with on conditions
312                    # if it did, this would be much simpler
313                    expression.join(
314                        exp.alias_(
315                            exp.Unnest(
316                                expressions=[explode_arg.copy()],
317                                offset=exp.to_identifier(pos_alias),
318                            ),
319                            unnest_source_alias,
320                            table=[explode_alias],
321                        ),
322                        join_type="CROSS",
323                        copy=False,
324                    )
325
326                    if index_offset != 1:
327                        size = size - 1
328
329                    expression.where(
330                        exp.column(series_alias, table=series_table_alias)
331                        .eq(exp.column(pos_alias, table=unnest_source_alias))
332                        .or_(
333                            (exp.column(series_alias, table=series_table_alias) > size).and_(
334                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
335                            )
336                        ),
337                        copy=False,
338                    )
339
340            if arrays:
341                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
342
343                if index_offset != 1:
344                    end = end - (1 - index_offset)
345                series.expressions[0].set("end", end)
346
347        return expression
348
349    return _explode_to_unnest
350
351
352PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
353
354
355def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
356    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
357    if (
358        isinstance(expression, PERCENTILES)
359        and not isinstance(expression.parent, exp.WithinGroup)
360        and expression.expression
361    ):
362        column = expression.this.pop()
363        expression.set("this", expression.expression.pop())
364        order = exp.Order(expressions=[exp.Ordered(this=column)])
365        expression = exp.WithinGroup(this=expression, expression=order)
366
367    return expression
368
369
370def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
371    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
372    if (
373        isinstance(expression, exp.WithinGroup)
374        and isinstance(expression.this, PERCENTILES)
375        and isinstance(expression.expression, exp.Order)
376    ):
377        quantile = expression.this.this
378        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
379        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
380
381    return expression
382
383
384def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
385    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
386    if isinstance(expression, exp.With) and expression.recursive:
387        next_name = name_sequence("_c_")
388
389        for cte in expression.expressions:
390            if not cte.args["alias"].columns:
391                query = cte.this
392                if isinstance(query, exp.Union):
393                    query = query.this
394
395                cte.args["alias"].set(
396                    "columns",
397                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
398                )
399
400    return expression
401
402
403def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
404    """Replace 'epoch' in casts by the equivalent date literal."""
405    if (
406        isinstance(expression, (exp.Cast, exp.TryCast))
407        and expression.name.lower() == "epoch"
408        and expression.to.this in exp.DataType.TEMPORAL_TYPES
409    ):
410        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
411
412    return expression
413
414
415def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
416    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
417    if isinstance(expression, exp.Select):
418        for join in expression.args.get("joins") or []:
419            on = join.args.get("on")
420            if on and join.kind in ("SEMI", "ANTI"):
421                subquery = exp.select("1").from_(join.this).where(on)
422                exists = exp.Exists(this=subquery)
423                if join.kind == "ANTI":
424                    exists = exists.not_(copy=False)
425
426                join.pop()
427                expression.where(exists, copy=False)
428
429    return expression
430
431
432def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
433    """
434    Converts a query with a FULL OUTER join to a union of identical queries that
435    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
436    for queries that have a single FULL OUTER join.
437    """
438    if isinstance(expression, exp.Select):
439        full_outer_joins = [
440            (index, join)
441            for index, join in enumerate(expression.args.get("joins") or [])
442            if join.side == "FULL"
443        ]
444
445        if len(full_outer_joins) == 1:
446            expression_copy = expression.copy()
447            expression.set("limit", None)
448            index, full_outer_join = full_outer_joins[0]
449            full_outer_join.set("side", "left")
450            expression_copy.args["joins"][index].set("side", "right")
451            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
452
453            return exp.union(expression, expression_copy, copy=False)
454
455    return expression
456
457
458def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
459    """
460    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
461    defined at the top-level, so for example queries like:
462
463        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
464
465    are invalid in those dialects. This transformation can be used to ensure all CTEs are
466    moved to the top level so that the final SQL code is valid from a syntax standpoint.
467
468    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
469    """
470    top_level_with = expression.args.get("with")
471    for node in expression.find_all(exp.With):
472        if node.parent is expression:
473            continue
474
475        inner_with = node.pop()
476        if not top_level_with:
477            top_level_with = inner_with
478            expression.set("with", top_level_with)
479        else:
480            if inner_with.recursive:
481                top_level_with.set("recursive", True)
482
483            top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions)
484
485    return expression
486
487
488def ensure_bools(expression: exp.Expression) -> exp.Expression:
489    """Converts numeric values used in conditions into explicit boolean expressions."""
490    from sqlglot.optimizer.canonicalize import ensure_bools
491
492    def _ensure_bool(node: exp.Expression) -> None:
493        if (
494            node.is_number
495            or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
496            or (isinstance(node, exp.Column) and not node.type)
497        ):
498            node.replace(node.neq(0))
499
500    for node in expression.walk():
501        ensure_bools(node, _ensure_bool)
502
503    return expression
504
505
506def unqualify_columns(expression: exp.Expression) -> exp.Expression:
507    for column in expression.find_all(exp.Column):
508        # We only wanna pop off the table, db, catalog args
509        for part in column.parts[:-1]:
510            part.pop()
511
512    return expression
513
514
515def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
516    assert isinstance(expression, exp.Create)
517    for constraint in expression.find_all(exp.UniqueColumnConstraint):
518        if constraint.parent:
519            constraint.parent.pop()
520
521    return expression
522
523
524def ctas_with_tmp_tables_to_create_tmp_view(
525    expression: exp.Expression,
526    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
527) -> exp.Expression:
528    assert isinstance(expression, exp.Create)
529    properties = expression.args.get("properties")
530    temporary = any(
531        isinstance(prop, exp.TemporaryProperty)
532        for prop in (properties.expressions if properties else [])
533    )
534
535    # CTAS with temp tables map to CREATE TEMPORARY VIEW
536    if expression.kind == "TABLE" and temporary:
537        if expression.expression:
538            return exp.Create(
539                kind="TEMPORARY VIEW",
540                this=expression.this,
541                expression=expression.expression,
542            )
543        return tmp_storage_provider(expression)
544
545    return expression
546
547
548def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
549    """
550    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
551    PARTITIONED BY value is an array of column names, they are transformed into a schema.
552    The corresponding columns are removed from the create statement.
553    """
554    assert isinstance(expression, exp.Create)
555    has_schema = isinstance(expression.this, exp.Schema)
556    is_partitionable = expression.kind in {"TABLE", "VIEW"}
557
558    if has_schema and is_partitionable:
559        prop = expression.find(exp.PartitionedByProperty)
560        if prop and prop.this and not isinstance(prop.this, exp.Schema):
561            schema = expression.this
562            columns = {v.name.upper() for v in prop.this.expressions}
563            partitions = [col for col in schema.expressions if col.name.upper() in columns]
564            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
565            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
566            expression.set("this", schema)
567
568    return expression
569
570
571def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
572    """
573    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
574
575    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
576    """
577    assert isinstance(expression, exp.Create)
578    prop = expression.find(exp.PartitionedByProperty)
579    if (
580        prop
581        and prop.this
582        and isinstance(prop.this, exp.Schema)
583        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
584    ):
585        prop_this = exp.Tuple(
586            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
587        )
588        schema = expression.this
589        for e in prop.this.expressions:
590            schema.append("expressions", e)
591        prop.set("this", prop_this)
592
593    return expression
594
595
596def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
597    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
598    if isinstance(expression, exp.Struct):
599        expression.set(
600            "expressions",
601            [
602                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
603                for e in expression.expressions
604            ],
605        )
606
607    return expression
608
609
610def preprocess(
611    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
612) -> t.Callable[[Generator, exp.Expression], str]:
613    """
614    Creates a new transform by chaining a sequence of transformations and converts the resulting
615    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
616    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
617
618    Args:
619        transforms: sequence of transform functions. These will be called in order.
620
621    Returns:
622        Function that can be used as a generator transform.
623    """
624
625    def _to_sql(self, expression: exp.Expression) -> str:
626        expression_type = type(expression)
627
628        expression = transforms[0](expression)
629        for transform in transforms[1:]:
630            expression = transform(expression)
631
632        _sql_handler = getattr(self, expression.key + "_sql", None)
633        if _sql_handler:
634            return _sql_handler(expression)
635
636        transforms_handler = self.TRANSFORMS.get(type(expression))
637        if transforms_handler:
638            if expression_type is type(expression):
639                if isinstance(expression, exp.Func):
640                    return self.function_fallback_sql(expression)
641
642                # Ensures we don't enter an infinite loop. This can happen when the original expression
643                # has the same type as the final expression and there's no _sql method available for it,
644                # because then it'd re-enter _to_sql.
645                raise ValueError(
646                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
647                )
648
649            return transforms_handler(self, expression)
650
651        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
652
653    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
68        order = expression.args.get("order")
69
70        if order:
71            window.set("order", order.pop())
72        else:
73            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
74
75        window = exp.alias_(window, row_number)
76        expression.select(window, copy=False)
77
78        return (
79            exp.select(*outer_selects, copy=False)
80            .from_(expression.subquery("_t", copy=False), copy=False)
81            .where(exp.column(row_number).eq(1), copy=False)
82        )
83
84    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 87def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 88    """
 89    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 90
 91    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 92    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 93
 94    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 95    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 96    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 97    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
 98    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
 99    corresponding expression to avoid creating invalid column references.
100    """
101    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
102        taken = set(expression.named_selects)
103        for select in expression.selects:
104            if not select.alias_or_name:
105                alias = find_new_name(taken, "_c")
106                select.replace(exp.alias_(select, alias))
107                taken.add(alias)
108
109        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
110        qualify_filters = expression.args["qualify"].pop().this
111        expression_by_alias = {
112            select.alias: select.this
113            for select in expression.selects
114            if isinstance(select, exp.Alias)
115        }
116
117        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
118        for select_candidate in qualify_filters.find_all(select_candidates):
119            if isinstance(select_candidate, exp.Window):
120                if expression_by_alias:
121                    for column in select_candidate.find_all(exp.Column):
122                        expr = expression_by_alias.get(column.name)
123                        if expr:
124                            column.replace(expr)
125
126                alias = find_new_name(expression.named_selects, "_w")
127                expression.select(exp.alias_(select_candidate, alias), copy=False)
128                column = exp.column(alias)
129
130                if isinstance(select_candidate.parent, exp.Qualify):
131                    qualify_filters = column
132                else:
133                    select_candidate.replace(column)
134            elif select_candidate.name not in expression.named_selects:
135                expression.select(select_candidate.copy(), copy=False)
136
137        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
138            qualify_filters, copy=False
139        )
140
141    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
144def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
145    """
146    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
147    other expressions. This transforms removes the precision from parameterized types in expressions.
148    """
149    for node in expression.find_all(exp.DataType):
150        node.set(
151            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
152        )
153
154    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
157def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
158    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
159    from sqlglot.optimizer.scope import find_all_in_scope
160
161    if isinstance(expression, exp.Select):
162        unnest_aliases = {
163            unnest.alias
164            for unnest in find_all_in_scope(expression, exp.Unnest)
165            if isinstance(unnest.parent, (exp.From, exp.Join))
166        }
167        if unnest_aliases:
168            for column in expression.find_all(exp.Column):
169                if column.table in unnest_aliases:
170                    column.set("table", None)
171                elif column.db in unnest_aliases:
172                    column.set("db", None)
173
174    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
177def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
178    """Convert cross join unnest into lateral view explode."""
179    if isinstance(expression, exp.Select):
180        for join in expression.args.get("joins") or []:
181            unnest = join.this
182
183            if isinstance(unnest, exp.Unnest):
184                alias = unnest.args.get("alias")
185                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
186
187                expression.args["joins"].remove(join)
188
189                for e, column in zip(unnest.expressions, alias.columns if alias else []):
190                    expression.append(
191                        "laterals",
192                        exp.Lateral(
193                            this=udtf(this=e),
194                            view=True,
195                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
196                        ),
197                    )
198
199    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
202def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
203    """Convert explode/posexplode into unnest."""
204
205    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
206        if isinstance(expression, exp.Select):
207            from sqlglot.optimizer.scope import Scope
208
209            taken_select_names = set(expression.named_selects)
210            taken_source_names = {name for name, _ in Scope(expression).references}
211
212            def new_name(names: t.Set[str], name: str) -> str:
213                name = find_new_name(names, name)
214                names.add(name)
215                return name
216
217            arrays: t.List[exp.Condition] = []
218            series_alias = new_name(taken_select_names, "pos")
219            series = exp.alias_(
220                exp.Unnest(
221                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
222                ),
223                new_name(taken_source_names, "_u"),
224                table=[series_alias],
225            )
226
227            # we use list here because expression.selects is mutated inside the loop
228            for select in list(expression.selects):
229                explode = select.find(exp.Explode)
230
231                if explode:
232                    pos_alias = ""
233                    explode_alias = ""
234
235                    if isinstance(select, exp.Alias):
236                        explode_alias = select.args["alias"]
237                        alias = select
238                    elif isinstance(select, exp.Aliases):
239                        pos_alias = select.aliases[0]
240                        explode_alias = select.aliases[1]
241                        alias = select.replace(exp.alias_(select.this, "", copy=False))
242                    else:
243                        alias = select.replace(exp.alias_(select, ""))
244                        explode = alias.find(exp.Explode)
245                        assert explode
246
247                    is_posexplode = isinstance(explode, exp.Posexplode)
248                    explode_arg = explode.this
249
250                    if isinstance(explode, exp.ExplodeOuter):
251                        bracket = explode_arg[0]
252                        bracket.set("safe", True)
253                        bracket.set("offset", True)
254                        explode_arg = exp.func(
255                            "IF",
256                            exp.func(
257                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
258                            ).eq(0),
259                            exp.array(bracket, copy=False),
260                            explode_arg,
261                        )
262
263                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
264                    if isinstance(explode_arg, exp.Column):
265                        taken_select_names.add(explode_arg.output_name)
266
267                    unnest_source_alias = new_name(taken_source_names, "_u")
268
269                    if not explode_alias:
270                        explode_alias = new_name(taken_select_names, "col")
271
272                        if is_posexplode:
273                            pos_alias = new_name(taken_select_names, "pos")
274
275                    if not pos_alias:
276                        pos_alias = new_name(taken_select_names, "pos")
277
278                    alias.set("alias", exp.to_identifier(explode_alias))
279
280                    series_table_alias = series.args["alias"].this
281                    column = exp.If(
282                        this=exp.column(series_alias, table=series_table_alias).eq(
283                            exp.column(pos_alias, table=unnest_source_alias)
284                        ),
285                        true=exp.column(explode_alias, table=unnest_source_alias),
286                    )
287
288                    explode.replace(column)
289
290                    if is_posexplode:
291                        expressions = expression.expressions
292                        expressions.insert(
293                            expressions.index(alias) + 1,
294                            exp.If(
295                                this=exp.column(series_alias, table=series_table_alias).eq(
296                                    exp.column(pos_alias, table=unnest_source_alias)
297                                ),
298                                true=exp.column(pos_alias, table=unnest_source_alias),
299                            ).as_(pos_alias),
300                        )
301                        expression.set("expressions", expressions)
302
303                    if not arrays:
304                        if expression.args.get("from"):
305                            expression.join(series, copy=False, join_type="CROSS")
306                        else:
307                            expression.from_(series, copy=False)
308
309                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
310                    arrays.append(size)
311
312                    # trino doesn't support left join unnest with on conditions
313                    # if it did, this would be much simpler
314                    expression.join(
315                        exp.alias_(
316                            exp.Unnest(
317                                expressions=[explode_arg.copy()],
318                                offset=exp.to_identifier(pos_alias),
319                            ),
320                            unnest_source_alias,
321                            table=[explode_alias],
322                        ),
323                        join_type="CROSS",
324                        copy=False,
325                    )
326
327                    if index_offset != 1:
328                        size = size - 1
329
330                    expression.where(
331                        exp.column(series_alias, table=series_table_alias)
332                        .eq(exp.column(pos_alias, table=unnest_source_alias))
333                        .or_(
334                            (exp.column(series_alias, table=series_table_alias) > size).and_(
335                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
336                            )
337                        ),
338                        copy=False,
339                    )
340
341            if arrays:
342                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
343
344                if index_offset != 1:
345                    end = end - (1 - index_offset)
346                series.expressions[0].set("end", end)
347
348        return expression
349
350    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
356def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
357    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
358    if (
359        isinstance(expression, PERCENTILES)
360        and not isinstance(expression.parent, exp.WithinGroup)
361        and expression.expression
362    ):
363        column = expression.this.pop()
364        expression.set("this", expression.expression.pop())
365        order = exp.Order(expressions=[exp.Ordered(this=column)])
366        expression = exp.WithinGroup(this=expression, expression=order)
367
368    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
371def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
372    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
373    if (
374        isinstance(expression, exp.WithinGroup)
375        and isinstance(expression.this, PERCENTILES)
376        and isinstance(expression.expression, exp.Order)
377    ):
378        quantile = expression.this.this
379        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
380        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
381
382    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
385def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
386    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
387    if isinstance(expression, exp.With) and expression.recursive:
388        next_name = name_sequence("_c_")
389
390        for cte in expression.expressions:
391            if not cte.args["alias"].columns:
392                query = cte.this
393                if isinstance(query, exp.Union):
394                    query = query.this
395
396                cte.args["alias"].set(
397                    "columns",
398                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
399                )
400
401    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
404def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
405    """Replace 'epoch' in casts by the equivalent date literal."""
406    if (
407        isinstance(expression, (exp.Cast, exp.TryCast))
408        and expression.name.lower() == "epoch"
409        and expression.to.this in exp.DataType.TEMPORAL_TYPES
410    ):
411        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
412
413    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
416def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
417    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
418    if isinstance(expression, exp.Select):
419        for join in expression.args.get("joins") or []:
420            on = join.args.get("on")
421            if on and join.kind in ("SEMI", "ANTI"):
422                subquery = exp.select("1").from_(join.this).where(on)
423                exists = exp.Exists(this=subquery)
424                if join.kind == "ANTI":
425                    exists = exists.not_(copy=False)
426
427                join.pop()
428                expression.where(exists, copy=False)
429
430    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
433def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
434    """
435    Converts a query with a FULL OUTER join to a union of identical queries that
436    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
437    for queries that have a single FULL OUTER join.
438    """
439    if isinstance(expression, exp.Select):
440        full_outer_joins = [
441            (index, join)
442            for index, join in enumerate(expression.args.get("joins") or [])
443            if join.side == "FULL"
444        ]
445
446        if len(full_outer_joins) == 1:
447            expression_copy = expression.copy()
448            expression.set("limit", None)
449            index, full_outer_join = full_outer_joins[0]
450            full_outer_join.set("side", "left")
451            expression_copy.args["joins"][index].set("side", "right")
452            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
453
454            return exp.union(expression, expression_copy, copy=False)
455
456    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
459def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
460    """
461    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
462    defined at the top-level, so for example queries like:
463
464        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
465
466    are invalid in those dialects. This transformation can be used to ensure all CTEs are
467    moved to the top level so that the final SQL code is valid from a syntax standpoint.
468
469    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
470    """
471    top_level_with = expression.args.get("with")
472    for node in expression.find_all(exp.With):
473        if node.parent is expression:
474            continue
475
476        inner_with = node.pop()
477        if not top_level_with:
478            top_level_with = inner_with
479            expression.set("with", top_level_with)
480        else:
481            if inner_with.recursive:
482                top_level_with.set("recursive", True)
483
484            top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions)
485
486    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
489def ensure_bools(expression: exp.Expression) -> exp.Expression:
490    """Converts numeric values used in conditions into explicit boolean expressions."""
491    from sqlglot.optimizer.canonicalize import ensure_bools
492
493    def _ensure_bool(node: exp.Expression) -> None:
494        if (
495            node.is_number
496            or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
497            or (isinstance(node, exp.Column) and not node.type)
498        ):
499            node.replace(node.neq(0))
500
501    for node in expression.walk():
502        ensure_bools(node, _ensure_bool)
503
504    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
507def unqualify_columns(expression: exp.Expression) -> exp.Expression:
508    for column in expression.find_all(exp.Column):
509        # We only wanna pop off the table, db, catalog args
510        for part in column.parts[:-1]:
511            part.pop()
512
513    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
516def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
517    assert isinstance(expression, exp.Create)
518    for constraint in expression.find_all(exp.UniqueColumnConstraint):
519        if constraint.parent:
520            constraint.parent.pop()
521
522    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
525def ctas_with_tmp_tables_to_create_tmp_view(
526    expression: exp.Expression,
527    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
528) -> exp.Expression:
529    assert isinstance(expression, exp.Create)
530    properties = expression.args.get("properties")
531    temporary = any(
532        isinstance(prop, exp.TemporaryProperty)
533        for prop in (properties.expressions if properties else [])
534    )
535
536    # CTAS with temp tables map to CREATE TEMPORARY VIEW
537    if expression.kind == "TABLE" and temporary:
538        if expression.expression:
539            return exp.Create(
540                kind="TEMPORARY VIEW",
541                this=expression.this,
542                expression=expression.expression,
543            )
544        return tmp_storage_provider(expression)
545
546    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
549def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
550    """
551    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
552    PARTITIONED BY value is an array of column names, they are transformed into a schema.
553    The corresponding columns are removed from the create statement.
554    """
555    assert isinstance(expression, exp.Create)
556    has_schema = isinstance(expression.this, exp.Schema)
557    is_partitionable = expression.kind in {"TABLE", "VIEW"}
558
559    if has_schema and is_partitionable:
560        prop = expression.find(exp.PartitionedByProperty)
561        if prop and prop.this and not isinstance(prop.this, exp.Schema):
562            schema = expression.this
563            columns = {v.name.upper() for v in prop.this.expressions}
564            partitions = [col for col in schema.expressions if col.name.upper() in columns]
565            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
566            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
567            expression.set("this", schema)
568
569    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
572def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
573    """
574    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
575
576    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
577    """
578    assert isinstance(expression, exp.Create)
579    prop = expression.find(exp.PartitionedByProperty)
580    if (
581        prop
582        and prop.this
583        and isinstance(prop.this, exp.Schema)
584        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
585    ):
586        prop_this = exp.Tuple(
587            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
588        )
589        schema = expression.this
590        for e in prop.this.expressions:
591            schema.append("expressions", e)
592        prop.set("this", prop_this)
593
594    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
597def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
598    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
599    if isinstance(expression, exp.Struct):
600        expression.set(
601            "expressions",
602            [
603                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
604                for e in expression.expressions
605            ],
606        )
607
608    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
611def preprocess(
612    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
613) -> t.Callable[[Generator, exp.Expression], str]:
614    """
615    Creates a new transform by chaining a sequence of transformations and converts the resulting
616    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
617    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
618
619    Args:
620        transforms: sequence of transform functions. These will be called in order.
621
622    Returns:
623        Function that can be used as a generator transform.
624    """
625
626    def _to_sql(self, expression: exp.Expression) -> str:
627        expression_type = type(expression)
628
629        expression = transforms[0](expression)
630        for transform in transforms[1:]:
631            expression = transform(expression)
632
633        _sql_handler = getattr(self, expression.key + "_sql", None)
634        if _sql_handler:
635            return _sql_handler(expression)
636
637        transforms_handler = self.TRANSFORMS.get(type(expression))
638        if transforms_handler:
639            if expression_type is type(expression):
640                if isinstance(expression, exp.Func):
641                    return self.function_fallback_sql(expression)
642
643                # Ensures we don't enter an infinite loop. This can happen when the original expression
644                # has the same type as the final expression and there's no _sql method available for it,
645                # because then it'd re-enter _to_sql.
646                raise ValueError(
647                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
648                )
649
650            return transforms_handler(self, expression)
651
652        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
653
654    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.