Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 67        order = expression.args.get("order")
 68
 69        if order:
 70            window.set("order", order.pop())
 71        else:
 72            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 73
 74        window = exp.alias_(window, row_number)
 75        expression.select(window, copy=False)
 76
 77        return (
 78            exp.select(*outer_selects, copy=False)
 79            .from_(expression.subquery("_t", copy=False), copy=False)
 80            .where(exp.column(row_number).eq(1), copy=False)
 81        )
 82
 83    return expression
 84
 85
 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 87    """
 88    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 89
 90    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 91    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 92
 93    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 94    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 95    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 96    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 97    """
 98    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 99        taken = set(expression.named_selects)
100        for select in expression.selects:
101            if not select.alias_or_name:
102                alias = find_new_name(taken, "_c")
103                select.replace(exp.alias_(select, alias))
104                taken.add(alias)
105
106        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
107        qualify_filters = expression.args["qualify"].pop().this
108
109        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
110        for expr in qualify_filters.find_all(select_candidates):
111            if isinstance(expr, exp.Window):
112                alias = find_new_name(expression.named_selects, "_w")
113                expression.select(exp.alias_(expr, alias), copy=False)
114                column = exp.column(alias)
115
116                if isinstance(expr.parent, exp.Qualify):
117                    qualify_filters = column
118                else:
119                    expr.replace(column)
120            elif expr.name not in expression.named_selects:
121                expression.select(expr.copy(), copy=False)
122
123        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
124            qualify_filters, copy=False
125        )
126
127    return expression
128
129
130def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
131    """
132    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
133    other expressions. This transforms removes the precision from parameterized types in expressions.
134    """
135    for node in expression.find_all(exp.DataType):
136        node.set(
137            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
138        )
139
140    return expression
141
142
143def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
144    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
145    if isinstance(expression, exp.Select):
146        for join in expression.args.get("joins") or []:
147            unnest = join.this
148
149            if isinstance(unnest, exp.Unnest):
150                alias = unnest.args.get("alias")
151                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
152
153                expression.args["joins"].remove(join)
154
155                for e, column in zip(unnest.expressions, alias.columns if alias else []):
156                    expression.append(
157                        "laterals",
158                        exp.Lateral(
159                            this=udtf(this=e),
160                            view=True,
161                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
162                        ),
163                    )
164
165    return expression
166
167
168def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
169    """Convert explode/posexplode into unnest (used in hive -> presto)."""
170
171    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
172        if isinstance(expression, exp.Select):
173            from sqlglot.optimizer.scope import Scope
174
175            taken_select_names = set(expression.named_selects)
176            taken_source_names = {name for name, _ in Scope(expression).references}
177
178            def new_name(names: t.Set[str], name: str) -> str:
179                name = find_new_name(names, name)
180                names.add(name)
181                return name
182
183            arrays: t.List[exp.Condition] = []
184            series_alias = new_name(taken_select_names, "pos")
185            series = exp.alias_(
186                exp.Unnest(
187                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
188                ),
189                new_name(taken_source_names, "_u"),
190                table=[series_alias],
191            )
192
193            # we use list here because expression.selects is mutated inside the loop
194            for select in list(expression.selects):
195                explode = select.find(exp.Explode)
196
197                if explode:
198                    pos_alias = ""
199                    explode_alias = ""
200
201                    if isinstance(select, exp.Alias):
202                        explode_alias = select.alias
203                        alias = select
204                    elif isinstance(select, exp.Aliases):
205                        pos_alias = select.aliases[0].name
206                        explode_alias = select.aliases[1].name
207                        alias = select.replace(exp.alias_(select.this, "", copy=False))
208                    else:
209                        alias = select.replace(exp.alias_(select, ""))
210                        explode = alias.find(exp.Explode)
211                        assert explode
212
213                    is_posexplode = isinstance(explode, exp.Posexplode)
214                    explode_arg = explode.this
215
216                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
217                    if isinstance(explode_arg, exp.Column):
218                        taken_select_names.add(explode_arg.output_name)
219
220                    unnest_source_alias = new_name(taken_source_names, "_u")
221
222                    if not explode_alias:
223                        explode_alias = new_name(taken_select_names, "col")
224
225                        if is_posexplode:
226                            pos_alias = new_name(taken_select_names, "pos")
227
228                    if not pos_alias:
229                        pos_alias = new_name(taken_select_names, "pos")
230
231                    alias.set("alias", exp.to_identifier(explode_alias))
232
233                    column = exp.If(
234                        this=exp.column(series_alias).eq(exp.column(pos_alias)),
235                        true=exp.column(explode_alias),
236                    )
237
238                    explode.replace(column)
239
240                    if is_posexplode:
241                        expressions = expression.expressions
242                        expressions.insert(
243                            expressions.index(alias) + 1,
244                            exp.If(
245                                this=exp.column(series_alias).eq(exp.column(pos_alias)),
246                                true=exp.column(pos_alias),
247                            ).as_(pos_alias),
248                        )
249                        expression.set("expressions", expressions)
250
251                    if not arrays:
252                        if expression.args.get("from"):
253                            expression.join(series, copy=False)
254                        else:
255                            expression.from_(series, copy=False)
256
257                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
258                    arrays.append(size)
259
260                    # trino doesn't support left join unnest with on conditions
261                    # if it did, this would be much simpler
262                    expression.join(
263                        exp.alias_(
264                            exp.Unnest(
265                                expressions=[explode_arg.copy()],
266                                offset=exp.to_identifier(pos_alias),
267                            ),
268                            unnest_source_alias,
269                            table=[explode_alias],
270                        ),
271                        join_type="CROSS",
272                        copy=False,
273                    )
274
275                    if index_offset != 1:
276                        size = size - 1
277
278                    expression.where(
279                        exp.column(series_alias)
280                        .eq(exp.column(pos_alias))
281                        .or_(
282                            (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
283                        ),
284                        copy=False,
285                    )
286
287            if arrays:
288                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
289
290                if index_offset != 1:
291                    end = end - (1 - index_offset)
292                series.expressions[0].set("end", end)
293
294        return expression
295
296    return _explode_to_unnest
297
298
299PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
300
301
302def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
303    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
304    if (
305        isinstance(expression, PERCENTILES)
306        and not isinstance(expression.parent, exp.WithinGroup)
307        and expression.expression
308    ):
309        column = expression.this.pop()
310        expression.set("this", expression.expression.pop())
311        order = exp.Order(expressions=[exp.Ordered(this=column)])
312        expression = exp.WithinGroup(this=expression, expression=order)
313
314    return expression
315
316
317def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
318    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
319    if (
320        isinstance(expression, exp.WithinGroup)
321        and isinstance(expression.this, PERCENTILES)
322        and isinstance(expression.expression, exp.Order)
323    ):
324        quantile = expression.this.this
325        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
326        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
327
328    return expression
329
330
331def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
332    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
333    if isinstance(expression, exp.With) and expression.recursive:
334        next_name = name_sequence("_c_")
335
336        for cte in expression.expressions:
337            if not cte.args["alias"].columns:
338                query = cte.this
339                if isinstance(query, exp.Union):
340                    query = query.this
341
342                cte.args["alias"].set(
343                    "columns",
344                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
345                )
346
347    return expression
348
349
350def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
351    """Replace 'epoch' in casts by the equivalent date literal."""
352    if (
353        isinstance(expression, (exp.Cast, exp.TryCast))
354        and expression.name.lower() == "epoch"
355        and expression.to.this in exp.DataType.TEMPORAL_TYPES
356    ):
357        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
358
359    return expression
360
361
362def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
363    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
364    if isinstance(expression, exp.Select):
365        for join in expression.args.get("joins") or []:
366            on = join.args.get("on")
367            if on and join.kind in ("SEMI", "ANTI"):
368                subquery = exp.select("1").from_(join.this).where(on)
369                exists = exp.Exists(this=subquery)
370                if join.kind == "ANTI":
371                    exists = exists.not_(copy=False)
372
373                join.pop()
374                expression.where(exists, copy=False)
375
376    return expression
377
378
379def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
380    """
381    Converts a query with a FULL OUTER join to a union of identical queries that
382    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
383    for queries that have a single FULL OUTER join.
384    """
385    if isinstance(expression, exp.Select):
386        full_outer_joins = [
387            (index, join)
388            for index, join in enumerate(expression.args.get("joins") or [])
389            if join.side == "FULL" and join.kind == "OUTER"
390        ]
391
392        if len(full_outer_joins) == 1:
393            expression_copy = expression.copy()
394            index, full_outer_join = full_outer_joins[0]
395            full_outer_join.set("side", "left")
396            expression_copy.args["joins"][index].set("side", "right")
397            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
398
399            return exp.union(expression, expression_copy, copy=False)
400
401    return expression
402
403
404def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
405    """
406    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
407    defined at the top-level, so for example queries like:
408
409        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
410
411    are invalid in those dialects. This transformation can be used to ensure all CTEs are
412    moved to the top level so that the final SQL code is valid from a syntax standpoint.
413
414    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
415    """
416    top_level_with = expression.args.get("with")
417    for node in expression.find_all(exp.With):
418        if node.parent is expression:
419            continue
420
421        inner_with = node.pop()
422        if not top_level_with:
423            top_level_with = inner_with
424            expression.set("with", top_level_with)
425        else:
426            if inner_with.recursive:
427                top_level_with.set("recursive", True)
428
429            top_level_with.expressions.extend(inner_with.expressions)
430
431    return expression
432
433
434def ensure_bools(expression: exp.Expression) -> exp.Expression:
435    from sqlglot.optimizer.canonicalize import ensure_bools
436
437    def _ensure_bool(node: exp.Expression) -> None:
438        if (
439            node.is_number
440            or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
441            or (isinstance(node, exp.Column) and not node.type)
442        ):
443            node.replace(node.neq(0))
444
445    for node, *_ in expression.walk():
446        ensure_bools(node, _ensure_bool)
447
448    return expression
449
450
451def preprocess(
452    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
453) -> t.Callable[[Generator, exp.Expression], str]:
454    """
455    Creates a new transform by chaining a sequence of transformations and converts the resulting
456    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
457    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
458
459    Args:
460        transforms: sequence of transform functions. These will be called in order.
461
462    Returns:
463        Function that can be used as a generator transform.
464    """
465
466    def _to_sql(self, expression: exp.Expression) -> str:
467        expression_type = type(expression)
468
469        expression = transforms[0](expression)
470        for t in transforms[1:]:
471            expression = t(expression)
472
473        _sql_handler = getattr(self, expression.key + "_sql", None)
474        if _sql_handler:
475            return _sql_handler(expression)
476
477        transforms_handler = self.TRANSFORMS.get(type(expression))
478        if transforms_handler:
479            if expression_type is type(expression):
480                if isinstance(expression, exp.Func):
481                    return self.function_fallback_sql(expression)
482
483                # Ensures we don't enter an infinite loop. This can happen when the original expression
484                # has the same type as the final expression and there's no _sql method available for it,
485                # because then it'd re-enter _to_sql.
486                raise ValueError(
487                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
488                )
489
490            return transforms_handler(self, expression)
491
492        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
493
494    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
68        order = expression.args.get("order")
69
70        if order:
71            window.set("order", order.pop())
72        else:
73            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
74
75        window = exp.alias_(window, row_number)
76        expression.select(window, copy=False)
77
78        return (
79            exp.select(*outer_selects, copy=False)
80            .from_(expression.subquery("_t", copy=False), copy=False)
81            .where(exp.column(row_number).eq(1), copy=False)
82        )
83
84    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 87def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 88    """
 89    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 90
 91    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 92    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 93
 94    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 95    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 96    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 97    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 98    """
 99    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
100        taken = set(expression.named_selects)
101        for select in expression.selects:
102            if not select.alias_or_name:
103                alias = find_new_name(taken, "_c")
104                select.replace(exp.alias_(select, alias))
105                taken.add(alias)
106
107        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
108        qualify_filters = expression.args["qualify"].pop().this
109
110        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
111        for expr in qualify_filters.find_all(select_candidates):
112            if isinstance(expr, exp.Window):
113                alias = find_new_name(expression.named_selects, "_w")
114                expression.select(exp.alias_(expr, alias), copy=False)
115                column = exp.column(alias)
116
117                if isinstance(expr.parent, exp.Qualify):
118                    qualify_filters = column
119                else:
120                    expr.replace(column)
121            elif expr.name not in expression.named_selects:
122                expression.select(expr.copy(), copy=False)
123
124        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
125            qualify_filters, copy=False
126        )
127
128    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
131def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
132    """
133    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
134    other expressions. This transforms removes the precision from parameterized types in expressions.
135    """
136    for node in expression.find_all(exp.DataType):
137        node.set(
138            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
139        )
140
141    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
144def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
145    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
146    if isinstance(expression, exp.Select):
147        for join in expression.args.get("joins") or []:
148            unnest = join.this
149
150            if isinstance(unnest, exp.Unnest):
151                alias = unnest.args.get("alias")
152                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
153
154                expression.args["joins"].remove(join)
155
156                for e, column in zip(unnest.expressions, alias.columns if alias else []):
157                    expression.append(
158                        "laterals",
159                        exp.Lateral(
160                            this=udtf(this=e),
161                            view=True,
162                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
163                        ),
164                    )
165
166    return expression

Convert cross join unnest into lateral view explode (used in presto -> hive).

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
169def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
170    """Convert explode/posexplode into unnest (used in hive -> presto)."""
171
172    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
173        if isinstance(expression, exp.Select):
174            from sqlglot.optimizer.scope import Scope
175
176            taken_select_names = set(expression.named_selects)
177            taken_source_names = {name for name, _ in Scope(expression).references}
178
179            def new_name(names: t.Set[str], name: str) -> str:
180                name = find_new_name(names, name)
181                names.add(name)
182                return name
183
184            arrays: t.List[exp.Condition] = []
185            series_alias = new_name(taken_select_names, "pos")
186            series = exp.alias_(
187                exp.Unnest(
188                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
189                ),
190                new_name(taken_source_names, "_u"),
191                table=[series_alias],
192            )
193
194            # we use list here because expression.selects is mutated inside the loop
195            for select in list(expression.selects):
196                explode = select.find(exp.Explode)
197
198                if explode:
199                    pos_alias = ""
200                    explode_alias = ""
201
202                    if isinstance(select, exp.Alias):
203                        explode_alias = select.alias
204                        alias = select
205                    elif isinstance(select, exp.Aliases):
206                        pos_alias = select.aliases[0].name
207                        explode_alias = select.aliases[1].name
208                        alias = select.replace(exp.alias_(select.this, "", copy=False))
209                    else:
210                        alias = select.replace(exp.alias_(select, ""))
211                        explode = alias.find(exp.Explode)
212                        assert explode
213
214                    is_posexplode = isinstance(explode, exp.Posexplode)
215                    explode_arg = explode.this
216
217                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
218                    if isinstance(explode_arg, exp.Column):
219                        taken_select_names.add(explode_arg.output_name)
220
221                    unnest_source_alias = new_name(taken_source_names, "_u")
222
223                    if not explode_alias:
224                        explode_alias = new_name(taken_select_names, "col")
225
226                        if is_posexplode:
227                            pos_alias = new_name(taken_select_names, "pos")
228
229                    if not pos_alias:
230                        pos_alias = new_name(taken_select_names, "pos")
231
232                    alias.set("alias", exp.to_identifier(explode_alias))
233
234                    column = exp.If(
235                        this=exp.column(series_alias).eq(exp.column(pos_alias)),
236                        true=exp.column(explode_alias),
237                    )
238
239                    explode.replace(column)
240
241                    if is_posexplode:
242                        expressions = expression.expressions
243                        expressions.insert(
244                            expressions.index(alias) + 1,
245                            exp.If(
246                                this=exp.column(series_alias).eq(exp.column(pos_alias)),
247                                true=exp.column(pos_alias),
248                            ).as_(pos_alias),
249                        )
250                        expression.set("expressions", expressions)
251
252                    if not arrays:
253                        if expression.args.get("from"):
254                            expression.join(series, copy=False)
255                        else:
256                            expression.from_(series, copy=False)
257
258                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
259                    arrays.append(size)
260
261                    # trino doesn't support left join unnest with on conditions
262                    # if it did, this would be much simpler
263                    expression.join(
264                        exp.alias_(
265                            exp.Unnest(
266                                expressions=[explode_arg.copy()],
267                                offset=exp.to_identifier(pos_alias),
268                            ),
269                            unnest_source_alias,
270                            table=[explode_alias],
271                        ),
272                        join_type="CROSS",
273                        copy=False,
274                    )
275
276                    if index_offset != 1:
277                        size = size - 1
278
279                    expression.where(
280                        exp.column(series_alias)
281                        .eq(exp.column(pos_alias))
282                        .or_(
283                            (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
284                        ),
285                        copy=False,
286                    )
287
288            if arrays:
289                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
290
291                if index_offset != 1:
292                    end = end - (1 - index_offset)
293                series.expressions[0].set("end", end)
294
295        return expression
296
297    return _explode_to_unnest

Convert explode/posexplode into unnest (used in hive -> presto).

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
303def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
304    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
305    if (
306        isinstance(expression, PERCENTILES)
307        and not isinstance(expression.parent, exp.WithinGroup)
308        and expression.expression
309    ):
310        column = expression.this.pop()
311        expression.set("this", expression.expression.pop())
312        order = exp.Order(expressions=[exp.Ordered(this=column)])
313        expression = exp.WithinGroup(this=expression, expression=order)
314
315    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
318def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
319    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
320    if (
321        isinstance(expression, exp.WithinGroup)
322        and isinstance(expression.this, PERCENTILES)
323        and isinstance(expression.expression, exp.Order)
324    ):
325        quantile = expression.this.this
326        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
327        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
328
329    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
332def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
333    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
334    if isinstance(expression, exp.With) and expression.recursive:
335        next_name = name_sequence("_c_")
336
337        for cte in expression.expressions:
338            if not cte.args["alias"].columns:
339                query = cte.this
340                if isinstance(query, exp.Union):
341                    query = query.this
342
343                cte.args["alias"].set(
344                    "columns",
345                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
346                )
347
348    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
351def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
352    """Replace 'epoch' in casts by the equivalent date literal."""
353    if (
354        isinstance(expression, (exp.Cast, exp.TryCast))
355        and expression.name.lower() == "epoch"
356        and expression.to.this in exp.DataType.TEMPORAL_TYPES
357    ):
358        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
359
360    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
363def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
364    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
365    if isinstance(expression, exp.Select):
366        for join in expression.args.get("joins") or []:
367            on = join.args.get("on")
368            if on and join.kind in ("SEMI", "ANTI"):
369                subquery = exp.select("1").from_(join.this).where(on)
370                exists = exp.Exists(this=subquery)
371                if join.kind == "ANTI":
372                    exists = exists.not_(copy=False)
373
374                join.pop()
375                expression.where(exists, copy=False)
376
377    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
380def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
381    """
382    Converts a query with a FULL OUTER join to a union of identical queries that
383    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
384    for queries that have a single FULL OUTER join.
385    """
386    if isinstance(expression, exp.Select):
387        full_outer_joins = [
388            (index, join)
389            for index, join in enumerate(expression.args.get("joins") or [])
390            if join.side == "FULL" and join.kind == "OUTER"
391        ]
392
393        if len(full_outer_joins) == 1:
394            expression_copy = expression.copy()
395            index, full_outer_join = full_outer_joins[0]
396            full_outer_join.set("side", "left")
397            expression_copy.args["joins"][index].set("side", "right")
398            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
399
400            return exp.union(expression, expression_copy, copy=False)
401
402    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
405def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
406    """
407    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
408    defined at the top-level, so for example queries like:
409
410        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
411
412    are invalid in those dialects. This transformation can be used to ensure all CTEs are
413    moved to the top level so that the final SQL code is valid from a syntax standpoint.
414
415    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
416    """
417    top_level_with = expression.args.get("with")
418    for node in expression.find_all(exp.With):
419        if node.parent is expression:
420            continue
421
422        inner_with = node.pop()
423        if not top_level_with:
424            top_level_with = inner_with
425            expression.set("with", top_level_with)
426        else:
427            if inner_with.recursive:
428                top_level_with.set("recursive", True)
429
430            top_level_with.expressions.extend(inner_with.expressions)
431
432    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
435def ensure_bools(expression: exp.Expression) -> exp.Expression:
436    from sqlglot.optimizer.canonicalize import ensure_bools
437
438    def _ensure_bool(node: exp.Expression) -> None:
439        if (
440            node.is_number
441            or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
442            or (isinstance(node, exp.Column) and not node.type)
443        ):
444            node.replace(node.neq(0))
445
446    for node, *_ in expression.walk():
447        ensure_bools(node, _ensure_bool)
448
449    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
452def preprocess(
453    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
454) -> t.Callable[[Generator, exp.Expression], str]:
455    """
456    Creates a new transform by chaining a sequence of transformations and converts the resulting
457    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
458    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
459
460    Args:
461        transforms: sequence of transform functions. These will be called in order.
462
463    Returns:
464        Function that can be used as a generator transform.
465    """
466
467    def _to_sql(self, expression: exp.Expression) -> str:
468        expression_type = type(expression)
469
470        expression = transforms[0](expression)
471        for t in transforms[1:]:
472            expression = t(expression)
473
474        _sql_handler = getattr(self, expression.key + "_sql", None)
475        if _sql_handler:
476            return _sql_handler(expression)
477
478        transforms_handler = self.TRANSFORMS.get(type(expression))
479        if transforms_handler:
480            if expression_type is type(expression):
481                if isinstance(expression, exp.Func):
482                    return self.function_fallback_sql(expression)
483
484                # Ensures we don't enter an infinite loop. This can happen when the original expression
485                # has the same type as the final expression and there's no _sql method available for it,
486                # because then it'd re-enter _to_sql.
487                raise ValueError(
488                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
489                )
490
491            return transforms_handler(self, expression)
492
493        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
494
495    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.